Spaces:
Running
Running
| import os | |
| import traceback | |
| import json | |
| # Must be set before importing tensorflow | |
| os.environ["TF_USE_LEGACY_KERAS"] = "1" | |
| import tensorflow as tf | |
| from tensorflow.keras import layers | |
| import math | |
| import tiktoken | |
| from fastapi import FastAPI | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from fastapi.responses import StreamingResponse | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class ChatRequest(BaseModel): | |
| message: str | |
| # ============================================================================== | |
| # 1. ARCHITECTURE DEFINITION (10MB Config) | |
| # ============================================================================== | |
| class CausalSelfAttention(layers.Layer): | |
| def __init__(self, config, **kwargs): | |
| super().__init__(**kwargs) | |
| self.n_head = config['n_head'] | |
| self.n_embd = config['n_embd'] | |
| self.head_dim = self.n_embd // self.n_head | |
| self.c_attn = layers.Dense(3 * self.n_embd) | |
| self.c_proj = layers.Dense(self.n_embd) | |
| self.attn_dropout = layers.Dropout(config['dropout']) | |
| self.resid_dropout = layers.Dropout(config['dropout']) | |
| def call(self, x, training=False): | |
| B = tf.shape(x)[0] | |
| T = tf.shape(x)[1] | |
| qkv = self.c_attn(x) | |
| q, k, v = tf.split(qkv, 3, axis=-1) | |
| q = tf.transpose(tf.reshape(q, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3)) | |
| k = tf.transpose(tf.reshape(k, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3)) | |
| v = tf.transpose(tf.reshape(v, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3)) | |
| att = tf.matmul(q, k, transpose_b=True) * (1.0 / math.sqrt(float(self.head_dim))) | |
| mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0) | |
| att = tf.where(mask == 0, tf.cast(-1e9, att.dtype), att) | |
| att = tf.nn.softmax(att, axis=-1) | |
| att = self.attn_dropout(att, training=training) | |
| y = tf.matmul(att, v) | |
| y = tf.reshape(tf.transpose(y, (0, 2, 1, 3)), (B, T, self.n_embd)) | |
| return self.resid_dropout(self.c_proj(y), training=training) | |
| class MLP(layers.Layer): | |
| def __init__(self, config, **kwargs): | |
| super().__init__(**kwargs) | |
| self.c_fc = layers.Dense(4 * config['n_embd']) | |
| self.c_proj = layers.Dense(config['n_embd']) | |
| self.dropout = layers.Dropout(config['dropout']) | |
| def call(self, x, training=False): | |
| x = self.c_fc(x) | |
| x = tf.nn.gelu(x) | |
| x = self.c_proj(x) | |
| return self.dropout(x, training=training) | |
| class Block(layers.Layer): | |
| def __init__(self, config, **kwargs): | |
| super().__init__(**kwargs) | |
| self.ln_1 = layers.LayerNormalization(epsilon=1e-5) | |
| self.attn = CausalSelfAttention(config) | |
| self.ln_2 = layers.LayerNormalization(epsilon=1e-5) | |
| self.mlp = MLP(config) | |
| def call(self, x, training=False): | |
| x = x + self.attn(self.ln_1(x), training=training) | |
| x = x + self.mlp(self.ln_2(x), training=training) | |
| return x | |
| class GPT2(tf.keras.Model): | |
| def __init__(self, config, **kwargs): | |
| super().__init__(**kwargs) | |
| self.config = config | |
| self.wte = layers.Embedding(config['vocab_size'], config['n_embd']) | |
| self.wpe = layers.Embedding(config['block_size'], config['n_embd']) | |
| self.drop = layers.Dropout(config['dropout']) | |
| self.h = [Block(config) for _ in range(config['n_layer'])] | |
| self.ln_f = layers.LayerNormalization(epsilon=1e-5) | |
| self.lm_head = layers.Dense(config['vocab_size'], use_bias=False) | |
| def call(self, idx, training=False): | |
| T = tf.shape(idx)[1] | |
| pos = tf.range(0, T, dtype=tf.int32) | |
| tok_emb = self.wte(idx) | |
| pos_emb = self.wpe(pos) | |
| x = self.drop(tok_emb + pos_emb, training=training) | |
| for block in self.h: | |
| x = block(x, training=training) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| return logits | |
| # ============================================================================== | |
| # 2. LOAD MODEL FLEXIBLY (Bypassing static shape errors) | |
| # ============================================================================== | |
| enc = tiktoken.get_encoding("gpt2") | |
| gpt_config = {'vocab_size': 50257, 'block_size': 256, 'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'dropout': 0.1} | |
| print("Instantiating flexible architecture...") | |
| model = GPT2(gpt_config) | |
| # Build the graph with a dynamic size (1 token) so it doesn't lock to 256 | |
| _ = model(tf.zeros((1, 1), dtype=tf.int32)) | |
| print("Loading weights directly...") | |
| # Loading weights onto the fresh architecture avoids the SavedModel shape restrictions! | |
| model.load_weights("gpt2_finetuned_10mb") | |
| print("Model loaded successfully.") | |
| # ============================================================================== | |
| # 3. API ENDPOINTS | |
| # ============================================================================== | |
| def read_root(): | |
| return FileResponse("index.html") | |
| def chat_endpoint(req: ChatRequest): | |
| def generate(): | |
| try: | |
| # Reverted to Colab-style strict QA recall settings | |
| temperature = 0.6 | |
| max_new_tokens = 10000 | |
| formatted_prompt = f"{req.message}" | |
| input_ids = enc.encode(formatted_prompt) | |
| original_len = len(input_ids) | |
| for _ in range(max_new_tokens): | |
| # We no longer need to pad with zeros! | |
| context = input_ids[-gpt_config['block_size']:] | |
| x = tf.constant([context], dtype=tf.int32) | |
| logits = model(x, training=False) | |
| # Extract raw logits for the last token | |
| next_token_logits = logits[0, -1, :] | |
| # Strict Temperature Sampling (like Colab) | |
| scaled_logits = next_token_logits / temperature | |
| scaled_logits = tf.expand_dims(scaled_logits, 0) | |
| next_token = tf.random.categorical(scaled_logits, num_samples=1).numpy()[0, 0] | |
| next_token_int = int(next_token) | |
| input_ids.append(next_token_int) | |
| # Decode the ENTIRE generation so far | |
| current_generation = enc.decode(input_ids[original_len:]) | |
| # Clean up the text stream | |
| clean_generation = current_generation.replace("\ufffd", "") | |
| clean_generation = clean_generation.lstrip() | |
| # Stop word logic | |
| if "<user>" in current_generation: | |
| final_text = clean_generation.split("<user>")[0].rstrip() | |
| yield f"data: {json.dumps({'text': final_text})}\n\n" | |
| break | |
| yield f"data: {json.dumps({'text': clean_generation})}\n\n" | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| print("CRITICAL ERROR IN /chat ENDPOINT:") | |
| print(error_details) | |
| yield f"data: {json.dumps({'error': f'🔥 CRASH: {str(e)}'})}\n\n" | |
| return StreamingResponse(generate(), media_type="text/event-stream") |