raazkumar's picture
Upload ui/app.py with huggingface_hub
5f9d45e verified
"""
StoryBox Gradio Web UI — model/provider picker + multilingual + style presets + NIM RPM warning.
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
import gradio as gr
# Add project root
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from reverie.config.config import Config
from reverie.skills.writing_styles import list_style_presets, get_style_preset
from reverie.skills.multilingual import get_language_config, LanguageConfig
# ---------------------------------------------------------------------------
# Provider / model registry
# ---------------------------------------------------------------------------
PROVIDER_MODELS = {
"OpenAI": {
"models": ["gpt-4o-mini", "gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"],
"env_var": "OPENAI_API_KEY",
"needs_key": True,
"rpm": 60,
"rpm_note": "Tier 1: 60 RPM",
},
"Ollama": {
"models": ["llama3.1:8b", "llama3.1:70b", "gemma4", "gemma4:9b",
"gemma4:27b", "mistral", "mixtral", "phi4", "qwen2.5", "command-r",
"deepseek-r1", "llama3.3", "codellama"],
"env_var": "",
"needs_key": False,
"rpm": 10000,
"rpm_note": "Local — unlimited",
},
"MLX (Apple Silicon)": {
"models": ["llama3.1-8b-mlx", "llama3.1-70b-mlx", "mistral-mlx",
"phi3-mlx", "qwen2.5-mlx", "gemma2-mlx", "deepseek-mlx"],
"env_var": "",
"needs_key": False,
"rpm": 10000,
"rpm_note": "Local — unlimited",
},
"NVIDIA NIM": {
"models": ["nvidia/meta/llama-3.1-8b-instruct",
"nvidia/meta/llama-3.1-70b-instruct",
"nvidia/mistralai/mistral-7b-instruct-v0.3",
"nvidia/google/gemma-2-9b-it",
"nvidia/microsoft/phi-3-mini-128k-instruct"],
"env_var": "NIM_API_KEY",
"needs_key": True,
"rpm": 40,
"rpm_note": "⚠️ Free tier: 40 RPM. Rate limiting auto-enabled.",
},
"HuggingFace": {
"models": ["mistralai/Mistral-7B-Instruct-v0.3",
"meta-llama/Llama-3.1-8B-Instruct"],
"env_var": "HF_TOKEN",
"needs_key": True,
"rpm": 60,
"rpm_note": "Inference API free tier: 60 RPM",
},
}
LANGUAGES = {
"English": "en",
"Hindi (हिन्दी)": "hi",
"Arabic (العربية)": "ar",
"Chinese (中文)": "zh",
"Japanese (日本語)": "ja",
"Spanish (Español)": "es",
"French (Français)": "fr",
"Tamil (தமிழ்)": "ta",
"Telugu (తెలుగు)": "te",
"Bengali (বাংলা)": "bn",
"Marathi (मराठी)": "mr",
"Gujarati (ગુજરાતી)": "gu",
"Urdu (اردو)": "ur",
"Malayalam (മലയാളം)": "ml",
"Kannada (ಕನ್ನಡ)": "kn",
"Punjabi (ਪੰਜਾਬੀ)": "pa",
"Korean (한국어)": "ko",
"Portuguese (Português)": "pt",
"German (Deutsch)": "de",
"Russian (Русский)": "ru",
}
def update_model_dropdown(provider: str):
"""Update model list when provider changes."""
models = PROVIDER_MODELS.get(provider, {}).get("models", [])
return gr.Dropdown(choices=models, value=models[0] if models else None)
def update_key_visibility(provider: str):
"""Show/hide API key input based on provider."""
info = PROVIDER_MODELS.get(provider, {})
needs_key = info.get("needs_key", False)
env_var = info.get("env_var", "")
placeholder = f"${env_var} env var or paste key here" if env_var else ""
return gr.Textbox(visible=needs_key, placeholder=placeholder)
def update_rpm_warning(provider: str):
"""Show RPM warning for providers with strict limits."""
info = PROVIDER_MODELS.get(provider, {})
rpm_note = info.get("rpm_note", "")
is_warning = "⚠️" in rpm_note or "40" in rpm_note
return gr.Markdown(value=f"**{rpm_note}**", visible=bool(rpm_note))
def generate_story(
provider: str,
model: str,
api_key: str,
story_setting: str,
language: str,
multilingual_mode: str,
num_days: int,
num_personas: int,
temperature: float,
style_preset: str,
enable_features: bool,
max_tokens: int,
progress=gr.Progress(),
):
"""Run StoryBox pipeline with UI parameters."""
# Apply config
provider_map = {
"OpenAI": "openai",
"Ollama": "ollama",
"MLX (Apple Silicon)": "mlx",
"NVIDIA NIM": "nim",
"HuggingFace": "huggingface",
}
Config.model_provider = provider_map.get(provider, "ollama")
Config.llm_model_name = model
Config.temperature = temperature
Config.max_tokens = max_tokens
Config.story_name = story_setting
Config.story_dir = f"{Config.data_dir}/{story_setting}"
Config.max_iteration = 24 * num_days
Config.target_language = LANGUAGES.get(language, "en")
Config.multilingual_mode = multilingual_mode
# API keys
if provider == "OpenAI" and api_key:
Config.api_key = api_key
elif provider == "NVIDIA NIM" and api_key:
Config.nim_api_key = api_key
# Style
if style_preset:
Config.writing_style = style_preset
# Run simulation
import asyncio
from reverie.run import reverie_task
progress(0.1, desc="Loading world and personas...")
asyncio.run(reverie_task())
# Load result
output_path = Path(Config.output_dir) / "story.json"
if not output_path.exists():
return "Error: Story generation failed.", {}
with open(output_path) as f:
data = json.load(f)
story_text = data.get("story", "")
metadata = {
"title": data.get("story_title", ""),
"type": data.get("story_type", ""),
"language": language,
"multilingual_mode": multilingual_mode,
"provider": provider,
"model": model,
"temperature": temperature,
"word_count": len(story_text.split()),
"style": style_preset or "default",
}
return story_text, metadata
def build_ui():
with gr.Blocks(title="StoryBox — AI Story Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📖 StoryBox")
gr.Markdown("Generate long-form stories with multi-agent simulation. Supports 20+ languages and multiple LLM providers.")
with gr.Row():
# Left panel: Configuration
with gr.Column(scale=1):
gr.Markdown("## 🤖 Model")
provider = gr.Dropdown(
choices=list(PROVIDER_MODELS.keys()),
value="Ollama",
label="Provider",
)
model = gr.Dropdown(
choices=PROVIDER_MODELS["Ollama"]["models"],
value="llama3.1:8b",
label="Model",
)
api_key = gr.Textbox(
label="API Key",
type="password",
visible=False,
placeholder="Set env var or paste key",
)
rpm_warning = gr.Markdown(visible=False)
provider.change(update_model_dropdown, inputs=provider, outputs=model)
provider.change(update_key_visibility, inputs=provider, outputs=api_key)
provider.change(update_rpm_warning, inputs=provider, outputs=rpm_warning)
gr.Markdown("## 🌍 Language")
language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="English",
label="Output Language",
)
multilingual_mode = gr.Radio(
choices=["native", "translate"],
value="native",
label="Generation Mode",
info="Native = generate directly in target language. Translate = English then translate.",
)
gr.Markdown("## ⚙️ Simulation")
story_setting = gr.Dropdown(
choices=[f"story{i:02d}" for i in range(1, 21)],
value="story01",
label="Story Setting",
)
num_days = gr.Slider(1, 14, value=3, step=1, label="Simulation Days")
num_personas = gr.Slider(2, 6, value=0, step=1,
label="Number of Characters (0 = all)")
gr.Markdown("## 🎨 Style")
style_preset = gr.Dropdown(
choices=[""] + list_style_presets(),
value="",
label="Writing Style Preset",
)
enable_features = gr.Checkbox(
label="Enable Advanced Features (arcs, subplots, pacing)",
value=False,
)
gr.Markdown("## 🔧 Parameters")
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature")
max_tokens = gr.Slider(1000, 16000, value=8000, step=1000, label="Max Tokens")
generate_btn = gr.Button("🚀 Generate Story", variant="primary", size="lg")
# Right panel: Output
with gr.Column(scale=2):
gr.Markdown("## 📜 Generated Story")
output_story = gr.Textbox(
label="",
lines=35,
max_lines=50,
show_copy_button=True,
autoscroll=False,
)
output_metadata = gr.JSON(label="Story Metadata")
with gr.Row():
download_json = gr.DownloadButton("⬇ Download JSON", visible=False)
download_txt = gr.DownloadButton("⬇ Download TXT", visible=False)
# Examples
gr.Markdown("## 💡 Quick Examples")
gr.Examples(
examples=[
["Ollama", "llama3.1:8b", "", "story01", "English", "native", 3, 0, 0.8, "", False, 8000],
["MLX (Apple Silicon)", "llama3.1-8b-mlx", "", "story01", "English", "native", 7, 0, 0.8, "classic_fantasy", False, 8000],
["Ollama", "qwen2.5", "", "story01", "Hindi (हिन्दी)", "native", 3, 0, 0.7, "hindi_mythological", False, 8000],
["Ollama", "llama3.1:8b", "", "story01", "Hindi (हिन्दी)", "translate", 3, 0, 0.8, "bollywood_masala", False, 8000],
["NVIDIA NIM", "nvidia/meta/llama-3.1-8b-instruct", "YOUR_KEY", "story02", "English", "native", 5, 0, 0.8, "noir_detective", False, 8000],
["OpenAI", "gpt-4o-mini", "YOUR_KEY", "story03", "English", "native", 3, 0, 0.9, "cyberpunk_neon", True, 8000],
],
inputs=[provider, model, api_key, story_setting, language, multilingual_mode,
num_days, num_personas, temperature, style_preset, enable_features, max_tokens],
label="Click to load example",
)
# Footer
gr.Markdown("---")
gr.Markdown("""
**Supported Providers:** OpenAI, Ollama, MLX (Apple Silicon), NVIDIA NIM, HuggingFace
**Supported Languages:** English, Hindi, Arabic, Chinese, Japanese, Spanish, French,
Tamil, Telugu, Bengali, Marathi, Gujarati, Urdu, Malayalam, Kannada, Punjabi,
Korean, Portuguese, German, Russian
**Native vs Translate:** Native generation produces culturally authentic stories.
Translate mode is a fallback for models with weak multilingual capability.
**Rate Limits:** NIM free tier = 40 RPM (auto-throttled). OpenAI tier 1 = 60 RPM.
Local providers (Ollama, MLX) have no limits.
""")
generate_btn.click(
fn=generate_story,
inputs=[provider, model, api_key, story_setting, language, multilingual_mode,
num_days, num_personas, temperature, style_preset, enable_features, max_tokens],
outputs=[output_story, output_metadata],
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)