| from PIL import Image |
| import io |
| import streamlit as st |
| import google.generativeai as genai |
|
|
| safety_settings = [ |
| { |
| "category": "HARM_CATEGORY_HARASSMENT", |
| "threshold": "BLOCK_NONE" |
| }, |
| { |
| "category": "HARM_CATEGORY_HATE_SPEECH", |
| "threshold": "BLOCK_NONE" |
| }, |
| { |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
| "threshold": "BLOCK_NONE" |
| }, |
| { |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
| "threshold": "BLOCK_NONE" |
| }, |
| ] |
|
|
|
|
| password_placeholder = st.empty() |
| password = password_placeholder.text_input("пасскод", type="password") |
| if password == st.secrets["real_password"]: |
| password_placeholder.empty() |
| |
|
|
| with st.sidebar: |
| st.title("Gemini Pro") |
| |
| CONFIG = { |
| "temperature": 0.5, |
| "top_p": 1, |
| "top_k": 32, |
| "max_output_tokens": 4096, |
| } |
| |
| genai.configure(api_key=st.secrets["api_key"]) |
|
|
| uploaded_image = st.file_uploader( |
| label="загрузи изображение", |
| label_visibility="visible", |
| help="если загружено изображение - можно спрашивать по нему что-то, если нет - будет обычный чат", |
| accept_multiple_files=False, |
| type=["png", "jpg"], |
| ) |
|
|
| if uploaded_image: |
| image_bytes = uploaded_image.read() |
|
|
|
|
| def get_response(messages, model="gemini-pro"): |
| try: |
| model = genai.GenerativeModel(model, generation_config=genai.GenerationConfig(candidate_count=1, max_output_tokens=4096, temperature=0.6)) |
| res = model.generate_content(messages, stream=True, safety_settings=safety_settings) |
| return res |
| except: |
| return "Извини, но запрос не прошел цензуру." |
|
|
|
|
| if "messages" not in st.session_state: |
| st.session_state["messages"] = [] |
| messages = st.session_state["messages"] |
|
|
| if messages: |
| for item in messages: |
| role, parts = item.values() |
| if role == "user": |
| st.chat_message("user").markdown(parts[0]) |
| elif role == "model": |
| st.chat_message("assistant").markdown(parts[0]) |
|
|
| chat_message = st.chat_input("Спроси что-нибудь!") |
|
|
| if chat_message: |
| st.chat_message("user").markdown(chat_message) |
| res_area = st.chat_message("assistant").empty() |
|
|
| if "image_bytes" in globals(): |
| vision_message = [chat_message, Image.open(io.BytesIO(image_bytes))] |
| res = get_response(vision_message, model="gemini-pro-vision") |
| else: |
| vision_message = [{"role": "user", "parts": [chat_message]}] |
| res = get_response(vision_message) |
|
|
| res_text = "" |
| try: |
| for chunk in res: |
| res_text += chunk.text |
| res_area.markdown(res_text) |
| except: |
| res_text += f"запрос не прошел цензуру:\n{str(res.prompt_feedback)}" |
| res_area.markdown(res_text) |
|
|
|
|
| messages.append({"role": "model", "parts": [res_text]}) |
| else: |
| st.warning("неправильный пароль, увы...") |
|
|