ziyadsuper2017's picture
Update app.py
4598a68 verified
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO
import PyPDF2
from audio_recorder_streamlit import audio_recorder
# Set your API key
api_key = "AIzaSyAHD0FwX-Ds6Y3eI-i5Oz7IdbJqR6rN7pg" # Replace with your actual API key
genai.configure(api_key=api_key)
# Configure the generative AI model
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=3000
)
# Safety settings configuration
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
# --- Streamlit UI ---
st.title("Gemini Chatbot")
st.write("Interact with the powerful Gemini 1.5 models.")
# Model Selection Dropdown
selected_model = st.selectbox("Choose a Gemini 1.5 Model:", ["gemini-1.5-flash-latest", "gemini-1.5-pro-latest"])
# TTS Option Checkbox
enable_tts = st.checkbox("Enable Text-to-Speech")
# --- Helper Functions ---
def get_file_base64(file_content, mime_type):
base64_data = base64.b64encode(file_content).decode()
return {"mime_type": mime_type, "data": base64_data}
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
def display_chat_history():
chat_container = st.empty()
with chat_container.container():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"**{role.title()}:** {parts['text']}")
elif 'data' in parts:
mime_type = parts.get('mime_type', '')
if mime_type.startswith('image'):
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))),
caption='Uploaded Image', use_column_width=True)
elif mime_type == 'application/pdf':
st.write("**PDF Content:**")
pdf_reader = PyPDF2.PdfReader(io.BytesIO(base64.b64decode(parts['data'])))
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
st.write(page.extract_text())
elif mime_type.startswith('audio'):
st.audio(io.BytesIO(base64.b64decode(parts['data'])), format=mime_type)
elif mime_type.startswith('video'):
st.video(io.BytesIO(base64.b64decode(parts['data'])))
# --- Send Message Function ---
def send_message(audio_data=None):
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
prompt_parts = []
# Add user input to the prompt
if user_input:
prompt_parts.append({"text": user_input})
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
# Handle uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
file_content = uploaded_file.read()
prompt_parts.append(get_file_base64(file_content, uploaded_file.type))
st.session_state['chat_history'].append(
{"role": "user", "parts": [get_file_base64(file_content, uploaded_file.type)]}
)
# Handle audio data
if audio_data:
prompt_parts.append(get_file_base64(audio_data, 'audio/wav'))
st.session_state['chat_history'].append(
{"role": "user", "parts": [get_file_base64(audio_data, 'audio/wav')]}
)
# Generate response using the selected model
try:
model = genai.GenerativeModel(
model_name=selected_model,
generation_config=generation_config,
safety_settings=safety_settings
)
response = model.generate_content([{"role": "user", "parts": prompt_parts}])
response_text = response.text if hasattr(response, "text") else "No response text found."
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
if enable_tts:
tts = gTTS(text=response_text, lang='en')
tts_file = BytesIO()
tts.write_to_fp(tts_file)
tts_file.seek(0)
st.audio(tts_file, format='audio/mp3')
except Exception as e:
st.error(f"An error occurred: {e}")
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
# Update the chat history display
display_chat_history()
# --- User Input Area ---
col1, col2 = st.columns([3, 1])
with col1:
user_input = st.text_area(
"Enter your message:",
value="",
key="user_input"
)
with col2:
send_button = st.button(
"Send",
on_click=send_message,
type="primary"
)
# --- File Uploader ---
uploaded_files = st.file_uploader(
"Upload Files (Images, Videos, PDFs, MP3):",
type=["png", "jpg", "jpeg", "mp4", "pdf", "mp3"],
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
# --- Audio Recorder ---
audio_bytes = audio_recorder()
if audio_bytes:
st.audio(audio_bytes, format="audio/wav")
if st.button("Send Recording"):
send_message(audio_data=audio_bytes)
# --- Other Buttons ---
st.button("Clear Conversation", on_click=clear_conversation)
# --- Ensure file_uploader state ---
st.session_state.uploaded_files = uploaded_files
# --- JavaScript for Ctrl+Enter ---
st.markdown(
"""
<script>
document.addEventListener('DOMContentLoaded', (event) => {
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
document.querySelector('.stButton > button').click();
e.preventDefault();
}
});
});
</script>
""",
unsafe_allow_html=True
)
# --- Display Chat History ---
display_chat_history()