ismdrobiul489 commited on
Commit
347b758
·
1 Parent(s): 93d86d5

Add Groq API support as primary script generator with Gemini fallback

Browse files
.env.example CHANGED
@@ -52,6 +52,11 @@ CF_URL=https://image-api.yourworker.workers.dev
52
  # Cloudflare API Key (FALLBACK)
53
  CF_API=your_api_key_here
54
 
55
- # Gemini API Key (Required for AI script generation)
56
  # Get from: https://aistudio.google.com/apikey
57
- GEMINI_API_KEY=your_gemini_api_key_here
 
 
 
 
 
 
52
  # Cloudflare API Key (FALLBACK)
53
  CF_API=your_api_key_here
54
 
55
+ # Gemini API Key (Required for AI script generation - Fallback)
56
  # Get from: https://aistudio.google.com/apikey
57
+ GEMINI_API_KEY=your_gemini_key_here
58
+
59
+ # Groq API Key (Primary for AI script generation)
60
+ # Get from: https://console.groq.com/keys
61
+ GROQ_API=gsk_your_groq_key_here
62
+ ```
config.py CHANGED
@@ -75,7 +75,8 @@ class NCAkitConfig(BaseConfig):
75
  nvidia_api_key: Optional[str] = None # NVIDIA API key (primary)
76
  cf_url: Optional[str] = None # Cloudflare Worker URL (fallback)
77
  cf_api: Optional[str] = None # Cloudflare API key (fallback)
78
- gemini_api_key: Optional[str] = None # For AI script generation
 
79
 
80
  @property
81
  def videos_dir_path(self) -> Path:
 
75
  nvidia_api_key: Optional[str] = None # NVIDIA API key (primary)
76
  cf_url: Optional[str] = None # Cloudflare Worker URL (fallback)
77
  cf_api: Optional[str] = None # Cloudflare API key (fallback)
78
+ gemini_api_key: Optional[str] = None # For AI script generation (fallback)
79
+ groq_api: Optional[str] = None # Groq API key (primary for script generation)
80
 
81
  @property
82
  def videos_dir_path(self) -> Path:
modules/story_reels/__init__.py CHANGED
@@ -52,7 +52,11 @@ def register(app: FastAPI, config):
52
 
53
  # Initialize Script Generator (Gemini)
54
  logger.info("Initializing script generator (Gemini)...")
55
- script_generator = ScriptGenerator(config.gemini_api_key or "")
 
 
 
 
56
 
57
  # Initialize NVIDIA client (PRIMARY)
58
  nvidia_client = None
 
52
 
53
  # Initialize Script Generator (Gemini)
54
  logger.info("Initializing script generator (Gemini)...")
55
+ script_generator = ScriptGenerator(
56
+ gemini_api_key=config.gemini_api_key,
57
+ groq_api_key=config.groq_api
58
+ )
59
+
60
 
61
  # Initialize NVIDIA client (PRIMARY)
62
  nvidia_client = None
modules/story_reels/services/script_generator.py CHANGED
@@ -1,19 +1,19 @@
1
  """
2
- Script Generator using Google GenAI SDK
3
  Generates story scripts from topics for TTS narration
 
4
  """
5
  import logging
6
  import json
 
7
  from typing import Optional
8
 
9
- from google import genai
10
-
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
  class ScriptGenerator:
15
  """
16
- Generates story scripts using Google Gemini API via google-genai SDK.
17
 
18
  Features:
19
  - Topic → Full narration script (<=1000 chars)
@@ -22,7 +22,8 @@ class ScriptGenerator:
22
  - Optimized for TTS output
23
  """
24
 
25
- MODEL = "gemini-2.5-flash"
 
26
 
27
  # System prompt for script generation
28
  SYSTEM_PROMPT = """You are a professional script writer for short-form video content (TikTok, Reels, Shorts).
@@ -38,9 +39,32 @@ RULES:
38
 
39
  If a character is provided, write the story from their perspective or about them."""
40
 
41
- def __init__(self, api_key: str):
42
- self.api_key = api_key
43
- self.client = genai.Client(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def generate_script(
46
  self,
@@ -50,14 +74,7 @@ If a character is provided, write the story from their perspective or about them
50
  ) -> str:
51
  """
52
  Generate a story script from topic.
53
-
54
- Args:
55
- topic: Story topic/idea
56
- character_name: Optional character name to include
57
- max_chars: Maximum character limit (default 1000)
58
-
59
- Returns:
60
- Generated script text
61
  """
62
  # Build the prompt
63
  user_prompt = f"Topic: {topic}"
@@ -69,30 +86,58 @@ If a character is provided, write the story from their perspective or about them
69
 
70
  logger.info(f"Generating script for topic: {topic[:50]}...")
71
 
72
- try:
73
- response = self.client.models.generate_content(
74
- model=self.MODEL,
75
- contents=self.SYSTEM_PROMPT + "\n\n" + user_prompt
76
- )
77
-
78
- script = response.text
79
-
80
- # Enforce character limit
81
- if len(script) > max_chars:
82
- script = script[:max_chars].rsplit(' ', 1)[0] + "."
83
-
84
- logger.info(f"Generated script: {len(script)} chars")
85
- return script.strip()
86
-
87
- except Exception as e:
88
- logger.error(f"Gemini API error: {e}")
89
- raise Exception(f"Script generation failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  @staticmethod
92
- def test_connection(api_key: str) -> bool:
93
  """Test API connection"""
94
  try:
95
- gen = ScriptGenerator(api_key)
96
  gen.generate_script("test", max_chars=50)
97
  return True
98
  except:
@@ -133,20 +178,10 @@ Return ONLY valid JSON array, no markdown, no explanation:
133
  ) -> list:
134
  """
135
  Generate detailed image prompts for all 2-second chunks.
136
-
137
- Args:
138
- full_script: Complete narration script (for context)
139
- chunks: List of {chunk_id, text, duration} from SRTParser
140
- character_profile: Optional character dict
141
- max_batch: Max chunks per API call (default 30)
142
-
143
- Returns:
144
- List of {chunk_id, prompt} dicts
145
  """
146
  all_prompts = []
147
  total_chunks = len(chunks)
148
 
149
- # Split into batches if too many chunks
150
  for batch_start in range(0, total_chunks, max_batch):
151
  batch_end = min(batch_start + max_batch, total_chunks)
152
  batch_chunks = chunks[batch_start:batch_end]
@@ -179,12 +214,13 @@ IMPORTANT: Include this character description in EVERY prompt!
179
  user_prompt += "\nGenerate detailed image prompts for each chunk. Return ONLY JSON array."
180
 
181
  try:
182
- response = self.client.models.generate_content(
183
- model=self.MODEL,
184
- contents=self.IMAGE_PROMPT_SYSTEM + "\n\n" + user_prompt
185
- )
186
-
187
- text = response.text
 
188
 
189
  # Clean response - remove markdown if present
190
  text = text.strip()
@@ -202,15 +238,13 @@ IMPORTANT: Include this character description in EVERY prompt!
202
 
203
  except json.JSONDecodeError as e:
204
  logger.error(f"Failed to parse JSON response: {e}")
205
- # Fallback: create simple prompts
206
  for chunk in batch_chunks:
207
  all_prompts.append({
208
  "chunk_id": chunk["chunk_id"],
209
  "prompt": f"{chunk['text']}, semi-realistic style, high quality, detailed"
210
  })
211
  except Exception as e:
212
- logger.error(f"Gemini API error: {e}")
213
- # Fallback
214
  for chunk in batch_chunks:
215
  all_prompts.append({
216
  "chunk_id": chunk["chunk_id"],
@@ -219,3 +253,25 @@ IMPORTANT: Include this character description in EVERY prompt!
219
 
220
  logger.info(f"Generated {len(all_prompts)} total image prompts")
221
  return all_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Script Generator using Groq and Gemini APIs
3
  Generates story scripts from topics for TTS narration
4
+ Tries Groq first (works in all regions), falls back to Gemini
5
  """
6
  import logging
7
  import json
8
+ import os
9
  from typing import Optional
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
  class ScriptGenerator:
15
  """
16
+ Generates story scripts using Groq API (primary) or Gemini API (fallback).
17
 
18
  Features:
19
  - Topic → Full narration script (<=1000 chars)
 
22
  - Optimized for TTS output
23
  """
24
 
25
+ GROQ_MODEL = "llama-3.3-70b-versatile" # Fast and good quality
26
+ GEMINI_MODEL = "gemini-2.5-flash"
27
 
28
  # System prompt for script generation
29
  SYSTEM_PROMPT = """You are a professional script writer for short-form video content (TikTok, Reels, Shorts).
 
39
 
40
  If a character is provided, write the story from their perspective or about them."""
41
 
42
+ def __init__(self, gemini_api_key: str = None, groq_api_key: str = None):
43
+ self.gemini_api_key = gemini_api_key
44
+ self.groq_api_key = groq_api_key
45
+
46
+ # Initialize clients based on available keys
47
+ self.groq_client = None
48
+ self.gemini_client = None
49
+
50
+ if groq_api_key:
51
+ try:
52
+ from groq import Groq
53
+ self.groq_client = Groq(api_key=groq_api_key)
54
+ logger.info("Groq client initialized (primary)")
55
+ except ImportError:
56
+ logger.warning("Groq package not installed, using Gemini only")
57
+
58
+ if gemini_api_key:
59
+ try:
60
+ from google import genai
61
+ self.gemini_client = genai.Client(api_key=gemini_api_key)
62
+ logger.info("Gemini client initialized (fallback)")
63
+ except ImportError:
64
+ logger.warning("google-genai package not installed")
65
+
66
+ if not self.groq_client and not self.gemini_client:
67
+ raise ValueError("At least one API key (GROQ_API or GEMINI_API_KEY) is required")
68
 
69
  def generate_script(
70
  self,
 
74
  ) -> str:
75
  """
76
  Generate a story script from topic.
77
+ Tries Groq first, falls back to Gemini if Groq fails.
 
 
 
 
 
 
 
78
  """
79
  # Build the prompt
80
  user_prompt = f"Topic: {topic}"
 
86
 
87
  logger.info(f"Generating script for topic: {topic[:50]}...")
88
 
89
+ # Try Groq first (works in all regions)
90
+ if self.groq_client:
91
+ try:
92
+ script = self._generate_with_groq(user_prompt)
93
+ if len(script) > max_chars:
94
+ script = script[:max_chars].rsplit(' ', 1)[0] + "."
95
+ logger.info(f"Generated script with Groq: {len(script)} chars")
96
+ return script.strip()
97
+ except Exception as e:
98
+ logger.warning(f"Groq failed: {e}, trying Gemini...")
99
+
100
+ # Fallback to Gemini
101
+ if self.gemini_client:
102
+ try:
103
+ script = self._generate_with_gemini(user_prompt)
104
+ if len(script) > max_chars:
105
+ script = script[:max_chars].rsplit(' ', 1)[0] + "."
106
+ logger.info(f"Generated script with Gemini: {len(script)} chars")
107
+ return script.strip()
108
+ except Exception as e:
109
+ logger.error(f"Gemini also failed: {e}")
110
+ raise Exception(f"Script generation failed: {e}")
111
+
112
+ raise Exception("No AI backend available for script generation")
113
+
114
+ def _generate_with_groq(self, user_prompt: str) -> str:
115
+ """Generate using Groq API"""
116
+ completion = self.groq_client.chat.completions.create(
117
+ model=self.GROQ_MODEL,
118
+ messages=[
119
+ {"role": "system", "content": self.SYSTEM_PROMPT},
120
+ {"role": "user", "content": user_prompt}
121
+ ],
122
+ temperature=0.7,
123
+ max_tokens=500,
124
+ top_p=0.9
125
+ )
126
+ return completion.choices[0].message.content
127
+
128
+ def _generate_with_gemini(self, user_prompt: str) -> str:
129
+ """Generate using Gemini API"""
130
+ response = self.gemini_client.models.generate_content(
131
+ model=self.GEMINI_MODEL,
132
+ contents=self.SYSTEM_PROMPT + "\n\n" + user_prompt
133
+ )
134
+ return response.text
135
 
136
  @staticmethod
137
+ def test_connection(gemini_api_key: str = None, groq_api_key: str = None) -> bool:
138
  """Test API connection"""
139
  try:
140
+ gen = ScriptGenerator(gemini_api_key=gemini_api_key, groq_api_key=groq_api_key)
141
  gen.generate_script("test", max_chars=50)
142
  return True
143
  except:
 
178
  ) -> list:
179
  """
180
  Generate detailed image prompts for all 2-second chunks.
 
 
 
 
 
 
 
 
 
181
  """
182
  all_prompts = []
183
  total_chunks = len(chunks)
184
 
 
185
  for batch_start in range(0, total_chunks, max_batch):
186
  batch_end = min(batch_start + max_batch, total_chunks)
187
  batch_chunks = chunks[batch_start:batch_end]
 
214
  user_prompt += "\nGenerate detailed image prompts for each chunk. Return ONLY JSON array."
215
 
216
  try:
217
+ # Try Groq first
218
+ if self.groq_client:
219
+ text = self._generate_image_prompts_groq(user_prompt)
220
+ elif self.gemini_client:
221
+ text = self._generate_image_prompts_gemini(user_prompt)
222
+ else:
223
+ raise Exception("No AI backend available")
224
 
225
  # Clean response - remove markdown if present
226
  text = text.strip()
 
238
 
239
  except json.JSONDecodeError as e:
240
  logger.error(f"Failed to parse JSON response: {e}")
 
241
  for chunk in batch_chunks:
242
  all_prompts.append({
243
  "chunk_id": chunk["chunk_id"],
244
  "prompt": f"{chunk['text']}, semi-realistic style, high quality, detailed"
245
  })
246
  except Exception as e:
247
+ logger.error(f"AI API error: {e}")
 
248
  for chunk in batch_chunks:
249
  all_prompts.append({
250
  "chunk_id": chunk["chunk_id"],
 
253
 
254
  logger.info(f"Generated {len(all_prompts)} total image prompts")
255
  return all_prompts
256
+
257
+ def _generate_image_prompts_groq(self, user_prompt: str) -> str:
258
+ """Generate image prompts using Groq"""
259
+ completion = self.groq_client.chat.completions.create(
260
+ model=self.GROQ_MODEL,
261
+ messages=[
262
+ {"role": "system", "content": self.IMAGE_PROMPT_SYSTEM},
263
+ {"role": "user", "content": user_prompt}
264
+ ],
265
+ temperature=0.7,
266
+ max_tokens=4000,
267
+ top_p=0.9
268
+ )
269
+ return completion.choices[0].message.content
270
+
271
+ def _generate_image_prompts_gemini(self, user_prompt: str) -> str:
272
+ """Generate image prompts using Gemini"""
273
+ response = self.gemini_client.models.generate_content(
274
+ model=self.GEMINI_MODEL,
275
+ contents=self.IMAGE_PROMPT_SYSTEM + "\n\n" + user_prompt
276
+ )
277
+ return response.text
requirements.txt CHANGED
@@ -19,6 +19,7 @@ numpy<2.0.0
19
  # AI/ML
20
  faster-whisper
21
  google-genai
 
22
 
23
  # Utilities
24
  python-multipart
 
19
  # AI/ML
20
  faster-whisper
21
  google-genai
22
+ groq
23
 
24
  # Utilities
25
  python-multipart