backend / app.py
AIDev07's picture
Update app.py
3e2439e verified
import os
import httpx
import json
import chromadb
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from datasets import load_dataset
from huggingface_hub import HfApi
from tavily import TavilyClient
from fastapi.middleware.cors import CORSMiddleware
# --- Secrets & Config ---
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_ID = "AIDev07/AIModelsLoaded"
DB_PATH = "./chroma_db"
# --- Vector DB Initialization ---
chroma_client = chromadb.PersistentClient(path=DB_PATH)
collection = chroma_client.get_or_create_collection(name="project_memory")
# --- Lifespan Handler (Replaces @app.on_event) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup Logic: Sync Vector Memory
print("πŸ”„ Lifespan Startup: Syncing Memory...")
if collection.count() == 0:
try:
ds = load_dataset(DATASET_ID, token=HF_TOKEN, split="train")
for i, record in enumerate(ds):
text = record.get("content") or record.get("text") or str(record)
collection.add(documents=[text], ids=[f"doc_{i}"])
print("βœ… Vector Memory Loaded.")
except Exception as e:
print(f"❌ Pull Failed: {e}")
yield
# Shutdown Logic (Optional: Clear cache or close DB connections)
print("πŸ›‘ Lifespan Shutdown: Cleaning up...")
app = FastAPI(lifespan=lifespan)
# --- Cross Origin Resource Sharing ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- Helper: Save to HF ---
def save_to_hf(new_content: str):
api = HfApi()
try:
filename = f"updates/update_{collection.count()}.txt"
with open("update.txt", "w") as f:
f.write(new_content)
api.upload_file(
path_or_fileobj="update.txt",
path_in_repo=filename,
repo_id=DATASET_ID,
repo_type="dataset",
token=HF_TOKEN
)
collection.add(documents=[new_content], ids=[f"doc_{collection.count()}"])
return True
except Exception as e:
print(f"❌ Write Failed: {e}")
return False
# --- Chat Endpoint (Httpx Proxy) ---
@app.post("/v1/chat")
async def chat_endpoint(request: Request):
data = await request.json()
messages = data.get("messages", [])
user_query = messages[-1]["content"]
# 1. RAG Retrieval
docs = collection.query(query_texts=[user_query], n_results=5)
context_str = "\n".join(docs['documents'][0]) if docs['documents'] else ""
# 2. Dynamic Model Routing (Bypass Logic)
model = "llama-3.3-70b-versatile"
if any(k in user_query.lower() for k in ["code", "python", "script"]):
model = "openai/gpt-oss-120b"
elif "kimi" in user_query.lower():
model = "kimi-k2-instruct-0905"
async def stream_logic():
headers = {"Authorization": f"Bearer {GROQ_API_KEY}"}
# Prepend the vector context to the first system message
payload = {
"model": model,
"messages": [{"role": "system", "content": f"Project Context: {context_str}"}] + messages[-5:],
"stream": True
}
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream("POST", "https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
yield f"{line}\n\n"
return StreamingResponse(stream_logic(), media_type="text/event-stream")
@app.post("/v1/debug")
async def debug_endpoint(request: Request):
data = await request.json()
messages = data.get("messages", [])
headers = {"Authorization": f"Bearer {GROQ_API_KEY}"}
payload = {
"model": "llama-3.3-70b-versatile",
"messages": messages,
"stream": False # non-streaming first to isolate the issue
}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
"https://api.groq.com/openai/v1/chat/completions",
headers=headers,
json=payload
)
return {
"groq_status": response.status_code,
"groq_key_set": bool(GROQ_API_KEY),
"groq_key_prefix": GROQ_API_KEY[:8] if GROQ_API_KEY else "MISSING",
"groq_response": response.json()
}
# --- Floating Terminal / Shell Control ---
@app.post("/v1/shell")
async def shell_exec(request: Request):
data = await request.json()
command = data.get("command", "")
# Internal 'save' command to trigger the HF Write bypass
if command.startswith("save "):
content = command.replace("save ", "")
status = save_to_hf(content)
return {"status": "success" if status else "failed", "action": "HF_WRITE"}
return {"status": "ready_to_copy", "command": command}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)