| """ |
| SAM-Z-1 Distributed Worker Node v4.0 |
| Optimized for distributed gen/decode pipeline |
| """ |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import StreamingResponse, HTMLResponse |
| from pydantic import BaseModel |
| import tensorflow as tf |
| import keras |
| from huggingface_hub import hf_hub_download |
| import json |
| import os |
| from tokenizers import Tokenizer |
| import numpy as np |
| import time |
| from typing import List, Optional |
| import asyncio |
|
|
| app = FastAPI(title="SAM-Z-1 Distributed Worker", version="4.0.0") |
|
|
| |
| |
| |
|
|
| @keras.saving.register_keras_serializable() |
| class RotaryEmbedding(keras.layers.Layer): |
| def __init__(self, dim, max_len=2048, theta=10000, **kwargs): |
| super().__init__(**kwargs) |
| self.dim = dim |
| self.max_len = max_len |
| self.theta = theta |
| self.built_cache = False |
| |
| def build(self, input_shape): |
| super().build(input_shape) |
| |
| def _build_cache(self): |
| if not self.built_cache: |
| inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) |
| t = tf.range(self.max_len, dtype=tf.float32) |
| freqs = tf.einsum("i,j->ij", t, inv_freq) |
| emb = tf.concat([freqs, freqs], axis=-1) |
| |
| self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) |
| self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) |
| self.built_cache = True |
| |
| def rotate_half(self, x): |
| x1, x2 = tf.split(x, 2, axis=-1) |
| return tf.concat([-x2, x1], axis=-1) |
| |
| def call(self, q, k): |
| self._build_cache() |
| seq_len = tf.shape(q)[2] |
| dtype = q.dtype |
| cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] |
| sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] |
| |
| q_rotated = (q * cos) + (self.rotate_half(q) * sin) |
| k_rotated = (k * cos) + (self.rotate_half(k) * sin) |
| |
| return q_rotated, k_rotated |
| |
| def get_config(self): |
| config = super().get_config() |
| config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) |
| return config |
|
|
| @keras.saving.register_keras_serializable() |
| class RMSNorm(keras.layers.Layer): |
| def __init__(self, epsilon=1e-5, **kwargs): |
| super().__init__(**kwargs) |
| self.epsilon = epsilon |
| |
| def build(self, input_shape): |
| self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") |
| |
| def call(self, x): |
| variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) |
| return x * tf.math.rsqrt(variance + self.epsilon) * self.scale |
| |
| def get_config(self): |
| config = super().get_config() |
| config.update({"epsilon": self.epsilon}) |
| return config |
|
|
| @keras.saving.register_keras_serializable() |
| class TransformerBlock(keras.layers.Layer): |
| def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): |
| super().__init__(**kwargs) |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.ff_dim = ff_dim |
| self.dropout_rate = dropout |
| self.max_len = max_len |
| self.rope_theta = rope_theta |
| self.head_dim = d_model // n_heads |
| self.layer_idx = layer_idx |
| |
| self.pre_attn_norm = RMSNorm() |
| self.pre_ffn_norm = RMSNorm() |
| |
| self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") |
| self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") |
| self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") |
| self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") |
| |
| self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) |
| |
| self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") |
| self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") |
| self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") |
| |
| self.dropout = keras.layers.Dropout(dropout) |
| |
| def call(self, x, training=None): |
| B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model |
| dtype = x.dtype |
| |
| res = x |
| y = self.pre_attn_norm(x) |
| |
| q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
| k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
| v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
| |
| q, k = self.rope(q, k) |
| |
| scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) |
| mask = tf.where( |
| tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, |
| tf.constant(-1e9, dtype=dtype), |
| tf.constant(0.0, dtype=dtype) |
| ) |
| scores += mask |
| attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) |
| |
| attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) |
| x = res + self.dropout(self.out_proj(attn), training=training) |
| |
| res = x |
| y = self.pre_ffn_norm(x) |
| ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) |
| |
| return res + self.dropout(ffn, training=training) |
| |
| def get_config(self): |
| config = super().get_config() |
| config.update({ |
| "d_model": self.d_model, |
| "n_heads": self.n_heads, |
| "ff_dim": self.ff_dim, |
| "dropout": self.dropout_rate, |
| "max_len": self.max_len, |
| "rope_theta": self.rope_theta, |
| "layer_idx": self.layer_idx |
| }) |
| return config |
|
|
| @keras.saving.register_keras_serializable() |
| class SAM1Model(keras.Model): |
| def __init__(self, **kwargs): |
| super().__init__() |
| if 'config' in kwargs and isinstance(kwargs['config'], dict): |
| self.cfg = kwargs['config'] |
| elif 'vocab_size' in kwargs: |
| self.cfg = kwargs |
| else: |
| self.cfg = kwargs.get('cfg', kwargs) |
| |
| self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") |
| |
| ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) |
| block_args = { |
| 'd_model': self.cfg['d_model'], |
| 'n_heads': self.cfg['n_heads'], |
| 'ff_dim': ff_dim, |
| 'dropout': self.cfg['dropout'], |
| 'max_len': self.cfg['max_len'], |
| 'rope_theta': self.cfg['rope_theta'] |
| } |
| |
| self.blocks = [] |
| for i in range(self.cfg['n_layers']): |
| block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) |
| self.blocks.append(block) |
| |
| self.norm = RMSNorm(name="final_norm") |
| self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") |
| |
| def call(self, input_ids, training=None): |
| x = self.embed(input_ids) |
| for block in self.blocks: |
| x = block(x, training=training) |
| return self.lm_head(self.norm(x)) |
| |
| def get_config(self): |
| base_config = super().get_config() |
| base_config['config'] = self.cfg |
| return base_config |
|
|
| |
| |
| |
|
|
| model = None |
| tokenizer = None |
| config = None |
| eos_token_id = None |
| fast_forward = None |
|
|
| MODEL_REPO = "Smilyai-labs/Sam-Z-1-tensorflow" |
| CACHE_DIR = "./model_cache" |
|
|
| |
| worker_stats = { |
| "total_requests": 0, |
| "total_tokens": 0, |
| "decode_requests": 0, |
| "uptime_start": time.time() |
| } |
|
|
| |
| |
| |
|
|
| class GenerateRequest(BaseModel): |
| prompt: str |
| max_tokens: int = 512 |
| temperature: float = 0.8 |
| top_k: int = 40 |
| top_p: float = 0.9 |
| repetition_penalty: float = 1.1 |
| stream: bool = False |
| return_token_ids: bool = False |
|
|
| class ChatMessage(BaseModel): |
| role: str |
| content: str |
|
|
| class ChatRequest(BaseModel): |
| messages: List[ChatMessage] |
| max_tokens: int = 512 |
| temperature: float = 0.8 |
| top_k: int = 40 |
| top_p: float = 0.9 |
| repetition_penalty: float = 1.1 |
| stream: bool = False |
| return_token_ids: bool = False |
|
|
| class DecodeRequest(BaseModel): |
| token_ids: List[int] |
|
|
| class BatchDecodeRequest(BaseModel): |
| batches: List[List[int]] |
|
|
| |
| |
| |
|
|
| def generate_tokens( |
| prompt: str, |
| max_tokens: int = 512, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.1, |
| return_token_ids: bool = False |
| ): |
| """Core generation - yields (token_id, token_text or None)""" |
| global model, tokenizer, config, eos_token_id, fast_forward |
| |
| input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id] |
| |
| if len(input_ids) == 0: |
| return |
| |
| if len(input_ids) > config['max_position_embeddings'] - max_tokens: |
| input_ids = input_ids[-(config['max_position_embeddings'] - max_tokens):] |
| |
| input_tensor = tf.constant([input_ids], dtype=tf.int32) |
| token_freq = {} |
| |
| for step in range(max_tokens): |
| logits = fast_forward(input_tensor) |
| next_token_logits = logits[0, -1, :].numpy() |
| |
| next_token_logits = next_token_logits / temperature |
| |
| if repetition_penalty != 1.0: |
| for token_id, freq in token_freq.items(): |
| if token_id < len(next_token_logits): |
| next_token_logits[token_id] /= (repetition_penalty ** freq) |
| |
| if top_k > 0: |
| top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:] |
| top_k_logits = next_token_logits[top_k_indices] |
| top_k_probs = tf.nn.softmax(top_k_logits).numpy() |
| |
| if top_p < 1.0: |
| sorted_indices = np.argsort(top_k_probs)[::-1] |
| cumsum = np.cumsum(top_k_probs[sorted_indices]) |
| cutoff_idx = np.searchsorted(cumsum, top_p) |
| nucleus_indices = sorted_indices[:cutoff_idx + 1] |
| |
| nucleus_logits = top_k_logits[nucleus_indices] |
| nucleus_probs = tf.nn.softmax(nucleus_logits).numpy() |
| |
| sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs) |
| next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]]) |
| else: |
| sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs) |
| next_token_id = int(top_k_indices[sampled_idx]) |
| else: |
| probs = tf.nn.softmax(next_token_logits).numpy() |
| next_token_id = np.random.choice(len(probs), p=probs) |
| |
| if next_token_id == eos_token_id: |
| break |
| |
| token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1 |
| |
| if return_token_ids: |
| yield (next_token_id, None) |
| else: |
| token_text = tokenizer.decode([next_token_id]) |
| yield (next_token_id, token_text) |
| |
| input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1) |
| |
| if input_tensor.shape[1] > config['max_position_embeddings']: |
| input_tensor = input_tensor[:, -config['max_position_embeddings']:] |
|
|
| def format_chat_prompt(messages: List[ChatMessage]) -> str: |
| prompt = "" |
| for msg in messages: |
| if msg.role == "user": |
| prompt += f"<|im_start|>user\n{msg.content}<|im_end|>\n" |
| elif msg.role == "assistant": |
| prompt += f"<|im_start|>assistant\n{msg.content}<|im_end|>\n" |
| |
| prompt += "<|im_start|>assistant\n" |
| return prompt |
|
|
| |
| |
| |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def status_page(): |
| """Worker status page""" |
| return """ |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <title>SAM-Z-1 Worker Node</title> |
| <style> |
| * { margin: 0; padding: 0; box-sizing: border-box; } |
| body { |
| font-family: 'Courier New', monospace; |
| background: linear-gradient(135deg, #1a1f3a 0%, #0a0e27 100%); |
| color: #00bfff; |
| padding: 20px; |
| min-height: 100vh; |
| } |
| .container { |
| max-width: 900px; |
| margin: 0 auto; |
| } |
| .header { |
| text-align: center; |
| padding: 30px; |
| background: rgba(0, 191, 255, 0.1); |
| border: 2px solid #00bfff; |
| border-radius: 10px; |
| margin-bottom: 30px; |
| box-shadow: 0 0 20px rgba(0, 191, 255, 0.3); |
| } |
| .header h1 { |
| font-size: 2.5em; |
| text-transform: uppercase; |
| letter-spacing: 3px; |
| animation: glow 2s ease-in-out infinite alternate; |
| } |
| @keyframes glow { |
| from { text-shadow: 0 0 10px #00bfff; } |
| to { text-shadow: 0 0 20px #00bfff, 0 0 30px #00bfff; } |
| } |
| .badge { |
| display: inline-block; |
| padding: 5px 15px; |
| border-radius: 15px; |
| font-size: 0.9em; |
| margin-top: 10px; |
| } |
| .badge-ready { |
| background: rgba(0, 255, 136, 0.2); |
| border: 1px solid #00ff88; |
| color: #00ff88; |
| } |
| .badge-loading { |
| background: rgba(255, 165, 0, 0.2); |
| border: 1px solid #ffa500; |
| color: #ffa500; |
| } |
| .stats-grid { |
| display: grid; |
| grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); |
| gap: 20px; |
| margin-bottom: 30px; |
| } |
| .stat-card { |
| background: rgba(0, 191, 255, 0.05); |
| border: 1px solid #00bfff; |
| border-radius: 8px; |
| padding: 20px; |
| text-align: center; |
| } |
| .stat-label { |
| font-size: 0.8em; |
| opacity: 0.7; |
| text-transform: uppercase; |
| margin-bottom: 10px; |
| } |
| .stat-value { |
| font-size: 2em; |
| font-weight: bold; |
| } |
| .features { |
| background: rgba(0, 191, 255, 0.05); |
| border: 1px solid #00bfff; |
| border-radius: 8px; |
| padding: 20px; |
| } |
| .features h3 { |
| margin-bottom: 15px; |
| } |
| .feature-list { |
| list-style: none; |
| padding: 0; |
| } |
| .feature-list li { |
| padding: 10px; |
| margin: 5px 0; |
| background: rgba(0, 191, 255, 0.1); |
| border-radius: 5px; |
| } |
| .feature-list li:before { |
| content: "β‘ "; |
| color: #00ff88; |
| } |
| .timestamp { |
| text-align: center; |
| margin-top: 20px; |
| opacity: 0.5; |
| } |
| </style> |
| </head> |
| <body> |
| <div class="container"> |
| <div class="header"> |
| <h1>βοΈ WORKER NODE βοΈ</h1> |
| <div>SAM-Z-1 Distributed Worker v4.0</div> |
| <div class="badge" id="status-badge">CHECKING STATUS...</div> |
| </div> |
| |
| <div class="stats-grid" id="stats"> |
| <div class="stat-card"> |
| <div class="stat-label">Total Requests</div> |
| <div class="stat-value" id="total-req">--</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Total Tokens</div> |
| <div class="stat-value" id="total-tokens">--</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Decode Requests</div> |
| <div class="stat-value" id="decode-req">--</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Uptime</div> |
| <div class="stat-value" id="uptime">--</div> |
| </div> |
| </div> |
| |
| <div class="features"> |
| <h3>π CAPABILITIES</h3> |
| <ul class="feature-list"> |
| <li>Full Text Generation</li> |
| <li>Token-Only Mode (for distributed pipeline)</li> |
| <li>High-Speed Batch Decoding</li> |
| <li>Chat Completion</li> |
| <li>Streaming & Non-Streaming</li> |
| </ul> |
| </div> |
| |
| <div class="timestamp" id="timestamp">Initializing...</div> |
| </div> |
| |
| <script> |
| async function updateStats() { |
| try { |
| const response = await fetch('/health'); |
| const data = await response.json(); |
| |
| const badge = document.getElementById('status-badge'); |
| if (data.model_loaded) { |
| badge.textContent = 'β
READY FOR INFERENCE'; |
| badge.className = 'badge badge-ready'; |
| } else { |
| badge.textContent = 'β³ LOADING MODEL...'; |
| badge.className = 'badge badge-loading'; |
| } |
| |
| // Fetch stats |
| const statsRes = await fetch('/stats'); |
| const stats = await statsRes.json(); |
| |
| document.getElementById('total-req').textContent = stats.total_requests; |
| document.getElementById('total-tokens').textContent = stats.total_tokens; |
| document.getElementById('decode-req').textContent = stats.decode_requests; |
| |
| const uptime = Math.floor(stats.uptime); |
| const h = Math.floor(uptime / 3600); |
| const m = Math.floor((uptime % 3600) / 60); |
| const s = uptime % 60; |
| document.getElementById('uptime').textContent = `${h}h ${m}m ${s}s`; |
| |
| document.getElementById('timestamp').textContent = |
| `Last update: ${new Date().toLocaleTimeString()}`; |
| } catch (e) { |
| console.error('Failed to update stats:', e); |
| } |
| } |
| |
| // Update every second |
| setInterval(updateStats, 1000); |
| updateStats(); |
| </script> |
| </body> |
| </html> |
| """ |
|
|
| |
| |
| |
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "healthy" if model is not None else "loading", |
| "model_loaded": model is not None |
| } |
|
|
| @app.get("/stats") |
| async def stats(): |
| uptime = time.time() - worker_stats["uptime_start"] |
| return { |
| "total_requests": worker_stats["total_requests"], |
| "total_tokens": worker_stats["total_tokens"], |
| "decode_requests": worker_stats["decode_requests"], |
| "uptime": uptime, |
| "tokens_per_second": worker_stats["total_tokens"] / uptime if uptime > 0 else 0 |
| } |
|
|
| @app.post("/decode") |
| async def decode(request: DecodeRequest): |
| """Fast single decode""" |
| if tokenizer is None: |
| raise HTTPException(status_code=503, detail="Tokenizer not loaded") |
| |
| try: |
| worker_stats["decode_requests"] += 1 |
| text = tokenizer.decode(request.token_ids) |
| return {"text": text} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Decode error: {str(e)}") |
|
|
| @app.post("/decode/batch") |
| async def batch_decode(request: BatchDecodeRequest): |
| """Optimized batch decoding for distributed pipeline""" |
| if tokenizer is None: |
| raise HTTPException(status_code=503, detail="Tokenizer not loaded") |
| |
| try: |
| worker_stats["decode_requests"] += len(request.batches) |
| results = [tokenizer.decode(batch) for batch in request.batches] |
| return {"texts": results} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Batch decode error: {str(e)}") |
|
|
| @app.post("/generate") |
| async def generate(request: GenerateRequest): |
| """Generate text""" |
| if model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| worker_stats["total_requests"] += 1 |
| start_time = time.time() |
| |
| if request.stream: |
| async def stream_tokens(): |
| generated_text = "" |
| token_count = 0 |
| |
| try: |
| for token_id, token_text in generate_tokens( |
| request.prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| repetition_penalty=request.repetition_penalty, |
| return_token_ids=request.return_token_ids |
| ): |
| token_count += 1 |
| worker_stats["total_tokens"] += 1 |
| |
| if request.return_token_ids: |
| yield f"data: {json.dumps({'token_id': token_id})}\n\n" |
| else: |
| generated_text += token_text |
| yield f"data: {json.dumps({'text': token_text, 'total': generated_text})}\n\n" |
| |
| await asyncio.sleep(0.001) |
| |
| elapsed = time.time() - start_time |
| yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n" |
| |
| except Exception as e: |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| |
| return StreamingResponse(stream_tokens(), media_type="text/event-stream") |
| |
| else: |
| generated_text = "" |
| token_count = 0 |
| |
| try: |
| for token_id, token_text in generate_tokens( |
| request.prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| repetition_penalty=request.repetition_penalty, |
| return_token_ids=request.return_token_ids |
| ): |
| if not request.return_token_ids: |
| generated_text += token_text |
| token_count += 1 |
| worker_stats["total_tokens"] += 1 |
| |
| elapsed = time.time() - start_time |
| |
| return { |
| "text": generated_text, |
| "tokens": token_count, |
| "time": elapsed, |
| "tokens_per_second": token_count / elapsed if elapsed > 0 else 0 |
| } |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") |
|
|
| @app.post("/chat") |
| async def chat(request: ChatRequest): |
| """Chat completion""" |
| if model is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| worker_stats["total_requests"] += 1 |
| prompt = format_chat_prompt(request.messages) |
| start_time = time.time() |
| |
| if request.stream: |
| async def stream_tokens(): |
| generated_text = "" |
| token_count = 0 |
| |
| try: |
| for token_id, token_text in generate_tokens( |
| prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| repetition_penalty=request.repetition_penalty, |
| return_token_ids=request.return_token_ids |
| ): |
| token_count += 1 |
| worker_stats["total_tokens"] += 1 |
| |
| if request.return_token_ids: |
| yield f"data: {json.dumps({'token_id': token_id})}\n\n" |
| else: |
| generated_text += token_text |
| |
| if "<|im_end|>" in generated_text: |
| generated_text = generated_text.split("<|im_end|>")[0] |
| break |
| |
| yield f"data: {json.dumps({'delta': token_text, 'content': generated_text})}\n\n" |
| |
| await asyncio.sleep(0.001) |
| |
| elapsed = time.time() - start_time |
| yield f"data: {json.dumps({'done': True, 'tokens': token_count, 'time': elapsed})}\n\n" |
| |
| except Exception as e: |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| |
| return StreamingResponse(stream_tokens(), media_type="text/event-stream") |
| |
| else: |
| generated_text = "" |
| token_count = 0 |
| |
| try: |
| for token_id, token_text in generate_tokens( |
| prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| repetition_penalty=request.repetition_penalty, |
| return_token_ids=request.return_token_ids |
| ): |
| if not request.return_token_ids: |
| generated_text += token_text |
| |
| if "<|im_end|>" in generated_text: |
| generated_text = generated_text.split("<|im_end|>")[0] |
| break |
| |
| token_count += 1 |
| worker_stats["total_tokens"] += 1 |
| |
| elapsed = time.time() - start_time |
| |
| return { |
| "message": { |
| "role": "assistant", |
| "content": generated_text.strip() |
| }, |
| "tokens": token_count, |
| "time": elapsed, |
| "tokens_per_second": token_count / elapsed if elapsed > 0 else 0 |
| } |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") |
|
|
| |
| |
| |
|
|
| @app.on_event("startup") |
| async def load_model(): |
| global model, tokenizer, config, eos_token_id, fast_forward |
| |
| print("π Loading SAM-Z-1 Model...") |
| |
| try: |
| config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR) |
| |
| try: |
| weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR) |
| print("β
Found checkpoint weights") |
| use_checkpoint = True |
| except: |
| print("β οΈ Checkpoint not found, using model.keras") |
| model_path = hf_hub_download(MODEL_REPO, "model.keras", cache_dir=CACHE_DIR) |
| use_checkpoint = False |
| |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
| |
| print(f"π¦ Config loaded: {config['num_hidden_layers']} layers") |
| |
| print("π¦ Creating tokenizer...") |
| from transformers import AutoTokenizer |
| |
| hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| custom_tokens = ["<|im_start|>", "<|im_end|>", "<think>", "<think/>"] |
| hf_tokenizer.add_special_tokens({"additional_special_tokens": custom_tokens}) |
| |
| os.makedirs("./temp_tokenizer", exist_ok=True) |
| hf_tokenizer.save_pretrained("./temp_tokenizer") |
| tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") |
| |
| eos_token_id = config.get('eos_token_id', 50256) |
| |
| print(f"β
Tokenizer ready: vocab size {tokenizer.get_vocab_size()}") |
| |
| print("π Loading model...") |
| |
| if use_checkpoint: |
| model_config = { |
| 'vocab_size': config['vocab_size'], |
| 'd_model': config['hidden_size'], |
| 'n_layers': config['num_hidden_layers'], |
| 'n_heads': config['num_attention_heads'], |
| 'ff_mult': config['intermediate_size'] / config['hidden_size'], |
| 'max_len': config['max_position_embeddings'], |
| 'dropout': 0.1, |
| 'rope_theta': config['rope_theta'] |
| } |
| |
| model = SAM1Model(config=model_config) |
| dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32) |
| _ = model(dummy_input, training=False) |
| |
| print(f"β
Architecture built: {model.count_params():,} parameters") |
| |
| model.load_weights(weights_path) |
| print("β
Weights loaded!") |
| |
| else: |
| model = keras.models.load_model(model_path, compile=False) |
| print("β
Model loaded!") |
| |
| @tf.function(reduce_retracing=True) |
| def optimized_forward(input_tensor): |
| return model(input_tensor, training=False) |
| |
| fast_forward = optimized_forward |
| |
| print("β
SAM-Z-1 Distributed Worker ready! π") |
| print("π₯ Features enabled:") |
| print(" - Full text generation") |
| print(" - Token-only mode (distributed pipeline)") |
| print(" - Batch decoding optimization") |
| print(" - Streaming support") |
| |
| except Exception as e: |
| print(f"β Failed to load model: {e}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=7860, |
| log_level="info" |
| ) |