""" StoryBox Gradio Web UI — model/provider picker + multilingual + style presets + NIM RPM warning. """ from __future__ import annotations import json import os import sys from pathlib import Path import gradio as gr # Add project root sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from reverie.config.config import Config from reverie.skills.writing_styles import list_style_presets, get_style_preset from reverie.skills.multilingual import get_language_config, LanguageConfig # --------------------------------------------------------------------------- # Provider / model registry # --------------------------------------------------------------------------- PROVIDER_MODELS = { "OpenAI": { "models": ["gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"], "env_var": "OPENAI_API_KEY", "needs_key": True, "rpm": 60, "rpm_note": "Tier 1: 60 RPM", }, "Ollama": { "models": ["llama3.1:8b", "llama3.1:70b", "gemma4", "gemma4:9b", "gemma4:27b", "mistral", "mixtral", "phi4", "qwen2.5", "command-r", "deepseek-r1", "llama3.3", "codellama"], "env_var": "", "needs_key": False, "rpm": 10000, "rpm_note": "Local — unlimited", }, "MLX (Apple Silicon)": { "models": ["llama3.1-8b-mlx", "llama3.1-70b-mlx", "mistral-mlx", "phi3-mlx", "qwen2.5-mlx", "gemma2-mlx", "deepseek-mlx"], "env_var": "", "needs_key": False, "rpm": 10000, "rpm_note": "Local — unlimited", }, "NVIDIA NIM": { "models": ["nvidia/meta/llama-3.1-8b-instruct", "nvidia/meta/llama-3.1-70b-instruct", "nvidia/mistralai/mistral-7b-instruct-v0.3", "nvidia/google/gemma-2-9b-it", "nvidia/microsoft/phi-3-mini-128k-instruct"], "env_var": "NIM_API_KEY", "needs_key": True, "rpm": 40, "rpm_note": "⚠️ Free tier: 40 RPM. Rate limiting auto-enabled.", }, "HuggingFace": { "models": ["mistralai/Mistral-7B-Instruct-v0.3", "meta-llama/Llama-3.1-8B-Instruct"], "env_var": "HF_TOKEN", "needs_key": True, "rpm": 60, "rpm_note": "Inference API free tier: 60 RPM", }, } LANGUAGES = { "English": "en", "Hindi (हिन्दी)": "hi", "Arabic (العربية)": "ar", "Chinese (中文)": "zh", "Japanese (日本語)": "ja", "Spanish (Español)": "es", "French (Français)": "fr", "Tamil (தமிழ்)": "ta", "Telugu (తెలుగు)": "te", "Bengali (বাংলা)": "bn", "Marathi (मराठी)": "mr", "Gujarati (ગુજરાતી)": "gu", "Urdu (اردو)": "ur", "Malayalam (മലയാളം)": "ml", "Kannada (ಕನ್ನಡ)": "kn", "Punjabi (ਪੰਜਾਬੀ)": "pa", "Korean (한국어)": "ko", "Portuguese (Português)": "pt", "German (Deutsch)": "de", "Russian (Русский)": "ru", } def update_model_dropdown(provider: str): """Update model list when provider changes.""" models = PROVIDER_MODELS.get(provider, {}).get("models", []) return gr.Dropdown(choices=models, value=models[0] if models else None) def update_key_visibility(provider: str): """Show/hide API key input based on provider.""" info = PROVIDER_MODELS.get(provider, {}) needs_key = info.get("needs_key", False) env_var = info.get("env_var", "") placeholder = f"${env_var} env var or paste key here" if env_var else "" return gr.Textbox(visible=needs_key, placeholder=placeholder) def update_rpm_warning(provider: str): """Show RPM warning for providers with strict limits.""" info = PROVIDER_MODELS.get(provider, {}) rpm_note = info.get("rpm_note", "") is_warning = "⚠️" in rpm_note or "40" in rpm_note return gr.Markdown(value=f"**{rpm_note}**", visible=bool(rpm_note)) def generate_story( provider: str, model: str, api_key: str, story_setting: str, language: str, multilingual_mode: str, num_days: int, num_personas: int, temperature: float, style_preset: str, enable_features: bool, max_tokens: int, progress=gr.Progress(), ): """Run StoryBox pipeline with UI parameters.""" # Apply config provider_map = { "OpenAI": "openai", "Ollama": "ollama", "MLX (Apple Silicon)": "mlx", "NVIDIA NIM": "nim", "HuggingFace": "huggingface", } Config.model_provider = provider_map.get(provider, "ollama") Config.llm_model_name = model Config.temperature = temperature Config.max_tokens = max_tokens Config.story_name = story_setting Config.story_dir = f"{Config.data_dir}/{story_setting}" Config.max_iteration = 24 * num_days Config.target_language = LANGUAGES.get(language, "en") Config.multilingual_mode = multilingual_mode # API keys if provider == "OpenAI" and api_key: Config.api_key = api_key elif provider == "NVIDIA NIM" and api_key: Config.nim_api_key = api_key # Style if style_preset: Config.writing_style = style_preset # Run simulation import asyncio from reverie.run import reverie_task progress(0.1, desc="Loading world and personas...") asyncio.run(reverie_task()) # Load result output_path = Path(Config.output_dir) / "story.json" if not output_path.exists(): return "Error: Story generation failed.", {} with open(output_path) as f: data = json.load(f) story_text = data.get("story", "") metadata = { "title": data.get("story_title", ""), "type": data.get("story_type", ""), "language": language, "multilingual_mode": multilingual_mode, "provider": provider, "model": model, "temperature": temperature, "word_count": len(story_text.split()), "style": style_preset or "default", } return story_text, metadata def build_ui(): with gr.Blocks(title="StoryBox — AI Story Generator", theme=gr.themes.Soft()) as demo: gr.Markdown("# 📖 StoryBox") gr.Markdown("Generate long-form stories with multi-agent simulation. Supports 20+ languages and multiple LLM providers.") with gr.Row(): # Left panel: Configuration with gr.Column(scale=1): gr.Markdown("## 🤖 Model") provider = gr.Dropdown( choices=list(PROVIDER_MODELS.keys()), value="Ollama", label="Provider", ) model = gr.Dropdown( choices=PROVIDER_MODELS["Ollama"]["models"], value="llama3.1:8b", label="Model", ) api_key = gr.Textbox( label="API Key", type="password", visible=False, placeholder="Set env var or paste key", ) rpm_warning = gr.Markdown(visible=False) provider.change(update_model_dropdown, inputs=provider, outputs=model) provider.change(update_key_visibility, inputs=provider, outputs=api_key) provider.change(update_rpm_warning, inputs=provider, outputs=rpm_warning) gr.Markdown("## 🌍 Language") language = gr.Dropdown( choices=list(LANGUAGES.keys()), value="English", label="Output Language", ) multilingual_mode = gr.Radio( choices=["native", "translate"], value="native", label="Generation Mode", info="Native = generate directly in target language. Translate = English then translate.", ) gr.Markdown("## ⚙️ Simulation") story_setting = gr.Dropdown( choices=[f"story{i:02d}" for i in range(1, 21)], value="story01", label="Story Setting", ) num_days = gr.Slider(1, 14, value=3, step=1, label="Simulation Days") num_personas = gr.Slider(2, 6, value=0, step=1, label="Number of Characters (0 = all)") gr.Markdown("## 🎨 Style") style_preset = gr.Dropdown( choices=[""] + list_style_presets(), value="", label="Writing Style Preset", ) enable_features = gr.Checkbox( label="Enable Advanced Features (arcs, subplots, pacing)", value=False, ) gr.Markdown("## 🔧 Parameters") temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature") max_tokens = gr.Slider(1000, 16000, value=8000, step=1000, label="Max Tokens") generate_btn = gr.Button("🚀 Generate Story", variant="primary", size="lg") # Right panel: Output with gr.Column(scale=2): gr.Markdown("## 📜 Generated Story") output_story = gr.Textbox( label="", lines=35, max_lines=50, show_copy_button=True, autoscroll=False, ) output_metadata = gr.JSON(label="Story Metadata") with gr.Row(): download_json = gr.DownloadButton("⬇ Download JSON", visible=False) download_txt = gr.DownloadButton("⬇ Download TXT", visible=False) # Examples gr.Markdown("## 💡 Quick Examples") gr.Examples( examples=[ ["Ollama", "llama3.1:8b", "", "story01", "English", "native", 3, 0, 0.8, "", False, 8000], ["MLX (Apple Silicon)", "llama3.1-8b-mlx", "", "story01", "English", "native", 7, 0, 0.8, "classic_fantasy", False, 8000], ["Ollama", "qwen2.5", "", "story01", "Hindi (हिन्दी)", "native", 3, 0, 0.7, "hindi_mythological", False, 8000], ["Ollama", "llama3.1:8b", "", "story01", "Hindi (हिन्दी)", "translate", 3, 0, 0.8, "bollywood_masala", False, 8000], ["NVIDIA NIM", "nvidia/meta/llama-3.1-8b-instruct", "YOUR_KEY", "story02", "English", "native", 5, 0, 0.8, "noir_detective", False, 8000], ["OpenAI", "gpt-4o-mini", "YOUR_KEY", "story03", "English", "native", 3, 0, 0.9, "cyberpunk_neon", True, 8000], ], inputs=[provider, model, api_key, story_setting, language, multilingual_mode, num_days, num_personas, temperature, style_preset, enable_features, max_tokens], label="Click to load example", ) # Footer gr.Markdown("---") gr.Markdown(""" **Supported Providers:** OpenAI, Ollama, MLX (Apple Silicon), NVIDIA NIM, HuggingFace **Supported Languages:** English, Hindi, Arabic, Chinese, Japanese, Spanish, French, Tamil, Telugu, Bengali, Marathi, Gujarati, Urdu, Malayalam, Kannada, Punjabi, Korean, Portuguese, German, Russian **Native vs Translate:** Native generation produces culturally authentic stories. Translate mode is a fallback for models with weak multilingual capability. **Rate Limits:** NIM free tier = 40 RPM (auto-throttled). OpenAI tier 1 = 60 RPM. Local providers (Ollama, MLX) have no limits. """) generate_btn.click( fn=generate_story, inputs=[provider, model, api_key, story_setting, language, multilingual_mode, num_days, num_personas, temperature, style_preset, enable_features, max_tokens], outputs=[output_story, output_metadata], ) return demo if __name__ == "__main__": demo = build_ui() demo.launch(server_name="0.0.0.0", server_port=7860, share=False)