import os import traceback import json # Must be set before importing tensorflow os.environ["TF_USE_LEGACY_KERAS"] = "1" import tensorflow as tf from tensorflow.keras import layers import math import tiktoken from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from fastapi.responses import StreamingResponse app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ChatRequest(BaseModel): message: str # ============================================================================== # 1. ARCHITECTURE DEFINITION (10MB Config) # ============================================================================== class CausalSelfAttention(layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.n_head = config['n_head'] self.n_embd = config['n_embd'] self.head_dim = self.n_embd // self.n_head self.c_attn = layers.Dense(3 * self.n_embd) self.c_proj = layers.Dense(self.n_embd) self.attn_dropout = layers.Dropout(config['dropout']) self.resid_dropout = layers.Dropout(config['dropout']) def call(self, x, training=False): B = tf.shape(x)[0] T = tf.shape(x)[1] qkv = self.c_attn(x) q, k, v = tf.split(qkv, 3, axis=-1) q = tf.transpose(tf.reshape(q, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3)) k = tf.transpose(tf.reshape(k, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3)) v = tf.transpose(tf.reshape(v, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3)) att = tf.matmul(q, k, transpose_b=True) * (1.0 / math.sqrt(float(self.head_dim))) mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0) att = tf.where(mask == 0, tf.cast(-1e9, att.dtype), att) att = tf.nn.softmax(att, axis=-1) att = self.attn_dropout(att, training=training) y = tf.matmul(att, v) y = tf.reshape(tf.transpose(y, (0, 2, 1, 3)), (B, T, self.n_embd)) return self.resid_dropout(self.c_proj(y), training=training) class MLP(layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.c_fc = layers.Dense(4 * config['n_embd']) self.c_proj = layers.Dense(config['n_embd']) self.dropout = layers.Dropout(config['dropout']) def call(self, x, training=False): x = self.c_fc(x) x = tf.nn.gelu(x) x = self.c_proj(x) return self.dropout(x, training=training) class Block(layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.ln_1 = layers.LayerNormalization(epsilon=1e-5) self.attn = CausalSelfAttention(config) self.ln_2 = layers.LayerNormalization(epsilon=1e-5) self.mlp = MLP(config) def call(self, x, training=False): x = x + self.attn(self.ln_1(x), training=training) x = x + self.mlp(self.ln_2(x), training=training) return x class GPT2(tf.keras.Model): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.config = config self.wte = layers.Embedding(config['vocab_size'], config['n_embd']) self.wpe = layers.Embedding(config['block_size'], config['n_embd']) self.drop = layers.Dropout(config['dropout']) self.h = [Block(config) for _ in range(config['n_layer'])] self.ln_f = layers.LayerNormalization(epsilon=1e-5) self.lm_head = layers.Dense(config['vocab_size'], use_bias=False) def call(self, idx, training=False): T = tf.shape(idx)[1] pos = tf.range(0, T, dtype=tf.int32) tok_emb = self.wte(idx) pos_emb = self.wpe(pos) x = self.drop(tok_emb + pos_emb, training=training) for block in self.h: x = block(x, training=training) x = self.ln_f(x) logits = self.lm_head(x) return logits # ============================================================================== # 2. LOAD MODEL FLEXIBLY (Bypassing static shape errors) # ============================================================================== enc = tiktoken.get_encoding("gpt2") gpt_config = {'vocab_size': 50257, 'block_size': 256, 'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'dropout': 0.1} print("Instantiating flexible architecture...") model = GPT2(gpt_config) # Build the graph with a dynamic size (1 token) so it doesn't lock to 256 _ = model(tf.zeros((1, 1), dtype=tf.int32)) print("Loading weights directly...") # Loading weights onto the fresh architecture avoids the SavedModel shape restrictions! model.load_weights("gpt2_finetuned_10mb") print("Model loaded successfully.") # ============================================================================== # 3. API ENDPOINTS # ============================================================================== @app.get("/") def read_root(): return FileResponse("index.html") @app.post("/chat") def chat_endpoint(req: ChatRequest): def generate(): try: # Reverted to Colab-style strict QA recall settings temperature = 0.6 max_new_tokens = 10000 formatted_prompt = f"{req.message}" input_ids = enc.encode(formatted_prompt) original_len = len(input_ids) for _ in range(max_new_tokens): # We no longer need to pad with zeros! context = input_ids[-gpt_config['block_size']:] x = tf.constant([context], dtype=tf.int32) logits = model(x, training=False) # Extract raw logits for the last token next_token_logits = logits[0, -1, :] # Strict Temperature Sampling (like Colab) scaled_logits = next_token_logits / temperature scaled_logits = tf.expand_dims(scaled_logits, 0) next_token = tf.random.categorical(scaled_logits, num_samples=1).numpy()[0, 0] next_token_int = int(next_token) input_ids.append(next_token_int) # Decode the ENTIRE generation so far current_generation = enc.decode(input_ids[original_len:]) # Clean up the text stream clean_generation = current_generation.replace("\ufffd", "") clean_generation = clean_generation.lstrip() # Stop word logic if "" in current_generation: final_text = clean_generation.split("")[0].rstrip() yield f"data: {json.dumps({'text': final_text})}\n\n" break yield f"data: {json.dumps({'text': clean_generation})}\n\n" except Exception as e: error_details = traceback.format_exc() print("CRITICAL ERROR IN /chat ENDPOINT:") print(error_details) yield f"data: {json.dumps({'error': f'🔥 CRASH: {str(e)}'})}\n\n" return StreamingResponse(generate(), media_type="text/event-stream")