Commit ·
347b758
1
Parent(s): 93d86d5
Add Groq API support as primary script generator with Gemini fallback
Browse files- .env.example +7 -2
- config.py +2 -1
- modules/story_reels/__init__.py +5 -1
- modules/story_reels/services/script_generator.py +111 -55
- requirements.txt +1 -0
.env.example
CHANGED
|
@@ -52,6 +52,11 @@ CF_URL=https://image-api.yourworker.workers.dev
|
|
| 52 |
# Cloudflare API Key (FALLBACK)
|
| 53 |
CF_API=your_api_key_here
|
| 54 |
|
| 55 |
-
# Gemini API Key (Required for AI script generation)
|
| 56 |
# Get from: https://aistudio.google.com/apikey
|
| 57 |
-
GEMINI_API_KEY=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Cloudflare API Key (FALLBACK)
|
| 53 |
CF_API=your_api_key_here
|
| 54 |
|
| 55 |
+
# Gemini API Key (Required for AI script generation - Fallback)
|
| 56 |
# Get from: https://aistudio.google.com/apikey
|
| 57 |
+
GEMINI_API_KEY=your_gemini_key_here
|
| 58 |
+
|
| 59 |
+
# Groq API Key (Primary for AI script generation)
|
| 60 |
+
# Get from: https://console.groq.com/keys
|
| 61 |
+
GROQ_API=gsk_your_groq_key_here
|
| 62 |
+
```
|
config.py
CHANGED
|
@@ -75,7 +75,8 @@ class NCAkitConfig(BaseConfig):
|
|
| 75 |
nvidia_api_key: Optional[str] = None # NVIDIA API key (primary)
|
| 76 |
cf_url: Optional[str] = None # Cloudflare Worker URL (fallback)
|
| 77 |
cf_api: Optional[str] = None # Cloudflare API key (fallback)
|
| 78 |
-
gemini_api_key: Optional[str] = None # For AI script generation
|
|
|
|
| 79 |
|
| 80 |
@property
|
| 81 |
def videos_dir_path(self) -> Path:
|
|
|
|
| 75 |
nvidia_api_key: Optional[str] = None # NVIDIA API key (primary)
|
| 76 |
cf_url: Optional[str] = None # Cloudflare Worker URL (fallback)
|
| 77 |
cf_api: Optional[str] = None # Cloudflare API key (fallback)
|
| 78 |
+
gemini_api_key: Optional[str] = None # For AI script generation (fallback)
|
| 79 |
+
groq_api: Optional[str] = None # Groq API key (primary for script generation)
|
| 80 |
|
| 81 |
@property
|
| 82 |
def videos_dir_path(self) -> Path:
|
modules/story_reels/__init__.py
CHANGED
|
@@ -52,7 +52,11 @@ def register(app: FastAPI, config):
|
|
| 52 |
|
| 53 |
# Initialize Script Generator (Gemini)
|
| 54 |
logger.info("Initializing script generator (Gemini)...")
|
| 55 |
-
script_generator = ScriptGenerator(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Initialize NVIDIA client (PRIMARY)
|
| 58 |
nvidia_client = None
|
|
|
|
| 52 |
|
| 53 |
# Initialize Script Generator (Gemini)
|
| 54 |
logger.info("Initializing script generator (Gemini)...")
|
| 55 |
+
script_generator = ScriptGenerator(
|
| 56 |
+
gemini_api_key=config.gemini_api_key,
|
| 57 |
+
groq_api_key=config.groq_api
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
|
| 61 |
# Initialize NVIDIA client (PRIMARY)
|
| 62 |
nvidia_client = None
|
modules/story_reels/services/script_generator.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
| 1 |
"""
|
| 2 |
-
Script Generator using
|
| 3 |
Generates story scripts from topics for TTS narration
|
|
|
|
| 4 |
"""
|
| 5 |
import logging
|
| 6 |
import json
|
|
|
|
| 7 |
from typing import Optional
|
| 8 |
|
| 9 |
-
from google import genai
|
| 10 |
-
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
class ScriptGenerator:
|
| 15 |
"""
|
| 16 |
-
Generates story scripts using
|
| 17 |
|
| 18 |
Features:
|
| 19 |
- Topic → Full narration script (<=1000 chars)
|
|
@@ -22,7 +22,8 @@ class ScriptGenerator:
|
|
| 22 |
- Optimized for TTS output
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
|
|
|
|
| 26 |
|
| 27 |
# System prompt for script generation
|
| 28 |
SYSTEM_PROMPT = """You are a professional script writer for short-form video content (TikTok, Reels, Shorts).
|
|
@@ -38,9 +39,32 @@ RULES:
|
|
| 38 |
|
| 39 |
If a character is provided, write the story from their perspective or about them."""
|
| 40 |
|
| 41 |
-
def __init__(self,
|
| 42 |
-
self.
|
| 43 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def generate_script(
|
| 46 |
self,
|
|
@@ -50,14 +74,7 @@ If a character is provided, write the story from their perspective or about them
|
|
| 50 |
) -> str:
|
| 51 |
"""
|
| 52 |
Generate a story script from topic.
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
topic: Story topic/idea
|
| 56 |
-
character_name: Optional character name to include
|
| 57 |
-
max_chars: Maximum character limit (default 1000)
|
| 58 |
-
|
| 59 |
-
Returns:
|
| 60 |
-
Generated script text
|
| 61 |
"""
|
| 62 |
# Build the prompt
|
| 63 |
user_prompt = f"Topic: {topic}"
|
|
@@ -69,30 +86,58 @@ If a character is provided, write the story from their perspective or about them
|
|
| 69 |
|
| 70 |
logger.info(f"Generating script for topic: {topic[:50]}...")
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
@staticmethod
|
| 92 |
-
def test_connection(
|
| 93 |
"""Test API connection"""
|
| 94 |
try:
|
| 95 |
-
gen = ScriptGenerator(
|
| 96 |
gen.generate_script("test", max_chars=50)
|
| 97 |
return True
|
| 98 |
except:
|
|
@@ -133,20 +178,10 @@ Return ONLY valid JSON array, no markdown, no explanation:
|
|
| 133 |
) -> list:
|
| 134 |
"""
|
| 135 |
Generate detailed image prompts for all 2-second chunks.
|
| 136 |
-
|
| 137 |
-
Args:
|
| 138 |
-
full_script: Complete narration script (for context)
|
| 139 |
-
chunks: List of {chunk_id, text, duration} from SRTParser
|
| 140 |
-
character_profile: Optional character dict
|
| 141 |
-
max_batch: Max chunks per API call (default 30)
|
| 142 |
-
|
| 143 |
-
Returns:
|
| 144 |
-
List of {chunk_id, prompt} dicts
|
| 145 |
"""
|
| 146 |
all_prompts = []
|
| 147 |
total_chunks = len(chunks)
|
| 148 |
|
| 149 |
-
# Split into batches if too many chunks
|
| 150 |
for batch_start in range(0, total_chunks, max_batch):
|
| 151 |
batch_end = min(batch_start + max_batch, total_chunks)
|
| 152 |
batch_chunks = chunks[batch_start:batch_end]
|
|
@@ -179,12 +214,13 @@ IMPORTANT: Include this character description in EVERY prompt!
|
|
| 179 |
user_prompt += "\nGenerate detailed image prompts for each chunk. Return ONLY JSON array."
|
| 180 |
|
| 181 |
try:
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
| 188 |
|
| 189 |
# Clean response - remove markdown if present
|
| 190 |
text = text.strip()
|
|
@@ -202,15 +238,13 @@ IMPORTANT: Include this character description in EVERY prompt!
|
|
| 202 |
|
| 203 |
except json.JSONDecodeError as e:
|
| 204 |
logger.error(f"Failed to parse JSON response: {e}")
|
| 205 |
-
# Fallback: create simple prompts
|
| 206 |
for chunk in batch_chunks:
|
| 207 |
all_prompts.append({
|
| 208 |
"chunk_id": chunk["chunk_id"],
|
| 209 |
"prompt": f"{chunk['text']}, semi-realistic style, high quality, detailed"
|
| 210 |
})
|
| 211 |
except Exception as e:
|
| 212 |
-
logger.error(f"
|
| 213 |
-
# Fallback
|
| 214 |
for chunk in batch_chunks:
|
| 215 |
all_prompts.append({
|
| 216 |
"chunk_id": chunk["chunk_id"],
|
|
@@ -219,3 +253,25 @@ IMPORTANT: Include this character description in EVERY prompt!
|
|
| 219 |
|
| 220 |
logger.info(f"Generated {len(all_prompts)} total image prompts")
|
| 221 |
return all_prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Script Generator using Groq and Gemini APIs
|
| 3 |
Generates story scripts from topics for TTS narration
|
| 4 |
+
Tries Groq first (works in all regions), falls back to Gemini
|
| 5 |
"""
|
| 6 |
import logging
|
| 7 |
import json
|
| 8 |
+
import os
|
| 9 |
from typing import Optional
|
| 10 |
|
|
|
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
class ScriptGenerator:
|
| 15 |
"""
|
| 16 |
+
Generates story scripts using Groq API (primary) or Gemini API (fallback).
|
| 17 |
|
| 18 |
Features:
|
| 19 |
- Topic → Full narration script (<=1000 chars)
|
|
|
|
| 22 |
- Optimized for TTS output
|
| 23 |
"""
|
| 24 |
|
| 25 |
+
GROQ_MODEL = "llama-3.3-70b-versatile" # Fast and good quality
|
| 26 |
+
GEMINI_MODEL = "gemini-2.5-flash"
|
| 27 |
|
| 28 |
# System prompt for script generation
|
| 29 |
SYSTEM_PROMPT = """You are a professional script writer for short-form video content (TikTok, Reels, Shorts).
|
|
|
|
| 39 |
|
| 40 |
If a character is provided, write the story from their perspective or about them."""
|
| 41 |
|
| 42 |
+
def __init__(self, gemini_api_key: str = None, groq_api_key: str = None):
|
| 43 |
+
self.gemini_api_key = gemini_api_key
|
| 44 |
+
self.groq_api_key = groq_api_key
|
| 45 |
+
|
| 46 |
+
# Initialize clients based on available keys
|
| 47 |
+
self.groq_client = None
|
| 48 |
+
self.gemini_client = None
|
| 49 |
+
|
| 50 |
+
if groq_api_key:
|
| 51 |
+
try:
|
| 52 |
+
from groq import Groq
|
| 53 |
+
self.groq_client = Groq(api_key=groq_api_key)
|
| 54 |
+
logger.info("Groq client initialized (primary)")
|
| 55 |
+
except ImportError:
|
| 56 |
+
logger.warning("Groq package not installed, using Gemini only")
|
| 57 |
+
|
| 58 |
+
if gemini_api_key:
|
| 59 |
+
try:
|
| 60 |
+
from google import genai
|
| 61 |
+
self.gemini_client = genai.Client(api_key=gemini_api_key)
|
| 62 |
+
logger.info("Gemini client initialized (fallback)")
|
| 63 |
+
except ImportError:
|
| 64 |
+
logger.warning("google-genai package not installed")
|
| 65 |
+
|
| 66 |
+
if not self.groq_client and not self.gemini_client:
|
| 67 |
+
raise ValueError("At least one API key (GROQ_API or GEMINI_API_KEY) is required")
|
| 68 |
|
| 69 |
def generate_script(
|
| 70 |
self,
|
|
|
|
| 74 |
) -> str:
|
| 75 |
"""
|
| 76 |
Generate a story script from topic.
|
| 77 |
+
Tries Groq first, falls back to Gemini if Groq fails.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"""
|
| 79 |
# Build the prompt
|
| 80 |
user_prompt = f"Topic: {topic}"
|
|
|
|
| 86 |
|
| 87 |
logger.info(f"Generating script for topic: {topic[:50]}...")
|
| 88 |
|
| 89 |
+
# Try Groq first (works in all regions)
|
| 90 |
+
if self.groq_client:
|
| 91 |
+
try:
|
| 92 |
+
script = self._generate_with_groq(user_prompt)
|
| 93 |
+
if len(script) > max_chars:
|
| 94 |
+
script = script[:max_chars].rsplit(' ', 1)[0] + "."
|
| 95 |
+
logger.info(f"Generated script with Groq: {len(script)} chars")
|
| 96 |
+
return script.strip()
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.warning(f"Groq failed: {e}, trying Gemini...")
|
| 99 |
+
|
| 100 |
+
# Fallback to Gemini
|
| 101 |
+
if self.gemini_client:
|
| 102 |
+
try:
|
| 103 |
+
script = self._generate_with_gemini(user_prompt)
|
| 104 |
+
if len(script) > max_chars:
|
| 105 |
+
script = script[:max_chars].rsplit(' ', 1)[0] + "."
|
| 106 |
+
logger.info(f"Generated script with Gemini: {len(script)} chars")
|
| 107 |
+
return script.strip()
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Gemini also failed: {e}")
|
| 110 |
+
raise Exception(f"Script generation failed: {e}")
|
| 111 |
+
|
| 112 |
+
raise Exception("No AI backend available for script generation")
|
| 113 |
+
|
| 114 |
+
def _generate_with_groq(self, user_prompt: str) -> str:
|
| 115 |
+
"""Generate using Groq API"""
|
| 116 |
+
completion = self.groq_client.chat.completions.create(
|
| 117 |
+
model=self.GROQ_MODEL,
|
| 118 |
+
messages=[
|
| 119 |
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
| 120 |
+
{"role": "user", "content": user_prompt}
|
| 121 |
+
],
|
| 122 |
+
temperature=0.7,
|
| 123 |
+
max_tokens=500,
|
| 124 |
+
top_p=0.9
|
| 125 |
+
)
|
| 126 |
+
return completion.choices[0].message.content
|
| 127 |
+
|
| 128 |
+
def _generate_with_gemini(self, user_prompt: str) -> str:
|
| 129 |
+
"""Generate using Gemini API"""
|
| 130 |
+
response = self.gemini_client.models.generate_content(
|
| 131 |
+
model=self.GEMINI_MODEL,
|
| 132 |
+
contents=self.SYSTEM_PROMPT + "\n\n" + user_prompt
|
| 133 |
+
)
|
| 134 |
+
return response.text
|
| 135 |
|
| 136 |
@staticmethod
|
| 137 |
+
def test_connection(gemini_api_key: str = None, groq_api_key: str = None) -> bool:
|
| 138 |
"""Test API connection"""
|
| 139 |
try:
|
| 140 |
+
gen = ScriptGenerator(gemini_api_key=gemini_api_key, groq_api_key=groq_api_key)
|
| 141 |
gen.generate_script("test", max_chars=50)
|
| 142 |
return True
|
| 143 |
except:
|
|
|
|
| 178 |
) -> list:
|
| 179 |
"""
|
| 180 |
Generate detailed image prompts for all 2-second chunks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
"""
|
| 182 |
all_prompts = []
|
| 183 |
total_chunks = len(chunks)
|
| 184 |
|
|
|
|
| 185 |
for batch_start in range(0, total_chunks, max_batch):
|
| 186 |
batch_end = min(batch_start + max_batch, total_chunks)
|
| 187 |
batch_chunks = chunks[batch_start:batch_end]
|
|
|
|
| 214 |
user_prompt += "\nGenerate detailed image prompts for each chunk. Return ONLY JSON array."
|
| 215 |
|
| 216 |
try:
|
| 217 |
+
# Try Groq first
|
| 218 |
+
if self.groq_client:
|
| 219 |
+
text = self._generate_image_prompts_groq(user_prompt)
|
| 220 |
+
elif self.gemini_client:
|
| 221 |
+
text = self._generate_image_prompts_gemini(user_prompt)
|
| 222 |
+
else:
|
| 223 |
+
raise Exception("No AI backend available")
|
| 224 |
|
| 225 |
# Clean response - remove markdown if present
|
| 226 |
text = text.strip()
|
|
|
|
| 238 |
|
| 239 |
except json.JSONDecodeError as e:
|
| 240 |
logger.error(f"Failed to parse JSON response: {e}")
|
|
|
|
| 241 |
for chunk in batch_chunks:
|
| 242 |
all_prompts.append({
|
| 243 |
"chunk_id": chunk["chunk_id"],
|
| 244 |
"prompt": f"{chunk['text']}, semi-realistic style, high quality, detailed"
|
| 245 |
})
|
| 246 |
except Exception as e:
|
| 247 |
+
logger.error(f"AI API error: {e}")
|
|
|
|
| 248 |
for chunk in batch_chunks:
|
| 249 |
all_prompts.append({
|
| 250 |
"chunk_id": chunk["chunk_id"],
|
|
|
|
| 253 |
|
| 254 |
logger.info(f"Generated {len(all_prompts)} total image prompts")
|
| 255 |
return all_prompts
|
| 256 |
+
|
| 257 |
+
def _generate_image_prompts_groq(self, user_prompt: str) -> str:
|
| 258 |
+
"""Generate image prompts using Groq"""
|
| 259 |
+
completion = self.groq_client.chat.completions.create(
|
| 260 |
+
model=self.GROQ_MODEL,
|
| 261 |
+
messages=[
|
| 262 |
+
{"role": "system", "content": self.IMAGE_PROMPT_SYSTEM},
|
| 263 |
+
{"role": "user", "content": user_prompt}
|
| 264 |
+
],
|
| 265 |
+
temperature=0.7,
|
| 266 |
+
max_tokens=4000,
|
| 267 |
+
top_p=0.9
|
| 268 |
+
)
|
| 269 |
+
return completion.choices[0].message.content
|
| 270 |
+
|
| 271 |
+
def _generate_image_prompts_gemini(self, user_prompt: str) -> str:
|
| 272 |
+
"""Generate image prompts using Gemini"""
|
| 273 |
+
response = self.gemini_client.models.generate_content(
|
| 274 |
+
model=self.GEMINI_MODEL,
|
| 275 |
+
contents=self.IMAGE_PROMPT_SYSTEM + "\n\n" + user_prompt
|
| 276 |
+
)
|
| 277 |
+
return response.text
|
requirements.txt
CHANGED
|
@@ -19,6 +19,7 @@ numpy<2.0.0
|
|
| 19 |
# AI/ML
|
| 20 |
faster-whisper
|
| 21 |
google-genai
|
|
|
|
| 22 |
|
| 23 |
# Utilities
|
| 24 |
python-multipart
|
|
|
|
| 19 |
# AI/ML
|
| 20 |
faster-whisper
|
| 21 |
google-genai
|
| 22 |
+
groq
|
| 23 |
|
| 24 |
# Utilities
|
| 25 |
python-multipart
|