File size: 12,491 Bytes
5f3e9f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""AI API client for processing text input."""
import os
import sys
import threading
import httpx  # type: ignore
from openai import OpenAI  # type: ignore

# Add config to path
config_path = os.path.join(os.path.dirname(__file__), '..', '..', 'config')
sys.path.insert(0, config_path)

from config import API_URL, MODELS_CONFIG  # type: ignore

# Add utils to path
utils_path = os.path.join(os.path.dirname(__file__), '..', 'utils')
sys.path.insert(0, utils_path)

from cache_manager import CacheManager  # type: ignore
from retry_handler import retry_with_backoff  # type: ignore

# Initialize cache manager
cache = CacheManager()


def load_system_prompt():
    """Load the system prompt from file."""
    try:
        prompt_path = os.path.join(os.path.dirname(__file__), '..', '..', 'config', 'system_prompt.txt')
        with open(prompt_path, 'r', encoding='utf-8') as f:
            return f.read()
    except FileNotFoundError:
        return """You are an expert web developer. Convert the raw text notes into properly formatted HTML content using CSS classes: .exercise-title, .question, .answer, .vocabulary-item, .section-number. Output ONLY the HTML content without DOCTYPE, html, head, or body tags."""


@retry_with_backoff(max_retries=3, base_delay=2, max_delay=30)
def _make_ai_request(client, system_prompt, user_text, model_config, cancel_event=None):
    """Make the actual AI request with proper system/user roles (wrapped with retry logic)."""
    
    # Safely handle extra_body params if they are defined
    kwargs = {
        "model": model_config['model'],
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_text}
        ],
        "temperature": model_config['temperature'],
        "top_p": model_config['top_p'],
        "max_tokens": model_config['max_tokens'],
        "stream": True,
    }
    
    if model_config.get('extra_body'):
        kwargs['extra_body'] = model_config['extra_body']
        
    if model_config.get('seed'):
        kwargs['seed'] = model_config['seed']
        
    completion = client.chat.completions.create(**kwargs)
    
    # Collect streamed response
    full_response = ""
    print("πŸ“₯ Receiving response:", flush=True)
    
    last_print_len: int = 0
    for chunk in completion:
        # Check for cancellation
        if cancel_event and cancel_event.is_set():
            print("\n⚠️ Request cancelled by user", flush=True)
            raise CancelledError("Generation cancelled by user")
        
        if not getattr(chunk, "choices", None):
            continue
            
        # Handle reasoning content (thought process) if present
        reasoning = getattr(chunk.choices[0].delta, "reasoning_content", None)
        if reasoning:
            print(reasoning, end="", flush=True)
            continue # Don't add reasoning to full_response to avoid corrupting HTML
            
        if chunk.choices[0].delta.content is not None:
            content = chunk.choices[0].delta.content
            full_response += content
            
            # Print a dot for every 50 characters received
            if len(full_response) >= last_print_len + 50:  # type: ignore
                print(".", end="", flush=True)
                last_print_len = len(full_response)
    
    print("\n", flush=True)
    return full_response


class CancelledError(Exception):
    """Raised when a generation is cancelled by the user."""
    pass


# Store active cancel events keyed by operation_id
_active_operations = {}
_operations_lock = threading.Lock()


def register_operation(operation_id):
    """Register a new operation and return its cancel event."""
    event = threading.Event()
    with _operations_lock:
        _active_operations[operation_id] = event
    return event


def cancel_operation(operation_id):
    """Cancel an active operation by setting its cancel event."""
    with _operations_lock:
        event = _active_operations.get(operation_id)
        if event:
            event.set()
            return True
    return False


def unregister_operation(operation_id):
    """Clean up a completed operation."""
    with _operations_lock:
        _active_operations.pop(operation_id, None)


def verify_html_content(input_text, html_content, cancel_event=None, model_choice='default'):
    """Verify that the generated HTML preserves all content from the input text."""
    print("=" * 60, flush=True)
    print("πŸ€– Verifying HTML content against original text...", flush=True)
    
    verification_sys_prompt = (
        "You are an expert quality assurance reviewer. Your job is to compare the original raw text with the generated HTML output.\n"
        "Check line by line to ensure NO content, questions, answers, or vocabulary from the original text has been skipped, summarized, or omitted in the HTML.\n"
        "If ALL content is carefully preserved in the HTML, output EXACTLY the word 'PASS' and nothing else.\n"
        "If ANY content was removed, summarized, or omitted, output a list of the specific missing content and instructions on what needs to be added back. Do not output 'PASS'."
    )
    
    verification_user_prompt = (
        f"--- ORIGINAL RAW TEXT ---\n{input_text}\n\n"
        f"--- GENERATED HTML ---\n{html_content}\n\n"
        "Did the HTML preserve all the content? Output 'PASS' or list the missing content."
    )
    
    try:
        model_config = MODELS_CONFIG.get(model_choice, MODELS_CONFIG['default'])
        client = OpenAI(base_url=API_URL, api_key=model_config['api_key'])
        response = _make_ai_request(client, verification_sys_prompt, verification_user_prompt, model_config, cancel_event=cancel_event)
        
        response = response.strip()
        print(f"βœ… Verification result: {response[:100]}...", flush=True)
        
        if response.upper() == "PASS" or response.upper().startswith("PASS"):
            return "PASS"
        else:
            return response
            
    except CancelledError:
        print("⚠️ Verification was cancelled")
        return None
    except Exception as e:
        print(f"❌ Verification failed: {e}")
        return "PASS"  # Fail open if verification errors


def get_ai_revision(input_text, previous_html, feedback, cancel_event=None, model_choice='default'):
    """Ask the AI to revise the HTML based on verification feedback."""
    print("=" * 60, flush=True)
    print("πŸ€– Requesting AI revision based on feedback...", flush=True)
    
    base_sys_prompt = load_system_prompt()
    revision_sys_prompt = (
        f"{base_sys_prompt}\n\n"
        "CRITICAL REVISION INSTRUCTIONS:\n"
        "You previously generated HTML for this text, but the quality assurance reviewer found that you skipped or summarized some content.\n"
        "Here is the exact feedback on what is missing:\n"
        "-------------------------------------\n"
        f"{feedback}\n"
        "-------------------------------------\n"
        "Your task:\n"
        "1. Rewrite the ENTIRE HTML document from start to finish.\n"
        "2. You MUST include ALL content from the original text.\n"
        "3. Pay special attention to the feedback above and guarantee that all missing parts are inserted in the correct locations.\n"
        "4. This is a strict test. If you skip, omit, or summarize ANY paragraph, question, or option, you will fail.\n"
        "DO NOT output anything other than raw HTML. No markdown code blocks, no explanations. Start with <!DOCTYPE html>."
    )
    
    return get_ai_response(input_text, use_cache=False, cancel_event=cancel_event, system_prompt=revision_sys_prompt, model_choice=model_choice)


def _cache_key(input_text, model_choice, system_prompt):
    """Cache key varies on (model_choice, system_prompt, input_text) β€” switching
    model or adding a custom system prompt must produce a different key so the
    cache doesn't return the response from a previous configuration."""
    return f"{model_choice}|{system_prompt or ''}|{input_text}"


