RAI-BETA / app.py
Kawaquader's picture
Update app.py
b216923 verified
import os
import traceback
import json
# Must be set before importing tensorflow
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
from tensorflow.keras import layers
import math
import tiktoken
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ChatRequest(BaseModel):
message: str
# ==============================================================================
# 1. ARCHITECTURE DEFINITION (10MB Config)
# ==============================================================================
class CausalSelfAttention(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.n_head = config['n_head']
self.n_embd = config['n_embd']
self.head_dim = self.n_embd // self.n_head
self.c_attn = layers.Dense(3 * self.n_embd)
self.c_proj = layers.Dense(self.n_embd)
self.attn_dropout = layers.Dropout(config['dropout'])
self.resid_dropout = layers.Dropout(config['dropout'])
def call(self, x, training=False):
B = tf.shape(x)[0]
T = tf.shape(x)[1]
qkv = self.c_attn(x)
q, k, v = tf.split(qkv, 3, axis=-1)
q = tf.transpose(tf.reshape(q, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
k = tf.transpose(tf.reshape(k, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
v = tf.transpose(tf.reshape(v, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
att = tf.matmul(q, k, transpose_b=True) * (1.0 / math.sqrt(float(self.head_dim)))
mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
att = tf.where(mask == 0, tf.cast(-1e9, att.dtype), att)
att = tf.nn.softmax(att, axis=-1)
att = self.attn_dropout(att, training=training)
y = tf.matmul(att, v)
y = tf.reshape(tf.transpose(y, (0, 2, 1, 3)), (B, T, self.n_embd))
return self.resid_dropout(self.c_proj(y), training=training)
class MLP(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.c_fc = layers.Dense(4 * config['n_embd'])
self.c_proj = layers.Dense(config['n_embd'])
self.dropout = layers.Dropout(config['dropout'])
def call(self, x, training=False):
x = self.c_fc(x)
x = tf.nn.gelu(x)
x = self.c_proj(x)
return self.dropout(x, training=training)
class Block(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.ln_1 = layers.LayerNormalization(epsilon=1e-5)
self.attn = CausalSelfAttention(config)
self.ln_2 = layers.LayerNormalization(epsilon=1e-5)
self.mlp = MLP(config)
def call(self, x, training=False):
x = x + self.attn(self.ln_1(x), training=training)
x = x + self.mlp(self.ln_2(x), training=training)
return x
class GPT2(tf.keras.Model):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.wte = layers.Embedding(config['vocab_size'], config['n_embd'])
self.wpe = layers.Embedding(config['block_size'], config['n_embd'])
self.drop = layers.Dropout(config['dropout'])
self.h = [Block(config) for _ in range(config['n_layer'])]
self.ln_f = layers.LayerNormalization(epsilon=1e-5)
self.lm_head = layers.Dense(config['vocab_size'], use_bias=False)
def call(self, idx, training=False):
T = tf.shape(idx)[1]
pos = tf.range(0, T, dtype=tf.int32)
tok_emb = self.wte(idx)
pos_emb = self.wpe(pos)
x = self.drop(tok_emb + pos_emb, training=training)
for block in self.h:
x = block(x, training=training)
x = self.ln_f(x)
logits = self.lm_head(x)
return logits
# ==============================================================================
# 2. LOAD MODEL FLEXIBLY (Bypassing static shape errors)
# ==============================================================================
enc = tiktoken.get_encoding("gpt2")
gpt_config = {'vocab_size': 50257, 'block_size': 256, 'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'dropout': 0.1}
print("Instantiating flexible architecture...")
model = GPT2(gpt_config)
# Build the graph with a dynamic size (1 token) so it doesn't lock to 256
_ = model(tf.zeros((1, 1), dtype=tf.int32))
print("Loading weights directly...")
# Loading weights onto the fresh architecture avoids the SavedModel shape restrictions!
model.load_weights("gpt2_finetuned_10mb")
print("Model loaded successfully.")
# ==============================================================================
# 3. API ENDPOINTS
# ==============================================================================
@app.get("/")
def read_root():
return FileResponse("index.html")
@app.post("/chat")
def chat_endpoint(req: ChatRequest):
def generate():
try:
# Reverted to Colab-style strict QA recall settings
temperature = 0.6
max_new_tokens = 10000
formatted_prompt = f"{req.message}"
input_ids = enc.encode(formatted_prompt)
original_len = len(input_ids)
for _ in range(max_new_tokens):
# We no longer need to pad with zeros!
context = input_ids[-gpt_config['block_size']:]
x = tf.constant([context], dtype=tf.int32)
logits = model(x, training=False)
# Extract raw logits for the last token
next_token_logits = logits[0, -1, :]
# Strict Temperature Sampling (like Colab)
scaled_logits = next_token_logits / temperature
scaled_logits = tf.expand_dims(scaled_logits, 0)
next_token = tf.random.categorical(scaled_logits, num_samples=1).numpy()[0, 0]
next_token_int = int(next_token)
input_ids.append(next_token_int)
# Decode the ENTIRE generation so far
current_generation = enc.decode(input_ids[original_len:])
# Clean up the text stream
clean_generation = current_generation.replace("\ufffd", "")
clean_generation = clean_generation.lstrip()
# Stop word logic
if "<user>" in current_generation:
final_text = clean_generation.split("<user>")[0].rstrip()
yield f"data: {json.dumps({'text': final_text})}\n\n"
break
yield f"data: {json.dumps({'text': clean_generation})}\n\n"
except Exception as e:
error_details = traceback.format_exc()
print("CRITICAL ERROR IN /chat ENDPOINT:")
print(error_details)
yield f"data: {json.dumps({'error': f'🔥 CRASH: {str(e)}'})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")