| """ |
| StoryBox Gradio Web UI — model/provider picker + multilingual + style presets + NIM RPM warning. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import gradio as gr |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from reverie.config.config import Config |
| from reverie.skills.writing_styles import list_style_presets, get_style_preset |
| from reverie.skills.multilingual import get_language_config, LanguageConfig |
|
|
| |
| |
| |
| PROVIDER_MODELS = { |
| "OpenAI": { |
| "models": ["gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"], |
| "env_var": "OPENAI_API_KEY", |
| "needs_key": True, |
| "rpm": 60, |
| "rpm_note": "Tier 1: 60 RPM", |
| }, |
| "Ollama": { |
| "models": ["llama3.1:8b", "llama3.1:70b", "gemma4", "gemma4:9b", |
| "gemma4:27b", "mistral", "mixtral", "phi4", "qwen2.5", "command-r", |
| "deepseek-r1", "llama3.3", "codellama"], |
| "env_var": "", |
| "needs_key": False, |
| "rpm": 10000, |
| "rpm_note": "Local — unlimited", |
| }, |
| "MLX (Apple Silicon)": { |
| "models": ["llama3.1-8b-mlx", "llama3.1-70b-mlx", "mistral-mlx", |
| "phi3-mlx", "qwen2.5-mlx", "gemma2-mlx", "deepseek-mlx"], |
| "env_var": "", |
| "needs_key": False, |
| "rpm": 10000, |
| "rpm_note": "Local — unlimited", |
| }, |
| "NVIDIA NIM": { |
| "models": ["nvidia/meta/llama-3.1-8b-instruct", |
| "nvidia/meta/llama-3.1-70b-instruct", |
| "nvidia/mistralai/mistral-7b-instruct-v0.3", |
| "nvidia/google/gemma-2-9b-it", |
| "nvidia/microsoft/phi-3-mini-128k-instruct"], |
| "env_var": "NIM_API_KEY", |
| "needs_key": True, |
| "rpm": 40, |
| "rpm_note": "⚠️ Free tier: 40 RPM. Rate limiting auto-enabled.", |
| }, |
| "HuggingFace": { |
| "models": ["mistralai/Mistral-7B-Instruct-v0.3", |
| "meta-llama/Llama-3.1-8B-Instruct"], |
| "env_var": "HF_TOKEN", |
| "needs_key": True, |
| "rpm": 60, |
| "rpm_note": "Inference API free tier: 60 RPM", |
| }, |
| } |
|
|
| LANGUAGES = { |
| "English": "en", |
| "Hindi (हिन्दी)": "hi", |
| "Arabic (العربية)": "ar", |
| "Chinese (中文)": "zh", |
| "Japanese (日本語)": "ja", |
| "Spanish (Español)": "es", |
| "French (Français)": "fr", |
| "Tamil (தமிழ்)": "ta", |
| "Telugu (తెలుగు)": "te", |
| "Bengali (বাংলা)": "bn", |
| "Marathi (मराठी)": "mr", |
| "Gujarati (ગુજરાતી)": "gu", |
| "Urdu (اردو)": "ur", |
| "Malayalam (മലയാളം)": "ml", |
| "Kannada (ಕನ್ನಡ)": "kn", |
| "Punjabi (ਪੰਜਾਬੀ)": "pa", |
| "Korean (한국어)": "ko", |
| "Portuguese (Português)": "pt", |
| "German (Deutsch)": "de", |
| "Russian (Русский)": "ru", |
| } |
|
|
|
|
| def update_model_dropdown(provider: str): |
| """Update model list when provider changes.""" |
| models = PROVIDER_MODELS.get(provider, {}).get("models", []) |
| return gr.Dropdown(choices=models, value=models[0] if models else None) |
|
|
|
|
| def update_key_visibility(provider: str): |
| """Show/hide API key input based on provider.""" |
| info = PROVIDER_MODELS.get(provider, {}) |
| needs_key = info.get("needs_key", False) |
| env_var = info.get("env_var", "") |
| placeholder = f"${env_var} env var or paste key here" if env_var else "" |
| return gr.Textbox(visible=needs_key, placeholder=placeholder) |
|
|
|
|
| def update_rpm_warning(provider: str): |
| """Show RPM warning for providers with strict limits.""" |
| info = PROVIDER_MODELS.get(provider, {}) |
| rpm_note = info.get("rpm_note", "") |
| is_warning = "⚠️" in rpm_note or "40" in rpm_note |
| return gr.Markdown(value=f"**{rpm_note}**", visible=bool(rpm_note)) |
|
|
|
|
| def generate_story( |
| provider: str, |
| model: str, |
| api_key: str, |
| story_setting: str, |
| language: str, |
| multilingual_mode: str, |
| num_days: int, |
| num_personas: int, |
| temperature: float, |
| style_preset: str, |
| enable_features: bool, |
| max_tokens: int, |
| progress=gr.Progress(), |
| ): |
| """Run StoryBox pipeline with UI parameters.""" |
| |
| provider_map = { |
| "OpenAI": "openai", |
| "Ollama": "ollama", |
| "MLX (Apple Silicon)": "mlx", |
| "NVIDIA NIM": "nim", |
| "HuggingFace": "huggingface", |
| } |
| Config.model_provider = provider_map.get(provider, "ollama") |
| Config.llm_model_name = model |
| Config.temperature = temperature |
| Config.max_tokens = max_tokens |
| Config.story_name = story_setting |
| Config.story_dir = f"{Config.data_dir}/{story_setting}" |
| Config.max_iteration = 24 * num_days |
| Config.target_language = LANGUAGES.get(language, "en") |
| Config.multilingual_mode = multilingual_mode |
|
|
| |
| if provider == "OpenAI" and api_key: |
| Config.api_key = api_key |
| elif provider == "NVIDIA NIM" and api_key: |
| Config.nim_api_key = api_key |
|
|
| |
| if style_preset: |
| Config.writing_style = style_preset |
|
|
| |
| import asyncio |
| from reverie.run import reverie_task |
|
|
| progress(0.1, desc="Loading world and personas...") |
| asyncio.run(reverie_task()) |
|
|
| |
| output_path = Path(Config.output_dir) / "story.json" |
| if not output_path.exists(): |
| return "Error: Story generation failed.", {} |
|
|
| with open(output_path) as f: |
| data = json.load(f) |
|
|
| story_text = data.get("story", "") |
| metadata = { |
| "title": data.get("story_title", ""), |
| "type": data.get("story_type", ""), |
| "language": language, |
| "multilingual_mode": multilingual_mode, |
| "provider": provider, |
| "model": model, |
| "temperature": temperature, |
| "word_count": len(story_text.split()), |
| "style": style_preset or "default", |
| } |
|
|
| return story_text, metadata |
|
|
|
|
| def build_ui(): |
| with gr.Blocks(title="StoryBox — AI Story Generator", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 📖 StoryBox") |
| gr.Markdown("Generate long-form stories with multi-agent simulation. Supports 20+ languages and multiple LLM providers.") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("## 🤖 Model") |
|
|
| provider = gr.Dropdown( |
| choices=list(PROVIDER_MODELS.keys()), |
| value="Ollama", |
| label="Provider", |
| ) |
| model = gr.Dropdown( |
| choices=PROVIDER_MODELS["Ollama"]["models"], |
| value="llama3.1:8b", |
| label="Model", |
| ) |
| api_key = gr.Textbox( |
| label="API Key", |
| type="password", |
| visible=False, |
| placeholder="Set env var or paste key", |
| ) |
| rpm_warning = gr.Markdown(visible=False) |
|
|
| provider.change(update_model_dropdown, inputs=provider, outputs=model) |
| provider.change(update_key_visibility, inputs=provider, outputs=api_key) |
| provider.change(update_rpm_warning, inputs=provider, outputs=rpm_warning) |
|
|
| gr.Markdown("## 🌍 Language") |
| language = gr.Dropdown( |
| choices=list(LANGUAGES.keys()), |
| value="English", |
| label="Output Language", |
| ) |
| multilingual_mode = gr.Radio( |
| choices=["native", "translate"], |
| value="native", |
| label="Generation Mode", |
| info="Native = generate directly in target language. Translate = English then translate.", |
| ) |
|
|
| gr.Markdown("## ⚙️ Simulation") |
| story_setting = gr.Dropdown( |
| choices=[f"story{i:02d}" for i in range(1, 21)], |
| value="story01", |
| label="Story Setting", |
| ) |
| num_days = gr.Slider(1, 14, value=3, step=1, label="Simulation Days") |
| num_personas = gr.Slider(2, 6, value=0, step=1, |
| label="Number of Characters (0 = all)") |
|
|
| gr.Markdown("## 🎨 Style") |
| style_preset = gr.Dropdown( |
| choices=[""] + list_style_presets(), |
| value="", |
| label="Writing Style Preset", |
| ) |
| enable_features = gr.Checkbox( |
| label="Enable Advanced Features (arcs, subplots, pacing)", |
| value=False, |
| ) |
|
|
| gr.Markdown("## 🔧 Parameters") |
| temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature") |
| max_tokens = gr.Slider(1000, 16000, value=8000, step=1000, label="Max Tokens") |
|
|
| generate_btn = gr.Button("🚀 Generate Story", variant="primary", size="lg") |
|
|
| |
| with gr.Column(scale=2): |
| gr.Markdown("## 📜 Generated Story") |
| output_story = gr.Textbox( |
| label="", |
| lines=35, |
| max_lines=50, |
| show_copy_button=True, |
| autoscroll=False, |
| ) |
| output_metadata = gr.JSON(label="Story Metadata") |
|
|
| with gr.Row(): |
| download_json = gr.DownloadButton("⬇ Download JSON", visible=False) |
| download_txt = gr.DownloadButton("⬇ Download TXT", visible=False) |
|
|
| |
| gr.Markdown("## 💡 Quick Examples") |
| gr.Examples( |
| examples=[ |
| ["Ollama", "llama3.1:8b", "", "story01", "English", "native", 3, 0, 0.8, "", False, 8000], |
| ["MLX (Apple Silicon)", "llama3.1-8b-mlx", "", "story01", "English", "native", 7, 0, 0.8, "classic_fantasy", False, 8000], |
| ["Ollama", "qwen2.5", "", "story01", "Hindi (हिन्दी)", "native", 3, 0, 0.7, "hindi_mythological", False, 8000], |
| ["Ollama", "llama3.1:8b", "", "story01", "Hindi (हिन्दी)", "translate", 3, 0, 0.8, "bollywood_masala", False, 8000], |
| ["NVIDIA NIM", "nvidia/meta/llama-3.1-8b-instruct", "YOUR_KEY", "story02", "English", "native", 5, 0, 0.8, "noir_detective", False, 8000], |
| ["OpenAI", "gpt-4o-mini", "YOUR_KEY", "story03", "English", "native", 3, 0, 0.9, "cyberpunk_neon", True, 8000], |
| ], |
| inputs=[provider, model, api_key, story_setting, language, multilingual_mode, |
| num_days, num_personas, temperature, style_preset, enable_features, max_tokens], |
| label="Click to load example", |
| ) |
|
|
| |
| gr.Markdown("---") |
| gr.Markdown(""" |
| **Supported Providers:** OpenAI, Ollama, MLX (Apple Silicon), NVIDIA NIM, HuggingFace |
| |
| **Supported Languages:** English, Hindi, Arabic, Chinese, Japanese, Spanish, French, |
| Tamil, Telugu, Bengali, Marathi, Gujarati, Urdu, Malayalam, Kannada, Punjabi, |
| Korean, Portuguese, German, Russian |
| |
| **Native vs Translate:** Native generation produces culturally authentic stories. |
| Translate mode is a fallback for models with weak multilingual capability. |
| |
| **Rate Limits:** NIM free tier = 40 RPM (auto-throttled). OpenAI tier 1 = 60 RPM. |
| Local providers (Ollama, MLX) have no limits. |
| """) |
|
|
| generate_btn.click( |
| fn=generate_story, |
| inputs=[provider, model, api_key, story_setting, language, multilingual_mode, |
| num_days, num_personas, temperature, style_preset, enable_features, max_tokens], |
| outputs=[output_story, output_metadata], |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = build_ui() |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|