Spaces:
Running
Running
File size: 7,386 Bytes
eb4b1cb b216923 eb4b1cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | import os
import traceback
import json
# Must be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
from tensorflow.keras import layers
import math
import tiktoken
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ChatRequest(BaseModel):
message: str
# ==============================================================================
# 1. ARCHITECTURE DEFINITION (10MB Config)
# ==============================================================================
class CausalSelfAttention(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.n_head = config['n_head']
self.n_embd = config['n_embd']
self.head_dim = self.n_embd // self.n_head
self.c_attn = layers.Dense(3 * self.n_embd)
self.c_proj = layers.Dense(self.n_embd)
self.attn_dropout = layers.Dropout(config['dropout'])
self.resid_dropout = layers.Dropout(config['dropout'])
def call(self, x, training=False):
B = tf.shape(x)[0]
T = tf.shape(x)[1]
qkv = self.c_attn(x)
q, k, v = tf.split(qkv, 3, axis=-1)
q = tf.transpose(tf.reshape(q, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
k = tf.transpose(tf.reshape(k, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
v = tf.transpose(tf.reshape(v, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
att = tf.matmul(q, k, transpose_b=True) * (1.0 / math.sqrt(float(self.head_dim)))
mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
att = tf.where(mask == 0, tf.cast(-1e9, att.dtype), att)
att = tf.nn.softmax(att, axis=-1)
att = self.attn_dropout(att, training=training)
y = tf.matmul(att, v)
y = tf.reshape(tf.transpose(y, (0, 2, 1, 3)), (B, T, self.n_embd))
return self.resid_dropout(self.c_proj(y), training=training)
class MLP(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.c_fc = layers.Dense(4 * config['n_embd'])
self.c_proj = layers.Dense(config['n_embd'])
self.dropout = layers.Dropout(config['dropout'])
def call(self, x, training=False):
x = self.c_fc(x)
x = tf.nn.gelu(x)
x = self.c_proj(x)
return self.dropout(x, training=training)
class Block(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.ln_1 = layers.LayerNormalization(epsilon=1e-5)
self.attn = CausalSelfAttention(config)
self.ln_2 = layers.LayerNormalization(epsilon=1e-5)
self.mlp = MLP(config)
def call(self, x, training=False):
x = x + self.attn(self.ln_1(x), training=training)
x = x + self.mlp(self.ln_2(x), training=training)
return x
class GPT2(tf.keras.Model):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.wte = layers.Embedding(config['vocab_size'], config['n_embd'])
self.wpe = layers.Embedding(config['block_size'], config['n_embd'])
self.drop = layers.Dropout(config['dropout'])
self.h = [Block(config) for _ in range(config['n_layer'])]
self.ln_f = layers.LayerNormalization(epsilon=1e-5)
self.lm_head = layers.Dense(config['vocab_size'], use_bias=False)
def call(self, idx, training=False):
T = tf.shape(idx)[1]
pos = tf.range(0, T, dtype=tf.int32)
tok_emb = self.wte(idx)
pos_emb = self.wpe(pos)
x = self.drop(tok_emb + pos_emb, training=training)
for block in self.h:
x = block(x, training=training)
x = self.ln_f(x)
logits = self.lm_head(x)
return logits
# ==============================================================================
# 2. LOAD MODEL FLEXIBLY (Bypassing static shape errors)
# ==============================================================================
enc = tiktoken.get_encoding("gpt2")
gpt_config = {'vocab_size': 50257, 'block_size': 256, 'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'dropout': 0.1}
print("Instantiating flexible architecture...")
model = GPT2(gpt_config)
# Build the graph with a dynamic size (1 token) so it doesn't lock to 256
_ = model(tf.zeros((1, 1), dtype=tf.int32))
print("Loading weights directly...")
# Loading weights onto the fresh architecture avoids the SavedModel shape restrictions!
model.load_weights("gpt2_finetuned_10mb")
print("Model loaded successfully.")
# ==============================================================================
# 3. API ENDPOINTS
# ==============================================================================
@app.get("/")
def read_root():
return FileResponse("index.html")
@app.post("/chat")
def chat_endpoint(req: ChatRequest):
def generate():
try:
# Reverted to Colab-style strict QA recall settings
temperature = 0.6
max_new_tokens = 10000
formatted_prompt = f"{req.message}"
input_ids = enc.encode(formatted_prompt)
original_len = len(input_ids)
for _ in range(max_new_tokens):
# We no longer need to pad with zeros!
context = input_ids[-gpt_config['block_size']:]
x = tf.constant([context], dtype=tf.int32)
logits = model(x, training=False)
# Extract raw logits for the last token
next_token_logits = logits[0, -1, :]
# Strict Temperature Sampling (like Colab)
scaled_logits = next_token_logits / temperature
scaled_logits = tf.expand_dims(scaled_logits, 0)
next_token = tf.random.categorical(scaled_logits, num_samples=1).numpy()[0, 0]
next_token_int = int(next_token)
input_ids.append(next_token_int)
# Decode the ENTIRE generation so far
current_generation = enc.decode(input_ids[original_len:])
# Clean up the text stream
clean_generation = current_generation.replace("\ufffd", "")
clean_generation = clean_generation.lstrip()
# Stop word logic
if "<user>" in current_generation:
final_text = clean_generation.split("<user>")[0].rstrip()
yield f"data: {json.dumps({'text': final_text})}\n\n"
break
yield f"data: {json.dumps({'text': clean_generation})}\n\n"
except Exception as e:
error_details = traceback.format_exc()
print("CRITICAL ERROR IN /chat ENDPOINT:")
print(error_details)
yield f"data: {json.dumps({'error': f'🔥 CRASH: {str(e)}'})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream") |