Kawaquader commited on
Commit
5a3b397
·
verified ·
1 Parent(s): f440d61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -162
app.py CHANGED
@@ -1,26 +1,24 @@
1
  import os
2
- import json
3
- import math
4
  import traceback
 
 
 
 
 
5
  import tensorflow as tf
6
  from tensorflow.keras import layers
 
7
  import tiktoken
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
  from fastapi.responses import StreamingResponse
12
 
13
- # Must be set before importing tensorflow
14
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
15
-
16
- # ==============================================================================
17
- # FASTAPI SETUP
18
- # ==============================================================================
19
  app = FastAPI()
20
 
21
  app.add_middleware(
22
  CORSMiddleware,
23
- allow_origins=["*"],
24
  allow_credentials=True,
25
  allow_methods=["*"],
26
  allow_headers=["*"],
@@ -30,207 +28,158 @@ class ChatRequest(BaseModel):
30
  message: str
31
 
32
  # ==============================================================================
33
- # MODEL ARCHITECTURE
34
  # ==============================================================================
35
  class CausalSelfAttention(layers.Layer):
36
- def __init__(self, config):
37
- super().__init__()
38
  self.n_head = config['n_head']
39
  self.n_embd = config['n_embd']
40
  self.head_dim = self.n_embd // self.n_head
41
-
42
  self.c_attn = layers.Dense(3 * self.n_embd)
43
  self.c_proj = layers.Dense(self.n_embd)
 
 
44
 
45
- def call(self, x):
46
  B = tf.shape(x)[0]
47
  T = tf.shape(x)[1]
48
-
49
  qkv = self.c_attn(x)
50
  q, k, v = tf.split(qkv, 3, axis=-1)
51
-
52
- q = tf.transpose(tf.reshape(q, (B, T, self.n_head, self.head_dim)), (0,2,1,3))
53
- k = tf.transpose(tf.reshape(k, (B, T, self.n_head, self.head_dim)), (0,2,1,3))
54
- v = tf.transpose(tf.reshape(v, (B, T, self.n_head, self.head_dim)), (0,2,1,3))
55
-
56
- att = tf.matmul(q, k, transpose_b=True) / math.sqrt(self.head_dim)
57
-
58
  mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
59
- att = tf.where(mask == 0, -1e9, att)
60
-
61
  att = tf.nn.softmax(att, axis=-1)
 
62
  y = tf.matmul(att, v)
63
-
64
- y = tf.reshape(tf.transpose(y, (0,2,1,3)), (B, T, self.n_embd))
65
- return self.c_proj(y)
66
 
67
  class MLP(layers.Layer):
68
- def __init__(self, config):
69
- super().__init__()
70
- self.fc = layers.Dense(4 * config['n_embd'])
71
- self.proj = layers.Dense(config['n_embd'])
72
-
73
- def call(self, x):
74
- return self.proj(tf.nn.gelu(self.fc(x)))
 
 
 
 
75
 
76
  class Block(layers.Layer):
77
- def __init__(self, config):
78
- super().__init__()
79
- self.ln1 = layers.LayerNormalization(epsilon=1e-5)
80
  self.attn = CausalSelfAttention(config)
81
- self.ln2 = layers.LayerNormalization(epsilon=1e-5)
82
  self.mlp = MLP(config)
83
 
84
- def call(self, x):
85
- x = x + self.attn(self.ln1(x))
86
- x = x + self.mlp(self.ln2(x))
87
  return x
88
 
89
  class GPT2(tf.keras.Model):
90
- def __init__(self, config):
91
- super().__init__()
 
92
  self.wte = layers.Embedding(config['vocab_size'], config['n_embd'])
93
  self.wpe = layers.Embedding(config['block_size'], config['n_embd'])
94
- self.blocks = [Block(config) for _ in range(config['n_layer'])]
 
95
  self.ln_f = layers.LayerNormalization(epsilon=1e-5)
96
- self.head = layers.Dense(config['vocab_size'], use_bias=False)
97
 
98
- def call(self, idx):
99
  T = tf.shape(idx)[1]
100
- pos = tf.range(0, T)
101
-
102
- x = self.wte(idx) + self.wpe(pos)
103
-
104
- for b in self.blocks:
105
- x = b(x)
106
-
107
  x = self.ln_f(x)
108
- return self.head(x)
 
109
 
110
  # ==============================================================================
111
- # LOAD MODEL
112
  # ==============================================================================
113
  enc = tiktoken.get_encoding("gpt2")
 
114
 
115
- config = {
116
- 'vocab_size': 50257,
117
- 'block_size': 256,
118
- 'n_layer': 6,
119
- 'n_head': 6,
120
- 'n_embd': 384
121
- }
122
 
123
- model = GPT2(config)
124
- _ = model(tf.zeros((1,1), dtype=tf.int32))
125
  model.load_weights("gpt2_finetuned_10mb")
126
-
127
- print("✅ Model loaded")
128
-
129
- # ==============================================================================
130
- # SAMPLING FUNCTION
131
- # ==============================================================================
132
- def sample_token(logits, input_ids, temperature=0.7, top_k=40, top_p=0.9):
133
- logits = logits / temperature
134
-
135
- # Repetition penalty
136
- for token in set(input_ids):
137
- logits = tf.tensor_scatter_nd_update(
138
- logits,
139
- [[token]],
140
- [logits[token] * 0.9]
141
- )
142
-
143
- # Top-k
144
- values, _ = tf.math.top_k(logits, k=top_k)
145
- min_val = values[-1]
146
- logits = tf.where(logits < min_val, -1e10, logits)
147
-
148
- # Top-p
149
- sorted_logits = tf.sort(logits, direction='DESCENDING')
150
- probs = tf.nn.softmax(sorted_logits)
151
- cumulative = tf.cumsum(probs)
152
-
153
- cutoff = tf.reduce_sum(tf.cast(cumulative <= top_p, tf.int32))
154
- threshold = sorted_logits[cutoff]
155
-
156
- logits = tf.where(logits < threshold, -1e10, logits)
157
-
158
- logits = tf.expand_dims(logits, 0)
159
- return int(tf.random.categorical(logits, 1)[0,0].numpy())
160
 
161
  # ==============================================================================
162
- # GENERATION HELPERS
163
  # ==============================================================================
164
- def generate_ids(input_ids, max_tokens):
165
- for _ in range(max_tokens):
166
- context = input_ids[-config['block_size']:]
167
- x = tf.constant([context], dtype=tf.int32)
168
-
169
- logits = model(x)[0, -1]
170
- next_token = sample_token(logits, input_ids)
171
-
172
- input_ids.append(next_token)
173
-
174
- return input_ids
175
 
176
- # ==============================================================================
177
- # CHAT ENDPOINT (HIDDEN THINKING)
178
- # ==============================================================================
179
  @app.post("/chat")
180
- def chat(req: ChatRequest):
181
-
182
- def stream():
183
  try:
184
- # ---------------------------
185
- # STEP 1: INTERNAL THINKING
186
- # ---------------------------
187
- thinking_prompt = (
188
- "<system> You are an expert AI. "
189
- "Solve the problem internally, then respond clearly.\n"
190
- "<user> " + req.message + "\n"
191
- "<ai>"
192
- )
193
-
194
- thought_ids = enc.encode(thinking_prompt)
195
- thought_ids = generate_ids(thought_ids, 80)
196
-
197
- # ---------------------------
198
- # STEP 2: FINAL ANSWER ONLY
199
- # ---------------------------
200
- final_prompt = enc.decode(thought_ids) + "\nAnswer:\n"
201
-
202
- input_ids = enc.encode(final_prompt)
203
  original_len = len(input_ids)
204
-
205
- for _ in range(120):
206
- context = input_ids[-config['block_size']:]
 
207
  x = tf.constant([context], dtype=tf.int32)
208
-
209
- logits = model(x)[0, -1]
210
- next_token = sample_token(logits, input_ids)
211
-
212
- input_ids.append(next_token)
213
-
214
- text = enc.decode(input_ids[original_len:])
215
- clean = text.replace("\ufffd", "").strip()
216
-
217
- if "<user>" in clean:
218
- clean = clean.split("<user>")[0].strip()
219
- yield f"data: {json.dumps({'text': clean})}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  break
221
-
222
- yield f"data: {json.dumps({'text': clean})}\n\n"
223
-
224
  except Exception as e:
225
  error_details = traceback.format_exc()
 
226
  print(error_details)
227
- yield f"data: {json.dumps({'error': str(e)})}\n\n"
228
-
229
- return StreamingResponse(stream(), media_type="text/event-stream")
230
 
231
- # ==============================================================================
232
- # ROOT
233
- # ==============================================================================
234
- @app.get("/")
235
- def root():
236
- return {"status": "✅ AI server running"}
 
1
  import os
 
 
2
  import traceback
3
+ import json
4
+
5
+ # Must be set before importing tensorflow
6
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
7
+
8
  import tensorflow as tf
9
  from tensorflow.keras import layers
10
+ import math
11
  import tiktoken
12
  from fastapi import FastAPI
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from pydantic import BaseModel
15
  from fastapi.responses import StreamingResponse
16
 
 
 
 
 
 
 
17
  app = FastAPI()
18
 
19
  app.add_middleware(
20
  CORSMiddleware,
21
+ allow_origins=["*"],
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
 
28
  message: str
29
 
30
  # ==============================================================================
31
+ # 1. ARCHITECTURE DEFINITION (10MB Config)
32
  # ==============================================================================
33
  class CausalSelfAttention(layers.Layer):
34
+ def __init__(self, config, **kwargs):
35
+ super().__init__(**kwargs)
36
  self.n_head = config['n_head']
37
  self.n_embd = config['n_embd']
38
  self.head_dim = self.n_embd // self.n_head
 
39
  self.c_attn = layers.Dense(3 * self.n_embd)
40
  self.c_proj = layers.Dense(self.n_embd)
41
+ self.attn_dropout = layers.Dropout(config['dropout'])
42
+ self.resid_dropout = layers.Dropout(config['dropout'])
43
 
44
+ def call(self, x, training=False):
45
  B = tf.shape(x)[0]
46
  T = tf.shape(x)[1]
 
47
  qkv = self.c_attn(x)
48
  q, k, v = tf.split(qkv, 3, axis=-1)
49
+ q = tf.transpose(tf.reshape(q, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
50
+ k = tf.transpose(tf.reshape(k, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
51
+ v = tf.transpose(tf.reshape(v, (B, T, self.n_head, self.head_dim)), (0, 2, 1, 3))
52
+ att = tf.matmul(q, k, transpose_b=True) * (1.0 / math.sqrt(float(self.head_dim)))
 
 
 
53
  mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
54
+ att = tf.where(mask == 0, tf.cast(-1e9, att.dtype), att)
 
55
  att = tf.nn.softmax(att, axis=-1)
56
+ att = self.attn_dropout(att, training=training)
57
  y = tf.matmul(att, v)
58
+ y = tf.reshape(tf.transpose(y, (0, 2, 1, 3)), (B, T, self.n_embd))
59
+ return self.resid_dropout(self.c_proj(y), training=training)
 
60
 
61
  class MLP(layers.Layer):
62
+ def __init__(self, config, **kwargs):
63
+ super().__init__(**kwargs)
64
+ self.c_fc = layers.Dense(4 * config['n_embd'])
65
+ self.c_proj = layers.Dense(config['n_embd'])
66
+ self.dropout = layers.Dropout(config['dropout'])
67
+
68
+ def call(self, x, training=False):
69
+ x = self.c_fc(x)
70
+ x = tf.nn.gelu(x)
71
+ x = self.c_proj(x)
72
+ return self.dropout(x, training=training)
73
 
74
  class Block(layers.Layer):
75
+ def __init__(self, config, **kwargs):
76
+ super().__init__(**kwargs)
77
+ self.ln_1 = layers.LayerNormalization(epsilon=1e-5)
78
  self.attn = CausalSelfAttention(config)
79
+ self.ln_2 = layers.LayerNormalization(epsilon=1e-5)
80
  self.mlp = MLP(config)
81
 
82
+ def call(self, x, training=False):
83
+ x = x + self.attn(self.ln_1(x), training=training)
84
+ x = x + self.mlp(self.ln_2(x), training=training)
85
  return x
86
 
87
  class GPT2(tf.keras.Model):
88
+ def __init__(self, config, **kwargs):
89
+ super().__init__(**kwargs)
90
+ self.config = config
91
  self.wte = layers.Embedding(config['vocab_size'], config['n_embd'])
92
  self.wpe = layers.Embedding(config['block_size'], config['n_embd'])
93
+ self.drop = layers.Dropout(config['dropout'])
94
+ self.h = [Block(config) for _ in range(config['n_layer'])]
95
  self.ln_f = layers.LayerNormalization(epsilon=1e-5)
96
+ self.lm_head = layers.Dense(config['vocab_size'], use_bias=False)
97
 
98
+ def call(self, idx, training=False):
99
  T = tf.shape(idx)[1]
100
+ pos = tf.range(0, T, dtype=tf.int32)
101
+ tok_emb = self.wte(idx)
102
+ pos_emb = self.wpe(pos)
103
+ x = self.drop(tok_emb + pos_emb, training=training)
104
+ for block in self.h:
105
+ x = block(x, training=training)
 
106
  x = self.ln_f(x)
107
+ logits = self.lm_head(x)
108
+ return logits
109
 
110
  # ==============================================================================
111
+ # 2. LOAD MODEL FLEXIBLY (Bypassing static shape errors)
112
  # ==============================================================================
113
  enc = tiktoken.get_encoding("gpt2")
114
+ gpt_config = {'vocab_size': 50257, 'block_size': 256, 'n_layer': 6, 'n_head': 6, 'n_embd': 384, 'dropout': 0.1}
115
 
116
+ print("Instantiating flexible architecture...")
117
+ model = GPT2(gpt_config)
118
+ # Build the graph with a dynamic size (1 token) so it doesn't lock to 256
119
+ _ = model(tf.zeros((1, 1), dtype=tf.int32))
 
 
 
120
 
121
+ print("Loading weights directly...")
122
+ # Loading weights onto the fresh architecture avoids the SavedModel shape restrictions!
123
  model.load_weights("gpt2_finetuned_10mb")
124
+ print("Model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # ==============================================================================
127
+ # 3. API ENDPOINTS
128
  # ==============================================================================
129
+ @app.get("/")
130
+ def read_root():
131
+ return {"status": "RAI Engine Python API is running!"}
 
 
 
 
 
 
 
 
132
 
 
 
 
133
  @app.post("/chat")
134
+ def chat_endpoint(req: ChatRequest):
135
+ def generate():
 
136
  try:
137
+ # Reverted to Colab-style strict QA recall settings
138
+ temperature = 0.1
139
+ max_new_tokens = 60
140
+
141
+ formatted_prompt = f"<user> {req.message} <ai>"
142
+ input_ids = enc.encode(formatted_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  original_len = len(input_ids)
144
+
145
+ for _ in range(max_new_tokens):
146
+ # We no longer need to pad with zeros!
147
+ context = input_ids[-gpt_config['block_size']:]
148
  x = tf.constant([context], dtype=tf.int32)
149
+
150
+ logits = model(x, training=False)
151
+
152
+ # Extract raw logits for the last token
153
+ next_token_logits = logits[0, -1, :]
154
+
155
+ # Strict Temperature Sampling (like Colab)
156
+ scaled_logits = next_token_logits / temperature
157
+ scaled_logits = tf.expand_dims(scaled_logits, 0)
158
+
159
+ next_token = tf.random.categorical(scaled_logits, num_samples=1).numpy()[0, 0]
160
+
161
+ next_token_int = int(next_token)
162
+ input_ids.append(next_token_int)
163
+
164
+ # Decode the ENTIRE generation so far
165
+ current_generation = enc.decode(input_ids[original_len:])
166
+
167
+ # Clean up the text stream
168
+ clean_generation = current_generation.replace("\ufffd", "")
169
+ clean_generation = clean_generation.lstrip()
170
+
171
+ # Stop word logic
172
+ if "<user>" in current_generation:
173
+ final_text = clean_generation.split("<user>")[0].rstrip()
174
+ yield f"data: {json.dumps({'text': final_text})}\n\n"
175
  break
176
+
177
+ yield f"data: {json.dumps({'text': clean_generation})}\n\n"
178
+
179
  except Exception as e:
180
  error_details = traceback.format_exc()
181
+ print("CRITICAL ERROR IN /chat ENDPOINT:")
182
  print(error_details)
183
+ yield f"data: {json.dumps({'error': f'🔥 CRASH: {str(e)}'})}\n\n"
 
 
184
 
185
+ return StreamingResponse(generate(), media_type="text/event-stream")