def get_ai_response(input_text, use_cache=True, cancel_event=None, system_prompt=None, model_choice='default'):
    """Send text to AI model and get response with proper system/user message roles."""
    print("=" * 60, flush=True)
    print("πŸ€– Sending request to AI using OpenAI library...", flush=True)

    # Check cache first β€” key includes model + system prompt, not just the input.
    resolved_system_prompt = system_prompt if system_prompt is not None else load_system_prompt()
    cache_key = _cache_key(input_text, model_choice, resolved_system_prompt)
    if use_cache:
        cached_response = cache.get(cache_key)
        if cached_response:
            print("=" * 60, flush=True)
            return cached_response

    try:
        model_config = MODELS_CONFIG.get(model_choice, MODELS_CONFIG['default'])

        # Fail fast on an obviously-unconfigured key rather than making the
        # UI sit at 0% while the OpenAI client retries a bogus endpoint.
        placeholder_keys = {'', 'your-api-key-here', 'REPLACE_ME'}
        resolved_key = (model_config.get('api_key') or '').strip()
        if resolved_key in placeholder_keys:
            raise RuntimeError(
                "AI is not configured: API_KEY is missing or a placeholder. "
                "Edit backend/config/config.py (or set the API_KEY env var) "
                "with a real key and restart the backend."
            )

        # Initialize OpenAI client with per-phase timeouts. A single
        # scalar `timeout=60` was previously used, which treated the
        # whole streaming completion as one 60-second budget and meant
        # any AI response taking longer than 60s (common for 1000-char
        # inputs on slower endpoints) timed out, got retried 3x by the
        # backoff decorator, and pushed a normal 3-4 min run past 8 min.
        # httpx.Timeout gives us fast-fail on connect + generous read
        # time for the streamed body. Override via env vars.
        connect_timeout = float(os.environ.get('AI_CONNECT_TIMEOUT', '15'))
        read_timeout = float(os.environ.get('AI_READ_TIMEOUT', '600'))
        client = OpenAI(
            base_url=API_URL,
            api_key=resolved_key,
            timeout=httpx.Timeout(
                connect=connect_timeout,
                read=read_timeout,
                write=30.0,
                pool=30.0,
            ),
            max_retries=0,  # _make_ai_request already handles retries
        )

        system_prompt = resolved_system_prompt
        
        print(f"πŸ“ Input length: {len(input_text)} characters", flush=True)
        print(f"πŸ“ System prompt length: {len(system_prompt)} characters", flush=True)
        print(f"🌐 API URL: {API_URL}", flush=True)
        print(f"πŸ”‘ Using model config: {model_choice} -> {model_config['model']}", flush=True)
        print(f"⏳ Sending request with streaming...\n", flush=True)
        
        # Make request with retry logic and proper roles
        full_response = _make_ai_request(client, system_prompt, input_text, model_config, cancel_event=cancel_event)
        print(f"βœ… Response received successfully", flush=True)
        print(f"πŸ“„ Content length: {len(full_response)} characters", flush=True)
        
        # Clean up response - remove markdown code blocks if present
        full_response = full_response.strip()
        if full_response.startswith("```html"):
            full_response = full_response[7:]  # Remove ```html
        if full_response.startswith("```"):
            full_response = full_response[3:]  # Remove ```
        if full_response.endswith("```"):
            full_response = full_response[:-3]  # Remove trailing ```
        full_response = full_response.strip()
        
        print(f"πŸ“„ First 100 chars: {full_response[:100]}...", flush=True)
        
        # Cache the response using the composite key so later requests with a
        # different model / system_prompt don't silently get this response back.
        if use_cache:
            cache.set(cache_key, full_response)
        
        return full_response
        
    except CancelledError:
        print("⚠️ Generation was cancelled")
        return None
    except Exception as e:
        print(f"❌ Request failed: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
        return None
    finally:
        print("=" * 60)