Spaces:
Runtime error
Runtime error
Add application file
Browse files
app.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from typing import List
|
| 6 |
+
from scene_gen import *
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, RootModel, ValidationError
|
| 9 |
+
from ollama import chat
|
| 10 |
+
|
| 11 |
+
# -----------------------------
|
| 12 |
+
# Models
|
| 13 |
+
# -----------------------------
|
| 14 |
+
class QAItem(BaseModel):
|
| 15 |
+
Question: str
|
| 16 |
+
Answer: str
|
| 17 |
+
Voice_Over: str
|
| 18 |
+
include_audio: bool
|
| 19 |
+
|
| 20 |
+
class QAList(RootModel[List[QAItem]]):
|
| 21 |
+
"""RootModel wrapping a list of QAItem"""
|
| 22 |
+
root: List[QAItem]
|
| 23 |
+
|
| 24 |
+
# -----------------------------
|
| 25 |
+
# Configuration
|
| 26 |
+
# -----------------------------
|
| 27 |
+
MODEL = "qwen3:0.6b"
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
|
| 29 |
+
|
| 30 |
+
# -----------------------------
|
| 31 |
+
# Core logic
|
| 32 |
+
# -----------------------------
|
| 33 |
+
def generate_qa(topic: str, count: int = 10) -> List[QAItem]:
|
| 34 |
+
"""
|
| 35 |
+
Call the Ollama model to generate `count` QA items for a given topic.
|
| 36 |
+
Returns a list of QAItem instances.
|
| 37 |
+
"""
|
| 38 |
+
schema = QAList.model_json_schema()
|
| 39 |
+
prompt = (
|
| 40 |
+
f'Given the topic "{topic}", generate {count} entries in JSON format, '
|
| 41 |
+
"each with keys Question, Answer, Voice_Over, and include_audio (true/false)."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
response = chat(
|
| 45 |
+
model=MODEL,
|
| 46 |
+
think=False,
|
| 47 |
+
messages=[{"role": "user", "content": prompt}],
|
| 48 |
+
format=schema,
|
| 49 |
+
options={"temperature": 0},
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
qa_list = QAList.model_validate_json(response.message.content)
|
| 54 |
+
return qa_list.root
|
| 55 |
+
except ValidationError as e:
|
| 56 |
+
logging.error("Response validation failed:\n%s", e)
|
| 57 |
+
raise
|
| 58 |
+
|
| 59 |
+
# -----------------------------
|
| 60 |
+
# CLI entrypoint
|
| 61 |
+
# -----------------------------
|
| 62 |
+
def cli_main():
|
| 63 |
+
parser = argparse.ArgumentParser(description="Generate QA JSON via Ollama")
|
| 64 |
+
parser.add_argument("topic", type=str, help="Topic to generate Q&A for")
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"-n", "--count", type=int, default=10,
|
| 67 |
+
help="Number of QA items to generate (default: 10)"
|
| 68 |
+
)
|
| 69 |
+
args = parser.parse_args()
|
| 70 |
+
|
| 71 |
+
logging.info("Generating %d QA items for topic: %s", args.count, args.topic)
|
| 72 |
+
try:
|
| 73 |
+
items = generate_qa(args.topic, args.count)
|
| 74 |
+
except Exception:
|
| 75 |
+
logging.critical("Aborting due to errors")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
# Convert to plain data
|
| 79 |
+
output = [item.model_dump() for item in items]
|
| 80 |
+
|
| 81 |
+
# 1) Pretty-print to stdout
|
| 82 |
+
print(json.dumps(output, indent=2, ensure_ascii=False))
|
| 83 |
+
|
| 84 |
+
# 2) Save to file
|
| 85 |
+
filename = f"{args.topic}.json"
|
| 86 |
+
with open(filename, "w", encoding="utf-8") as f:
|
| 87 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 88 |
+
logging.info("Saved output to %s", filename)
|
| 89 |
+
|
| 90 |
+
# -----------------------------
|
| 91 |
+
# Gradio entrypoint
|
| 92 |
+
# -----------------------------
|
| 93 |
+
def gradio_generate(topic: str, count: int = 10) -> str:
|
| 94 |
+
"""
|
| 95 |
+
Wrapper for Gradio: returns the JSON string.
|
| 96 |
+
"""
|
| 97 |
+
items = generate_qa(topic, count)
|
| 98 |
+
output = [item.model_dump() for item in items]
|
| 99 |
+
|
| 100 |
+
with open("questions.json", "w", encoding="utf-8") as f:
|
| 101 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 102 |
+
|
| 103 |
+
return json.dumps(output, indent=2, ensure_ascii=False)
|
| 104 |
+
|
| 105 |
+
def app():
|
| 106 |
+
demo = gr.Interface(
|
| 107 |
+
fn=gradio_generate,
|
| 108 |
+
inputs=[
|
| 109 |
+
gr.Textbox(label="Topic", placeholder="Enter your topic here"),
|
| 110 |
+
gr.Slider(minimum=1, maximum=50, step=1, label="Number of Q&A items", value=10)
|
| 111 |
+
],
|
| 112 |
+
outputs=gr.Textbox(label="Generated JSON"),
|
| 113 |
+
title="Transcript Generator for Manim Scene",
|
| 114 |
+
description="Generates JSON transcript for Manim Scene, with voiceover."
|
| 115 |
+
)
|
| 116 |
+
demo.launch(share=True, mcp_server=True)
|
| 117 |
+
|
| 118 |
+
# -----------------------------
|
| 119 |
+
# Bootstrap
|
| 120 |
+
# -----------------------------
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
# Decide between CLI and UI based on presence of command-line args
|
| 123 |
+
import sys
|
| 124 |
+
if len(sys.argv) > 1:
|
| 125 |
+
cli_main()
|
| 126 |
+
else:
|
| 127 |
+
app()
|