| import os |
| from typing import List, Optional |
| from langchain_google_genai import ChatGoogleGenerativeAI |
| from langchain_core.messages import HumanMessage |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
|
|
| def simple_chat( |
| question: str, |
| context: str, |
| image_paths: Optional[List[str]] = None |
| ) -> str: |
| """ |
| Simple chat function that answers questions based on context and optional images. |
| |
| Args: |
| question: User's question |
| context: Context information (e.g., dataset summary, analysis results) |
| image_paths: Optional list of image file paths to include |
| |
| Returns: |
| AI response as a string |
| """ |
| try: |
| |
| llm = ChatGoogleGenerativeAI( |
| model="gemini-2.0-flash-exp", |
| temperature=0, |
| api_key=os.getenv("GOOGLE_API_KEY"), |
| ) |
| |
| |
| prompt = f"""You are a helpful data analysis assistant. |
| |
| Context: |
| {context} |
| |
| User Question: {question} |
| |
| Please provide a clear, concise answer based on the context provided.""" |
| |
| |
| if image_paths: |
| content = [{"type": "text", "text": prompt}] |
| |
| for img_path in image_paths: |
| if os.path.exists(img_path): |
| import base64 |
| with open(img_path, "rb") as f: |
| img_data = base64.b64encode(f.read()).decode() |
| |
| |
| ext = os.path.splitext(img_path)[1].lower() |
| mime_type = { |
| '.png': 'image/png', |
| '.jpg': 'image/jpeg', |
| '.jpeg': 'image/jpeg', |
| '.gif': 'image/gif', |
| '.webp': 'image/webp' |
| }.get(ext, 'image/png') |
| |
| content.append({ |
| "type": "image_url", |
| "image_url": f"data:{mime_type};base64,{img_data}" |
| }) |
| |
| message = HumanMessage(content=content) |
| else: |
| message = HumanMessage(content=prompt) |
| |
| |
| response = llm.invoke([message]) |
| |
| |
| if hasattr(response, 'content'): |
| return str(response.content) |
| return str(response) |
| |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| context = """ |
| Dataset: Customer Sales Data |
| - 1000 rows, 15 columns |
| - Label: purchase_made (binary) |
| - Task: Classification |
| - Missing values: 5% in age column |
| """ |
| |
| question = "What's the main task for this dataset?" |
| response = simple_chat(question, context) |
| print(response) |
| |
| |
| question2 = "What do you see in the visualization?" |
| response2 = simple_chat(question2, context, image_paths=["/path/to/plot.png"]) |
| print(response2) |