| """ |
| Sam-large-2 Distributed Inference - HEAD NODE |
| Edit the CONFIG below, then deploy. |
| """ |
|
|
| |
| |
| |
|
|
| CONFIG = { |
| "node_id": "head-main", |
| "layer_start": 0, |
| "layer_end": 6, |
| "worker_urls": [], |
| "secret_token": "sam2-distributed-secret-change-me", |
| "model_repo": "Smilyai-labs/Sam-large-2", |
| "cache_dir": "./model_cache", |
| } |
|
|
| |
| |
| |
|
|
| import os |
| NUM_CORES = os.cpu_count() or 4 |
|
|
| os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES) |
| os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES) |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
| import json |
| import time |
| import io |
| import base64 |
| from typing import Dict, List, Optional, Tuple, Any |
|
|
| import gradio as gr |
| import numpy as np |
| import requests |
| import tensorflow as tf |
| import keras |
| from huggingface_hub import hf_hub_download |
|
|
| tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES) |
| tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES) |
|
|
| print(f"β
CPU optimized: {NUM_CORES} threads") |
|
|
| |
| |
| |
|
|
| @keras.saving.register_keras_serializable() |
| class RotaryEmbedding(keras.layers.Layer): |
| def __init__(self, dim, max_len=2048, theta=10000, **kwargs): |
| super().__init__(**kwargs) |
| self.dim = dim |
| self.max_len = max_len |
| self.theta = theta |
| self.built_cache = False |
| self.cos_cached = None |
| self.sin_cached = None |
|
|
| def build(self, input_shape): |
| super().build(input_shape) |
|
|
| def _build_cache(self): |
| if not self.built_cache: |
| inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) |
| t = tf.range(self.max_len, dtype=tf.float32) |
| freqs = tf.einsum("i,j->ij", t, inv_freq) |
| emb = tf.concat([freqs, freqs], axis=-1) |
| self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) |
| self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) |
| self.built_cache = True |
|
|
| def rotate_half(self, x): |
| x1, x2 = tf.split(x, 2, axis=-1) |
| return tf.concat([-x2, x1], axis=-1) |
|
|
| def call(self, q, k, offset=0): |
| self._build_cache() |
| seq_len = tf.shape(q)[2] |
| dtype = q.dtype |
| cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] |
| sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] |
| q_embed = (q * cos) + (self.rotate_half(q) * sin) |
| k_embed = (k * cos) + (self.rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
| def get_config(self): |
| return {**super().get_config(), "dim": self.dim, "max_len": self.max_len, "theta": self.theta} |
|
|
|
|
| @keras.saving.register_keras_serializable() |
| class RMSNorm(keras.layers.Layer): |
| def __init__(self, epsilon=1e-5, **kwargs): |
| super().__init__(**kwargs) |
| self.epsilon = epsilon |
|
|
| def build(self, input_shape): |
| self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") |
| super().build(input_shape) |
|
|
| def call(self, x): |
| variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) |
| return x * tf.math.rsqrt(variance + self.epsilon) * self.scale |
|
|
| def get_config(self): |
| return {**super().get_config(), "epsilon": self.epsilon} |
|
|
|
|
| @keras.saving.register_keras_serializable() |
| class TransformerBlock(keras.layers.Layer): |
| def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): |
| super().__init__(**kwargs) |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.ff_dim = ff_dim |
| self.dropout_rate = dropout |
| self.max_len = max_len |
| self.rope_theta = rope_theta |
| self.head_dim = d_model // n_heads |
| self.layer_idx = layer_idx |
|
|
| def build(self, input_shape): |
| self.pre_attn_norm = RMSNorm(name="pre_attn_norm") |
| self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm") |
| self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj") |
| self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj") |
| self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj") |
| self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj") |
| self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta) |
| self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj") |
| self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj") |
| self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj") |
| self.dropout = keras.layers.Dropout(self.dropout_rate) |
| super().build(input_shape) |
|
|
| def call(self, x, training=None, past_kv=None, use_cache=False): |
| B, T = tf.shape(x)[0], tf.shape(x)[1] |
| dtype = x.dtype |
|
|
| res = x |
| y = self.pre_attn_norm(x) |
|
|
| q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
| k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
| v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
|
|
| past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0 |
| q, k = self.rope(q, k, offset=past_len) |
|
|
| if past_kv is not None: |
| k = tf.concat([past_kv[0], k], axis=2) |
| v = tf.concat([past_kv[1], v], axis=2) |
|
|
| new_kv = (k, v) if use_cache else None |
|
|
| scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) |
| full_len = tf.shape(k)[2] |
| q_pos = tf.range(past_len, past_len + T) |
| k_pos = tf.range(full_len) |
| mask = tf.where(q_pos[:, None] >= k_pos[None, :], 0.0, -1e9) |
| scores = scores + tf.cast(mask[None, None, :, :], dtype) |
|
|
| attn = tf.nn.softmax(scores, axis=-1) |
| attn_out = tf.reshape(tf.transpose(tf.matmul(attn, v), [0, 2, 1, 3]), [B, T, self.d_model]) |
| x = res + self.dropout(self.out_proj(attn_out), training=training) |
|
|
| res = x |
| y = self.pre_ffn_norm(x) |
| ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) |
| return res + self.dropout(ffn, training=training), new_kv |
|
|
| def get_config(self): |
| return {**super().get_config(), "d_model": self.d_model, "n_heads": self.n_heads, |
| "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, |
| "rope_theta": self.rope_theta, "layer_idx": self.layer_idx} |
|
|
|
|
| |
| |
| |
|
|
| class ModelState: |
| def __init__(self): |
| self.config = None |
| self.tokenizer = None |
| self.eos_token_id = 50256 |
| self.embedding = None |
| self.blocks: List = [] |
| self.final_norm = None |
| self.lm_head = None |
| self.my_block_start = 0 |
| self.my_block_end = 0 |
|
|
| STATE = ModelState() |
| stop_generation = False |
|
|
| |
| |
| |
|
|
| def serialize_tensor(tensor: tf.Tensor) -> str: |
| buffer = io.BytesIO() |
| np.save(buffer, tensor.numpy(), allow_pickle=False) |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
| def deserialize_tensor(data: str) -> tf.Tensor: |
| buffer = io.BytesIO(base64.b64decode(data)) |
| return tf.constant(np.load(buffer, allow_pickle=False)) |
|
|
| def serialize_kv_cache(past_kv): |
| if past_kv is None: |
| return None |
| return [{"k": serialize_tensor(k), "v": serialize_tensor(v)} if k is not None else None for k, v in past_kv] |
|
|
| def deserialize_kv_cache(data): |
| if data is None: |
| return None |
| return [(deserialize_tensor(item["k"]), deserialize_tensor(item["v"])) if item else None for item in data] |
|
|
| |
| |
| |
|
|
| def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]: |
| try: |
| response = requests.post( |
| f"{url.rstrip('/')}/api/forward", |
| json={ |
| "hidden_states": serialize_tensor(hidden_states), |
| "past_kv": serialize_kv_cache(past_kv), |
| "use_cache": use_cache, |
| }, |
| headers={"Authorization": f"Bearer {CONFIG['secret_token']}"}, |
| timeout=120 |
| ) |
| |
| if response.status_code == 200: |
| result = response.json() |
| output = deserialize_tensor(result["hidden_states"]) |
| new_kv = deserialize_kv_cache(result.get("past_kv")) |
| return output, new_kv |
| else: |
| raise RuntimeError(f"Worker returned {response.status_code}") |
| except Exception as e: |
| raise RuntimeError(f"Worker call failed: {e}") |
|
|
| |
| |
| |
|
|
| def load_model(): |
| print("π Loading model...") |
| |
| config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"]) |
| with open(config_path, 'r') as f: |
| model_config = json.load(f) |
| STATE.config = model_config |
| |
| from transformers import AutoTokenizer |
| from tokenizers import Tokenizer |
| |
| hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| hf_tokenizer.add_special_tokens({"additional_special_tokens": |
| ["<|im_start|>", "<|im_end|>", "<think>", "</think>", "<CONTINUE>", "<im end for model tun>"]}) |
| os.makedirs("./temp_tokenizer", exist_ok=True) |
| hf_tokenizer.save_pretrained("./temp_tokenizer") |
| STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") |
| STATE.eos_token_id = model_config.get('eos_token_id', 50256) |
| |
| weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"]) |
| |
| n_layers = model_config['num_hidden_layers'] |
| d_model = model_config['hidden_size'] |
| n_heads = model_config['num_attention_heads'] |
| ff_dim = model_config['intermediate_size'] |
| max_len = model_config['max_position_embeddings'] |
| rope_theta = model_config['rope_theta'] |
| vocab_size = model_config['vocab_size'] |
| |
| embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens") |
| blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}") |
| for i in range(n_layers)] |
| final_norm = RMSNorm(name="final_norm") |
| lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head") |
| |
| dummy = tf.zeros((1, 16), dtype=tf.int32) |
| x = embedding(dummy) |
| for block in blocks: |
| x, _ = block(x) |
| x = final_norm(x) |
| _ = lm_head(x) |
| |
| class TempModel(keras.Model): |
| def __init__(self): |
| super().__init__() |
| self.embed = embedding |
| self.blocks = blocks |
| self.norm = final_norm |
| self.lm_head = lm_head |
| def call(self, x): |
| x = self.embed(x) |
| for b in self.blocks: |
| x, _ = b(x) |
| return self.lm_head(self.norm(x)) |
| |
| temp_model = TempModel() |
| temp_model(dummy) |
| temp_model.load_weights(weights_path) |
| print("β
Weights loaded") |
| |
| STATE.my_block_start = CONFIG["layer_start"] |
| STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers |
| |
| STATE.embedding = embedding |
| STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end] |
| print(f"β
Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}") |
| |
| has_workers = len(CONFIG["worker_urls"]) > 0 |
| if not has_workers: |
| STATE.final_norm = final_norm |
| STATE.lm_head = lm_head |
| print("β
Loaded final norm and LM head (standalone mode)") |
| |
| print("π₯ Warming up...") |
| dummy = tf.constant([[1, 2, 3]], dtype=tf.int32) |
| x = STATE.embedding(dummy) |
| for block in STATE.blocks: |
| x, _ = block(x, use_cache=False) |
| if STATE.lm_head: |
| _ = STATE.lm_head(STATE.final_norm(x)) |
| |
| print("β
Model ready!") |
| return True |
|
|
| |
| |
| |
|
|
| def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False): |
| x = STATE.embedding(input_ids) |
| |
| new_local_kv = [] if use_cache else None |
| for i, block in enumerate(STATE.blocks): |
| block_past = past_kv_local[i] if past_kv_local else None |
| x, kv = block(x, past_kv=block_past, use_cache=use_cache) |
| if use_cache: |
| new_local_kv.append(kv) |
| |
| new_worker_kv = {} if use_cache else None |
| for worker_url in CONFIG["worker_urls"]: |
| worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None |
| x, worker_kv = call_worker(worker_url, x, worker_past, use_cache) |
| if use_cache: |
| new_worker_kv[worker_url] = worker_kv |
| |
| if STATE.lm_head: |
| logits = STATE.lm_head(STATE.final_norm(x)) |
| else: |
| logits = x |
| |
| return logits, new_local_kv, new_worker_kv |
|
|
| |
| |
| |
|
|
| def sample_token(logits, temperature, top_k, top_p, token_freq, rep_penalty): |
| logits = np.array(logits) / temperature |
| |
| for tid, freq in token_freq.items(): |
| if tid < len(logits): |
| logits[tid] /= (rep_penalty ** freq) |
| |
| if 0 < top_k < len(logits): |
| top_k_idx = np.argpartition(logits, -top_k)[-top_k:] |
| top_k_logits = logits[top_k_idx] |
| else: |
| top_k_idx = np.arange(len(logits)) |
| top_k_logits = logits |
| |
| top_k_logits = top_k_logits - np.max(top_k_logits) |
| probs = np.exp(top_k_logits) |
| probs /= probs.sum() |
| |
| if top_p < 1.0: |
| sorted_idx = np.argsort(probs)[::-1] |
| cumsum = np.cumsum(probs[sorted_idx]) |
| cutoff = np.searchsorted(cumsum, top_p) + 1 |
| nucleus_idx = sorted_idx[:cutoff] |
| nucleus_probs = probs[nucleus_idx] |
| nucleus_probs /= nucleus_probs.sum() |
| sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs) |
| return int(top_k_idx[nucleus_idx[sampled]]) |
| |
| return int(top_k_idx[np.random.choice(len(probs), p=probs)]) |
|
|
|
|
| def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_p=0.9, rep_penalty=1.1): |
| global stop_generation |
| stop_generation = False |
| |
| input_ids = [i for i in STATE.tokenizer.encode(prompt).ids if i != STATE.eos_token_id] |
| if not input_ids: |
| yield "Error: Empty prompt" |
| return |
| |
| generated = "" |
| token_freq = {} |
| |
| stop_ids = {STATE.eos_token_id, STATE.tokenizer.token_to_id("<|im_end|>"), |
| STATE.tokenizer.token_to_id("<im end for model tun>")} |
| stop_ids.discard(None) |
| |
| max_ctx = STATE.config['max_position_embeddings'] |
| if len(input_ids) > max_ctx - max_tokens: |
| input_ids = input_ids[-(max_ctx - max_tokens):] |
| |
| start = time.time() |
| |
| input_tensor = tf.constant([input_ids], dtype=tf.int32) |
| try: |
| logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True) |
| except Exception as e: |
| yield f"Error: {e}" |
| return |
| |
| next_logits = logits[0, -1, :].numpy() |
| prefill_time = time.time() - start |
| print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s") |
| |
| decode_start = time.time() |
| tokens_generated = 0 |
| |
| for _ in range(max_tokens): |
| if stop_generation: |
| yield generated + "\n\n*[Stopped]*" |
| return |
| |
| next_id = sample_token(next_logits, temperature, top_k, top_p, token_freq, rep_penalty) |
| |
| if next_id in stop_ids: |
| break |
| |
| token_freq[next_id] = token_freq.get(next_id, 0) + 1 |
| generated += STATE.tokenizer.decode([next_id]) |
| tokens_generated += 1 |
| yield generated |
| |
| next_input = tf.constant([[next_id]], dtype=tf.int32) |
| try: |
| logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True) |
| except Exception as e: |
| yield generated + f"\n\n*[Error: {e}]*" |
| return |
| |
| next_logits = logits[0, -1, :].numpy() |
| |
| if tokens_generated > 0: |
| total = time.time() - start |
| tps = tokens_generated / (time.time() - decode_start) |
| workers = len(CONFIG["worker_urls"]) |
| mode = f", {workers} workers" if workers else " standalone" |
| generated += f"\n\n*[{tokens_generated} tokens in {total:.1f}s ({tps:.1f} tok/s){mode}]*" |
| |
| yield generated |
|
|
|
|
| def format_prompt(message: str, history: list, reasoning: bool) -> str: |
| prompt = "" |
| for msg in history: |
| if msg["role"] == "user": |
| prompt += f"<|im_start|>user\n{msg['content']}<|im_end|>\n" |
| elif msg["role"] == "assistant": |
| content = msg['content'].split('*[')[0].strip() |
| prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n" |
| prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" |
| if reasoning: |
| prompt += "<think>" |
| return prompt |
|
|
|
|
| def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reasoning): |
| if not message.strip(): |
| yield history |
| return |
| |
| prompt = format_prompt(message, history, reasoning) |
| |
| |
| history = history + [{"role": "user", "content": message}] |
| |
| for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen): |
| display = text |
| |
| |
| for tag in ["<|im_end|>", "<im end for model tun>"]: |
| if tag in display: |
| idx = display.find(tag) |
| stats = display.find("\n\n*[") |
| display = display[:idx] + (display[stats:] if stats > idx else "") |
| |
| |
| if reasoning and '<think>' in display and '</think>' in display: |
| s, e = display.find('<think>'), display.find('</think>') |
| if s < e: |
| thought = display[s+7:e].strip() |
| display = display[:s] + f'<details><summary>π§ Reasoning</summary><p>{thought}</p></details>' + display[e+8:] |
| |
| yield history + [{"role": "assistant", "content": display.strip()}] |
|
|
|
|
| def stop(): |
| global stop_generation |
| stop_generation = True |
|
|
| |
| |
| |
|
|
| def create_ui(): |
| workers = CONFIG["worker_urls"] |
| mode = f"Distributed ({len(workers)} workers)" if workers else "Standalone" |
| |
| with gr.Blocks(title="Sam-large-2 HEAD") as app: |
| gr.Markdown(f""" |
| # π Sam-large-2 - HEAD NODE |
| **Mode:** {mode} | **Blocks:** {CONFIG['layer_start']}-{CONFIG['layer_end']-1} | **ID:** {CONFIG['node_id']} |
| """) |
| |
| if workers: |
| gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers)) |
| |
| reasoning = gr.State(False) |
| |
| chatbot = gr.Chatbot( |
| height=500, |
| type="messages" |
| ) |
| |
| with gr.Row(): |
| reason_btn = gr.Button("π‘", size="sm", scale=0) |
| msg = gr.Textbox(placeholder="Type message...", show_label=False, scale=8) |
| send = gr.Button("Send", variant="primary", scale=1) |
| stop_btn = gr.Button("βΉοΈ", scale=0) |
| |
| with gr.Accordion("βοΈ Settings", open=False): |
| max_tok = gr.Slider(50, 1024, 512, label="Max Tokens") |
| temp = gr.Slider(0.1, 2.0, 0.8, label="Temperature") |
| topk = gr.Slider(1, 100, 40, label="Top-K") |
| topp = gr.Slider(0.1, 1.0, 0.9, label="Top-P") |
| rep = gr.Slider(1.0, 2.0, 1.1, label="Repetition Penalty") |
| |
| def toggle(r): |
| return not r, gr.update(variant="primary" if not r else "secondary") |
| |
| reason_btn.click(toggle, [reasoning], [reasoning, reason_btn]) |
| |
| inputs = [msg, chatbot, max_tok, temp, topk, topp, rep, reasoning] |
| submit = msg.submit(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg) |
| click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg) |
| stop_btn.click(stop, cancels=[submit, click]) |
| |
| gr.Button("ποΈ Clear").click(lambda: [], outputs=[chatbot]) |
| |
| return app |
|
|
| |
| |
| |
|
|
| print("=" * 60) |
| print("π Sam-large-2 HEAD Node Starting") |
| print(f" Blocks: {CONFIG['layer_start']} to {CONFIG['layer_end']}") |
| print(f" Workers: {CONFIG['worker_urls'] or 'None (standalone)'}") |
| print("=" * 60) |
|
|
| load_model() |
| app = create_ui() |
| app.queue() |
| app.launch(server_name="0.0.0.0", server_port=7860) |