| import os |
| import json |
| from datetime import datetime |
| import asyncio |
| import aiohttp |
| from typing import Dict, List, Optional |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, HttpUrl |
| import uvicorn |
| from git_clone import clone_repository |
|
|
| |
| class Settings: |
| |
| CONTROLLER_HOST = "0.0.0.0" |
| CONTROLLER_PORT = 8000 |
| |
| CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", "http://192.168.1.100:8000") |
|
|
| |
| TENSOR_SERVER_URLS = [ |
| url for url in os.getenv("TENSOR_SERVER_URLS", "").split(",") if url |
| ] or [ |
| "https://fred808-ilob.hf.space", |
| "https://fred808-tserv.hf.space", |
| "https://fred808-tserve2.hf.space", |
| ] |
| AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", "http://192.168.1.104:8002") |
| |
| |
| MODEL_REPO = "https://huggingface.co/inference-net/Schematron-8B" |
| |
| |
| TENSOR_SERVER_TIMEOUT = 30 |
| MAX_ERROR_THRESHOLD = 5 |
| SERVER_TIMEOUT = 60 |
| MONITORING_INTERVAL = 15 |
| |
| |
| @classmethod |
| def get_optimal_chunk_size(cls, total_params: int, num_servers: int) -> int: |
| """Calculate optimal chunk size based on number of servers""" |
| |
| target_chunks = num_servers * 2 |
| return max(1, total_params // target_chunks) |
| |
| @classmethod |
| def get_min_servers_required(cls) -> int: |
| """Dynamically calculate minimum servers needed based on registered servers""" |
| return max(2, len(cls.TENSOR_SERVER_URLS) // 3) |
| |
| @classmethod |
| def get_min_replica_count(cls, num_servers: int) -> int: |
| """Calculate minimum replicas based on server count""" |
| return max(2, num_servers // 4) |
| |
| |
| MAX_SEQUENCE_LENGTH = 2048 |
| VOCAB_SIZE = 50257 |
| |
| @classmethod |
| def from_env(cls): |
| """Load settings from environment variables""" |
| cls.CONTROLLER_HOST = os.getenv("CONTROLLER_HOST", cls.CONTROLLER_HOST) |
| cls.CONTROLLER_PORT = int(os.getenv("CONTROLLER_PORT", cls.CONTROLLER_PORT)) |
| cls.CONTROLLER_BASE_URL = os.getenv("CONTROLLER_BASE_URL", cls.CONTROLLER_BASE_URL) |
| |
| |
| tensor_urls = os.getenv("TENSOR_SERVER_URLS") |
| if tensor_urls: |
| cls.TENSOR_SERVER_URLS = tensor_urls.split(",") |
| |
| cls.AGGREGATOR_HOST = os.getenv("AGGREGATOR_HOST", cls.AGGREGATOR_HOST) |
| cls.AGGREGATOR_PORT = int(os.getenv("AGGREGATOR_PORT", cls.AGGREGATOR_PORT)) |
| cls.AGGREGATOR_URL = os.getenv("AGGREGATOR_URL", |
| f"http://{cls.AGGREGATOR_HOST}:{cls.AGGREGATOR_PORT}") |
| |
| return cls |
|
|
| |
| class ServerMetrics(BaseModel): |
| """Metrics for tensor server performance and load""" |
| cpu_usage: float = 0.0 |
| memory_usage: float = 0.0 |
| gpu_usage: Optional[float] = None |
| active_requests: int = 0 |
| total_requests: int = 0 |
| average_response_time: float = 0.0 |
| last_error: Optional[str] = None |
| error_count: int = 0 |
|
|
| class TensorServer(BaseModel): |
| """Represents a registered tensor server""" |
| url: HttpUrl |
| status: str = "initializing" |
| last_heartbeat: datetime = datetime.now() |
| model_chunks: List[int] = [] |
| metrics: ServerMetrics = ServerMetrics() |
| version: str = "1.0.0" |
| capabilities: Dict[str, bool] = { |
| "gpu_available": False, |
| "quantization_support": False, |
| "tensor_parallelism": False |
| } |
| |
| class ModelChunk(BaseModel): |
| """Represents a chunk of the model to be sent to a tensor server""" |
| chunk_id: int |
| files: List[str] |
| config: Dict |
| size_bytes: int = 0 |
| server_assignments: List[str] = [] |
| status: str = "unassigned" |
| metrics: Dict[str, float] = { |
| "load_time": 0.0, |
| "memory_usage": 0.0, |
| "average_inference_time": 0.0 |
| } |
|
|
| |
| app = FastAPI( |
| title="Florence-2 Model Controller", |
| description="Controls model distribution across tensor servers", |
| version="1.0.0" |
| ) |
|
|
| |
| class ControllerState: |
| def __init__(self): |
| self.model_files: Dict[str, str] = {} |
| self.model_config: Dict = {} |
| self.tensor_servers: Dict[str, TensorServer] = {} |
| self.model_chunks: Dict[int, ModelChunk] = {} |
| self.is_model_loaded = False |
| self.model_path: str = "" |
| self.chunks_dir: str = "" |
| self.operation_results: Dict[str, Dict] = {} |
| self.pending_operations: Dict[str, asyncio.Task] = {} |
| |
| state = ControllerState() |
|
|
| |
| async def split_model_weights(): |
| """Split model files into chunks based on available servers without loading into memory""" |
| try: |
| import os |
| import math |
| import shutil |
| from pathlib import Path |
| |
| |
| weight_files = [f for f in state.model_files.values() if f.endswith(('.safetensors', '.bin'))] |
| |
| if not weight_files: |
| raise Exception("No model weight files found") |
| |
| |
| |
| |
| model_file = max(weight_files, key=os.path.getsize) if len(weight_files) > 1 else weight_files[0] |
| |
| if len(weight_files) > 1: |
| print(f"[WARN] Found multiple weight files. Selecting the largest one for splitting: {model_file}") |
| else: |
| print(f"[INFO] Found model weight file: {model_file}") |
|
|
| |
| try: |
| with open(model_file, 'rb') as f: |
| |
| f.seek(0, 2) |
| file_size = f.tell() |
| f.seek(0) |
| |
| |
| header = f.read(8) |
| if len(header) == 0: |
| raise ValueError(f"File is empty: {model_file}") |
| except Exception as e: |
| raise Exception(f"Failed to read model file {model_file}: {str(e)}") |
| |
| |
| if file_size < 1024: |
| raise ValueError(f"Model file suspiciously small ({file_size} bytes). Possible corruption or incomplete download.") |
| |
| num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS) |
| num_chunks = num_servers |
| |
| chunk_size = math.ceil(file_size / num_chunks) |
| |
| |
| def format_size(size_bytes): |
| if size_bytes >= 1024*1024*1024: |
| return f"{size_bytes / (1024*1024*1024):.2f} GB ({size_bytes:,} bytes)" |
| elif size_bytes >= 1024*1024: |
| return f"{size_bytes / (1024*1024):.2f} MB ({size_bytes:,} bytes)" |
| elif size_bytes >= 1024: |
| return f"{size_bytes / 1024:.2f} KB ({size_bytes:,} bytes)" |
| else: |
| return f"{size_bytes:,} bytes" |
| |
| print(f"[INFO] Model file size: {format_size(file_size)}") |
| print(f"[INFO] Creating {num_chunks} chunks of approximately {format_size(chunk_size)} each") |
|
|
| |
| os.makedirs(state.chunks_dir, exist_ok=True) |
|
|
| |
| with open(model_file, 'rb') as f: |
| chunk_sizes = [] |
| for chunk_id in range(num_chunks): |
| chunk_path = os.path.join(state.chunks_dir, f"chunk_{chunk_id}.bin") |
| |
| |
| start_pos = chunk_id * chunk_size |
| remaining = file_size - start_pos |
| current_chunk_size = min(chunk_size, remaining) |
| |
| if current_chunk_size <= 0: |
| break |
| |
| |
| try: |
| f.seek(start_pos) |
| chunk_data = f.read(current_chunk_size) |
| actual_chunk_size = len(chunk_data) |
| |
| if actual_chunk_size != current_chunk_size: |
| print(f"[WARN] Chunk {chunk_id} size mismatch. Expected: {current_chunk_size}, Got: {actual_chunk_size}") |
| |
| with open(chunk_path, 'wb') as chunk_file: |
| chunk_file.write(chunk_data) |
| |
| chunk_sizes.append(actual_chunk_size) |
| print(f"[DEBUG] Chunk {chunk_id} data: First few bytes: {chunk_data[:20].hex()}") |
| except Exception as e: |
| raise Exception(f"Failed to process chunk {chunk_id} at offset {start_pos}: {str(e)}") |
| |
| |
| |
| |
| cumulative = 0 |
| for cid, c in state.model_chunks.items(): |
| try: |
| cumulative += int(c.config.get('shard_dim', c.config.get('size_bytes', 1))) |
| except Exception: |
| cumulative += 1 |
|
|
| cfg = { |
| "start_offset": start_pos, |
| "size_bytes": current_chunk_size, |
| "is_last_chunk": chunk_id == num_chunks - 1, |
| "total_chunks": num_chunks, |
| "original_file": os.path.basename(model_file), |
| |
| "vocab_offset": cumulative, |
| |
| |
| "shard_dim": int(cfg.get('shard_dim', 1)) if isinstance(cfg := {} , dict) else 1 |
| } |
|
|
| state.model_chunks[chunk_id] = ModelChunk( |
| chunk_id=chunk_id, |
| files=[f"chunk_{chunk_id}.bin"], |
| config=cfg, |
| size_bytes=current_chunk_size, |
| status="ready" |
| ) |
| |
| print(f"[INFO] Created chunk {chunk_id}: {format_size(current_chunk_size)} ({current_chunk_size:,} bytes)") |
| |
| |
| total_size_actual = sum(chunk_sizes) |
| if total_size_actual != file_size: |
| print(f"[WARN] Total chunk size ({format_size(total_size_actual)}) differs from original file size ({format_size(file_size)})") |
| print(f"[WARN] Difference: {format_size(abs(total_size_actual - file_size))}") |
| |
| |
| avg_chunk_size = sum(chunk_sizes) / len(chunk_sizes) if chunk_sizes else 0 |
| min_chunk_size = min(chunk_sizes) if chunk_sizes else 0 |
| max_chunk_size = max(chunk_sizes) if chunk_sizes else 0 |
| |
| print(f"\n[INFO] Distribution Summary:") |
| print(f"- Original file: {os.path.basename(model_file)}") |
| print(f"- Total size: {format_size(file_size)} ({file_size:,} bytes)") |
| print(f"- Number of chunks: {len(state.model_chunks)}") |
| print(f"- Chunks directory: {state.chunks_dir}") |
| print(f"- Average chunk size: {format_size(avg_chunk_size)}") |
| print(f"- Smallest chunk: {format_size(min_chunk_size)}") |
| print(f"- Largest chunk: {format_size(max_chunk_size)}") |
| print(f"- Size variance: {((max_chunk_size - min_chunk_size) / avg_chunk_size * 100):.1f}%") |
| |
| return True |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to split model weights: {str(e)}") |
| return False |
| |
| |
| total_size_bytes = sum(p.nelement() * p.element_size() for p in weights.values()) |
| num_servers = len(state.tensor_servers) or len(Settings.TENSOR_SERVER_URLS) |
| |
| |
| |
| |
| num_chunks = num_servers |
| bytes_per_chunk = math.ceil(total_size_bytes / num_chunks) |
| |
| print(f"[INFO] Total model size: {total_size_bytes / (1024*1024*1024):.2f} GB") |
| print(f"[INFO] Available servers: {num_servers}") |
| print(f"[INFO] Creating {num_chunks} chunks") |
| print(f"[INFO] Target chunk size: {bytes_per_chunk / (1024*1024):.2f} MB") |
| |
| current_chunk = [] |
| current_chunk_size = 0 |
| chunk_id = 0 |
| chunk_sizes = [] |
| |
| |
| sorted_weights = sorted( |
| weights.items(), |
| key=lambda x: x[1].nelement() * x[1].element_size(), |
| reverse=True |
| ) |
| |
| for key, tensor in weights.items(): |
| tensor_size = tensor.numel() |
| |
| |
| tensor_size = tensor.nelement() * tensor.element_size() |
| |
| |
| if (current_chunk_size + tensor_size > bytes_per_chunk and current_chunk) or \ |
| (chunk_id == num_chunks - 1): |
| |
| |
| chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") |
| chunk_weights = {k: weights[k] for k in current_chunk} |
| torch.save(chunk_weights, chunk_path) |
| |
| |
| chunk_total_size = sum(weights[k].nelement() * weights[k].element_size() |
| for k in current_chunk) |
| chunk_sizes.append(chunk_total_size) |
| |
| |
| state.model_chunks[chunk_id] = ModelChunk( |
| chunk_id=chunk_id, |
| files=[f"chunk_{chunk_id}.safetensors"], |
| config={ |
| "weight_keys": current_chunk, |
| "size_bytes": chunk_total_size, |
| "num_parameters": sum(weights[k].nelement() for k in current_chunk), |
| "input_size": weights[current_chunk[0]].size(1) if len(current_chunk) > 0 else 0, |
| "output_size": weights[current_chunk[-1]].size(0) if len(current_chunk) > 0 else 0, |
| |
| "vocab_offset": sum(int(c.config.get('shard_dim', 1)) for c in state.model_chunks.values()), |
| |
| "shard_dim": int(1) |
| } |
| ) |
| |
| print(f"[INFO] Created chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, " |
| f"{len(current_chunk)} tensors") |
| |
| |
| current_chunk = [] |
| current_chunk_size = 0 |
| chunk_id += 1 |
| |
| |
| if chunk_id == num_chunks - 1: |
| remaining_tensors = [k for k, _ in sorted_weights if k not in sum([c.config["weight_keys"] |
| for c in state.model_chunks.values()], [])] |
| current_chunk.extend(remaining_tensors) |
| continue |
| |
| |
| current_chunk.append(key) |
| current_chunk_size += tensor_size |
| |
| |
| if current_chunk: |
| chunk_path = os.path.join(state.model_path, f"chunk_{chunk_id}.safetensors") |
| chunk_weights = {k: weights[k] for k in current_chunk} |
| torch.save(chunk_weights, chunk_path) |
| |
| |
| chunk_total_size = sum(weights[k].nelement() * weights[k].element_size() |
| for k in current_chunk) |
| chunk_sizes.append(chunk_total_size) |
| |
| state.model_chunks[chunk_id] = ModelChunk( |
| chunk_id=chunk_id, |
| files=[f"chunk_{chunk_id}.safetensors"], |
| config={ |
| "weight_keys": current_chunk, |
| "size_bytes": chunk_total_size, |
| "num_parameters": sum(weights[k].nelement() for k in current_chunk), |
| "input_size": weights[current_chunk[0]].size(1), |
| "output_size": weights[current_chunk[-1]].size(0) |
| } |
| ) |
| |
| print(f"[INFO] Created final chunk {chunk_id}: {chunk_total_size / (1024*1024):.2f} MB, " |
| f"{len(current_chunk)} tensors") |
| |
| |
| total_size_actual = sum(chunk_sizes) |
| size_std_dev = torch.tensor(chunk_sizes).std().item() / (1024*1024) |
| size_mean = torch.tensor(chunk_sizes).mean().item() / (1024*1024) |
| |
| print(f"\n[INFO] Distribution Summary:") |
| print(f"- Total model size: {total_size_actual / (1024*1024*1024):.2f} GB") |
| print(f"- Number of chunks: {len(state.model_chunks)}") |
| print(f"- Average chunk size: {size_mean:.2f} MB") |
| print(f"- Chunk size std dev: {size_std_dev:.2f} MB") |
| print(f"- Size variation: {(size_std_dev/size_mean*100):.1f}%") |
| |
| |
| all_distributed = set(sum([c.config["weight_keys"] for c in state.model_chunks.values()], [])) |
| if len(all_distributed) != len(weights): |
| missing = set(weights.keys()) - all_distributed |
| print(f"[WARN] Some weights were not distributed: {missing}") |
| |
| return True |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to split model weights: {str(e)}") |
| return False |
|
|
| async def send_chunk_to_server(server_url: str, chunk_id: int, chunk_info: Dict): |
| """Send a model chunk to a tensor server""" |
| try: |
| print(f"[INFO] Sending chunk {chunk_id} to server {server_url}") |
| chunk_path = os.path.join(state.chunks_dir, f"chunk_{chunk_id}.bin") |
| |
| if not os.path.exists(chunk_path): |
| raise Exception(f"Chunk file not found: {chunk_path}") |
| |
| |
| chunk = state.model_chunks[chunk_id] |
| chunk_data = { |
| 'chunk_id': chunk_id, |
| 'files': [os.path.basename(chunk_path)], |
| 'config': chunk.config |
| } |
| |
| async with aiohttp.ClientSession() as session: |
| |
| async with session.post( |
| f"{server_url}/load_chunk", |
| json=chunk_data, |
| timeout=Settings.TENSOR_SERVER_TIMEOUT |
| ) as response: |
| if response.status != 200: |
| error_msg = await response.text() |
| raise Exception(f"Failed to register chunk: {error_msg}") |
| |
| result = await response.json() |
| if not result.get("ready_for_data", False): |
| raise Exception("Server not ready for chunk data") |
| |
| |
| with open(chunk_path, 'rb') as f: |
| chunk_file = f.read() |
| |
| form = aiohttp.FormData() |
| form.add_field('file', |
| chunk_file, |
| filename=os.path.basename(chunk_path), |
| content_type='application/octet-stream') |
| |
| async with session.post( |
| f"{server_url}/upload_chunk_data/{chunk_id}", |
| data=form, |
| timeout=Settings.TENSOR_SERVER_TIMEOUT |
| ) as upload_response: |
| if upload_response.status != 200: |
| error_msg = await upload_response.text() |
| raise Exception(f"Failed to upload chunk data: {error_msg}") |
| |
| upload_result = await upload_response.json() |
| print(f"[INFO] Successfully uploaded chunk {chunk_id} to {server_url} ({upload_result.get('size_bytes', 0)} bytes)") |
| return True |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to send chunk {chunk_id} to {server_url}: {str(e)}") |
| return False |
|
|
| async def distribute_model_chunks(): |
| """Distribute model chunks across available tensor servers""" |
| try: |
| available_servers = [ |
| server for server in state.tensor_servers.values() |
| if server.status in ["ready", "busy"] and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD |
| ] |
| |
| min_required = Settings.get_min_servers_required() |
| if len(available_servers) < min_required: |
| raise Exception(f"Not enough healthy servers. Need {min_required}, got {len(available_servers)}") |
| |
| |
| if not state.model_chunks or len(state.model_chunks) > len(available_servers) * 3: |
| if not await split_model_weights(): |
| raise Exception("Failed to split model weights") |
| |
| |
| tasks = [] |
| min_replicas = Settings.get_min_replica_count(len(available_servers)) |
| chunks_per_server = len(state.model_chunks) / len(available_servers) |
| print(f"[INFO] Distributing chunks with min {min_replicas} replicas per chunk") |
| print(f"[INFO] Target chunks per server: {chunks_per_server:.1f}") |
| |
| |
| for chunk_id, chunk in state.model_chunks.items(): |
| |
| target_replicas = max(min_replicas, |
| int(chunks_per_server * len(available_servers) / len(state.model_chunks))) |
| |
| current_assignments = set(chunk.server_assignments) |
| current_healthy = [url for url in current_assignments |
| if state.tensor_servers[url].status in ["ready", "busy"]] |
| |
| |
| chunk.server_assignments = current_healthy |
| |
| |
| while len(chunk.server_assignments) < target_replicas: |
| |
| eligible_servers = [ |
| server for server in available_servers |
| if str(server.url) not in chunk.server_assignments |
| and len(server.model_chunks) < (len(state.model_chunks) / len(available_servers) * 1.5) |
| ] |
| |
| if not eligible_servers: |
| break |
| |
| |
| eligible_servers.sort(key=lambda s: ( |
| len(s.model_chunks), |
| s.metrics.error_count, |
| s.metrics.cpu_usage |
| )) |
| |
| |
| best_server = eligible_servers[0] |
| chunk.server_assignments.append(str(best_server.url)) |
| best_server.model_chunks.append(chunk_id) |
| print(f"[INFO] Assigned chunk {chunk_id} to server {best_server.url}") |
| |
| return True |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to distribute model chunks: {str(e)}") |
| return False |
|
|
| async def monitor_tensor_servers(): |
| """Periodically check health and update metrics of all tensor servers""" |
| while True: |
| for server_url, server in state.tensor_servers.items(): |
| try: |
| |
| is_healthy = await check_tensor_server_health(server_url) |
| |
| if not is_healthy: |
| server.status = "error" |
| server.metrics.error_count += 1 |
| print(f"[WARN] Server {server_url} is unhealthy") |
| continue |
| |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.get(f"{server_url}/metrics", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: |
| if response.status == 200: |
| metrics = await response.json() |
| server.metrics = ServerMetrics(**metrics) |
| |
| |
| if server.metrics.error_count > Settings.MAX_ERROR_THRESHOLD: |
| server.status = "degraded" |
| elif server.metrics.cpu_usage > 90 or server.metrics.memory_usage > 90: |
| server.status = "busy" |
| else: |
| server.status = "ready" |
| |
| server.last_heartbeat = datetime.now() |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to monitor server {server_url}: {str(e)}") |
| server.status = "error" |
| server.metrics.last_error = str(e) |
| server.metrics.error_count += 1 |
| |
| |
| current_time = datetime.now() |
| for server_url, server in state.tensor_servers.items(): |
| if (current_time - server.last_heartbeat).seconds > Settings.SERVER_TIMEOUT: |
| print(f"[WARN] Server {server_url} hasn't responded in {Settings.SERVER_TIMEOUT} seconds") |
| server.status = "error" |
| |
| await asyncio.sleep(Settings.MONITORING_INTERVAL) |
|
|
| def get_next_model_version(base_dir: str, model_name: str) -> int: |
| """Get the next available version number for the model""" |
| existing_versions = [] |
| model_base_dir = os.path.join(base_dir, model_name) |
| if os.path.exists(model_base_dir): |
| for d in os.listdir(model_base_dir): |
| if d.startswith('v') and d[1:].isdigit(): |
| existing_versions.append(int(d[1:])) |
| return max(existing_versions + [0]) + 1 |
|
|
| def check_existing_model(model_path: str) -> bool: |
| """Check if a model exists and has required files""" |
| if not os.path.exists(model_path): |
| return False |
| |
| |
| required_files = ['config.json'] |
| model_files = os.listdir(model_path) |
| |
| |
| has_weights = any(f.endswith(('.bin', '.safetensors')) for f in model_files) |
| |
| return all(f in model_files for f in required_files) and has_weights |
|
|
| async def download_model_files(): |
| """Downloads the model files using Hugging Face Hub API""" |
| try: |
| print(f"[INFO] Processing model from {Settings.MODEL_REPO}...") |
| |
| |
| required_packages = ["huggingface_hub", "requests", "tqdm"] |
| for package in required_packages: |
| try: |
| __import__(package) |
| except ImportError: |
| print(f"[INFO] Installing {package}...") |
| import subprocess |
| subprocess.check_call(["pip", "install", package]) |
| |
| from huggingface_hub import hf_hub_download, snapshot_download, HfFolder |
| import requests |
| from tqdm import tqdm |
| |
| |
| models_dir = os.path.join(os.getcwd(), "models") |
| os.makedirs(models_dir, exist_ok=True) |
| print(f"[INFO] Models directory: {models_dir}") |
| |
| |
| repo_id = "/".join(Settings.MODEL_REPO.split('/')[-2:]) |
| model_name = repo_id.split('/')[-1] |
| |
| |
| version = get_next_model_version(models_dir, model_name) |
| model_base_dir = os.path.join(models_dir, model_name) |
| model_version_dir = os.path.join(model_base_dir, f"v{version}") |
| |
| |
| def download_file(url, filename): |
| response = requests.get(url, stream=True) |
| total_size = int(response.headers.get('content-length', 0)) |
| |
| with open(filename, 'wb') as f, tqdm( |
| desc=os.path.basename(filename), |
| total=total_size, |
| unit='iB', |
| unit_scale=True, |
| unit_divisor=1024, |
| ) as pbar: |
| for data in response.iter_content(chunk_size=1024): |
| size = f.write(data) |
| pbar.update(size) |
| |
| |
| if version > 1: |
| prev_version_dir = os.path.join(model_base_dir, f"v{version-1}") |
| if check_existing_model(prev_version_dir): |
| print(f"[INFO] Using existing model from {prev_version_dir}") |
| model_path = prev_version_dir |
| state.is_model_loaded = True |
| else: |
| |
| os.makedirs(model_version_dir, exist_ok=True) |
| model_path = model_version_dir |
| else: |
| |
| os.makedirs(model_version_dir, exist_ok=True) |
| model_path = model_version_dir |
| |
| if not state.is_model_loaded: |
| try: |
| print(f"[INFO] Downloading model files from {repo_id}...") |
| |
| |
| print("[INFO] Downloading all model files (this may take a while)...") |
| |
| |
| |
| model_path = snapshot_download( |
| repo_id=repo_id, |
| local_dir=model_path, |
| allow_patterns=["*.bin", "*.safetensors", "*.json", "*.txt", "tokenizer.model"], |
| ignore_patterns=["*.msgpack", "*.onnx"], |
| force_download=True |
| ) |
| |
| print(f"[INFO] All files downloaded to {model_path}") |
| state.is_model_loaded = True |
| |
| except Exception as e: |
| raise Exception(f"Failed to download model files: {str(e)}") |
| |
| |
| state.model_path = model_path |
| state.chunks_dir = os.path.join(model_path, "chunks") |
| os.makedirs(state.chunks_dir, exist_ok=True) |
| |
| |
| config_path = os.path.join(model_path, "config.json") |
| if os.path.exists(config_path): |
| with open(config_path, 'r') as f: |
| state.model_config = json.load(f) |
| print("[INFO] Loaded model configuration") |
| print(f"[INFO] Model type: {state.model_config.get('model_type', 'unknown')}") |
| print(f"[INFO] Architecture: {state.model_config.get('architectures', ['unknown'])[0]}") |
| else: |
| print("[WARN] No config.json found in model directory") |
| |
| |
| print("[INFO] Scanning for model files...") |
| for root, _, files in os.walk(model_path): |
| for file in files: |
| if file.endswith(('.bin', '.json', '.safetensors')): |
| file_path = os.path.join(root, file) |
| state.model_files[file] = file_path |
| print(f"[INFO] Found model file: {file}") |
| |
| if state.model_files: |
| state.is_model_loaded = True |
| print(f"[INFO] Model files found successfully! Total files: {len(state.model_files)}") |
| print(f"[INFO] Model location: {model_path}") |
| return True |
| else: |
| raise ValueError("No model files were found in the repository") |
| |
| except Exception as e: |
| print(f"[ERROR] Failed to process model files: {e}") |
| state.is_model_loaded = False |
| raise |
|
|
| async def check_tensor_server_health(url: HttpUrl) -> bool: |
| """Checks if a tensor server is healthy""" |
| try: |
| async with aiohttp.ClientSession() as session: |
| async with session.get(f"{url}/health", timeout=Settings.TENSOR_SERVER_TIMEOUT) as response: |
| return response.status == 200 |
| except: |
| return False |
|
|
| |
| async def execute_tensor_operation(operation_id: str, server_url: HttpUrl, operation: str, data: Dict): |
| """Execute an operation on a tensor server and wait for results""" |
| try: |
| async with aiohttp.ClientSession() as session: |
| |
| async with session.post( |
| f"{server_url}/{operation}", |
| json=data, |
| timeout=Settings.TENSOR_SERVER_TIMEOUT |
| ) as response: |
| if response.status != 200: |
| error_msg = await response.text() |
| raise HTTPException( |
| status_code=response.status, |
| detail=f"Operation failed on server {server_url}: {error_msg}" |
| ) |
| |
| initial_response = await response.json() |
| if initial_response.get("status") == "completed": |
| |
| state.operation_results[operation_id] = initial_response |
| return initial_response |
| |
| |
| while True: |
| await asyncio.sleep(1) |
| async with session.get( |
| f"{server_url}/operation/{initial_response['operation_id']}", |
| timeout=Settings.TENSOR_SERVER_TIMEOUT |
| ) as status_response: |
| if status_response.status != 200: |
| raise HTTPException( |
| status_code=status_response.status, |
| detail=f"Failed to get operation status from {server_url}" |
| ) |
| |
| status_data = await status_response.json() |
| if status_data["status"] in ["completed", "failed"]: |
| state.operation_results[operation_id] = status_data |
| if status_data["status"] == "failed": |
| raise HTTPException( |
| status_code=500, |
| detail=f"Operation failed on server {server_url}: {status_data.get('error')}" |
| ) |
| return status_data |
| |
| except asyncio.TimeoutError: |
| raise HTTPException( |
| status_code=504, |
| detail=f"Operation timed out on server {server_url}" |
| ) |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, |
| detail=f"Error executing operation on {server_url}: {str(e)}" |
| ) |
|
|
| @app.post("/execute/{operation}") |
| async def execute_operation(operation: str, data: Dict): |
| """Execute an operation across tensor servers and collect results""" |
| operation_id = f"{operation}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{len(state.operation_results)}" |
| |
| |
| available_servers = [ |
| server for server in state.tensor_servers.values() |
| if server.status in ["ready", "busy"] |
| and server.metrics.error_count < Settings.MAX_ERROR_THRESHOLD |
| ] |
| |
| if not available_servers: |
| raise HTTPException( |
| status_code=503, |
| detail="No available tensor servers" |
| ) |
| |
| |
| tasks = [] |
| for server in available_servers: |
| if operation in ["compute", "forward"]: |
| |
| required_chunks = data.get("required_chunks", []) |
| if not all(chunk_id in server.model_chunks for chunk_id in required_chunks): |
| continue |
| |
| task = asyncio.create_task( |
| execute_tensor_operation( |
| f"{operation_id}_{server.url}", |
| server.url, |
| operation, |
| data |
| ) |
| ) |
| tasks.append(task) |
| state.pending_operations[f"{operation_id}_{server.url}"] = task |
| |
| if not tasks: |
| raise HTTPException( |
| status_code=400, |
| detail="No servers available with required model chunks" |
| ) |
| |
| try: |
| |
| results = await asyncio.gather(*tasks) |
| |
| |
| aggregated_result = { |
| "operation_id": operation_id, |
| "status": "completed", |
| "server_results": results, |
| "timestamp": datetime.now().isoformat() |
| } |
| |
| |
| for task_id in list(state.pending_operations.keys()): |
| if task_id.startswith(operation_id): |
| del state.pending_operations[task_id] |
| |
| return aggregated_result |
| |
| except Exception as e: |
| |
| for task in tasks: |
| if not task.done(): |
| task.cancel() |
| |
| |
| for task_id in list(state.pending_operations.keys()): |
| if task_id.startswith(operation_id): |
| del state.pending_operations[task_id] |
| |
| raise HTTPException( |
| status_code=500, |
| detail=f"Operation failed: {str(e)}" |
| ) |
|
|
| @app.get("/operation/{operation_id}") |
| async def get_operation_status(operation_id: str): |
| """Get the status of an operation""" |
| |
| results = { |
| k: v for k, v in state.operation_results.items() |
| if k.startswith(operation_id) |
| } |
| |
| if results: |
| return { |
| "operation_id": operation_id, |
| "status": "completed", |
| "results": results |
| } |
| |
| |
| pending = { |
| k: "running" for k in state.pending_operations.keys() |
| if k.startswith(operation_id) |
| } |
| |
| if pending: |
| return { |
| "operation_id": operation_id, |
| "status": "running", |
| "pending_servers": list(pending.keys()) |
| } |
| |
| raise HTTPException( |
| status_code=404, |
| detail=f"Operation {operation_id} not found" |
| ) |
|
|
| @app.get("/") |
| async def root(): |
| """Health check endpoint""" |
| return { |
| "status": "running", |
| "model_loaded": state.is_model_loaded, |
| "registered_servers": len(state.tensor_servers), |
| "downloaded_files": len(state.model_files), |
| "config_loaded": bool(state.model_config) |
| } |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Detailed health check""" |
| return { |
| "status": "healthy", |
| "model_loaded": state.is_model_loaded, |
| "registered_servers": len(state.tensor_servers), |
| "downloaded_files": list(state.model_files.keys()), |
| "config_loaded": bool(state.model_config), |
| "model_type": state.model_config.get("model_type", "unknown") |
| } |
|
|
| @app.post("/register_tensor_server") |
| async def register_tensor_server(server_url: HttpUrl): |
| """Register a new tensor server""" |
| if not await check_tensor_server_health(server_url): |
| raise HTTPException(status_code=400, detail="Tensor server is not healthy") |
| |
| state.tensor_servers[str(server_url)] = TensorServer(url=server_url) |
| print(f"[INFO] Registered new tensor server at {server_url}") |
| |
| |
| if state.is_model_loaded: |
| print(f"[INFO] Model is loaded, starting distribution for new server {server_url}") |
| try: |
| |
| if not state.model_chunks: |
| if await split_model_weights(): |
| print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") |
| else: |
| print("[ERROR] Failed to split model weights") |
| |
| |
| if await distribute_model_chunks(): |
| print("[INFO] Successfully distributed chunks to tensor servers") |
| else: |
| print("[ERROR] Failed to distribute chunks") |
| except Exception as e: |
| print(f"[ERROR] Distribution error during server registration: {str(e)}") |
| |
| return { |
| "status": "registered", |
| "registered_servers": len(state.tensor_servers), |
| "server_id": str(server_url), |
| "model_loaded": state.is_model_loaded, |
| "chunks_distributed": len(state.model_chunks) if state.model_chunks else 0 |
| } |
|
|
| @app.delete("/unregister_tensor_server") |
| async def unregister_tensor_server(server_url: HttpUrl): |
| """Unregister a tensor server""" |
| if str(server_url) in state.tensor_servers: |
| |
| for chunk in state.model_chunks.values(): |
| if str(server_url) in chunk.server_assignments: |
| chunk.server_assignments.remove(str(server_url)) |
| |
| del state.tensor_servers[str(server_url)] |
| print(f"[INFO] Unregistered tensor server at {server_url}") |
| |
| |
| await distribute_model_chunks() |
| return {"status": "unregistered"} |
| raise HTTPException(status_code=404, detail="Server not found") |
|
|
| @app.get("/server/{server_url}/chunks") |
| async def get_server_chunks(server_url: HttpUrl): |
| """Get the chunks assigned to a specific server""" |
| if str(server_url) not in state.tensor_servers: |
| raise HTTPException(status_code=404, detail="Server not found") |
| |
| server = state.tensor_servers[str(server_url)] |
| assigned_chunks = [ |
| state.model_chunks[chunk_id] |
| for chunk_id in server.model_chunks |
| ] |
| |
| return { |
| "server_status": server.status, |
| "assigned_chunks": assigned_chunks, |
| "metrics": server.metrics.dict() |
| } |
|
|
| @app.post("/redistribute") |
| async def redistribute_chunks(): |
| """Manually trigger redistribution of model chunks""" |
| success = await distribute_model_chunks() |
| if not success: |
| raise HTTPException(status_code=500, detail="Failed to redistribute chunks") |
| |
| return { |
| "status": "redistributed", |
| "chunk_assignments": { |
| chunk_id: chunk.server_assignments |
| for chunk_id, chunk in state.model_chunks.items() |
| } |
| } |
|
|
| @app.get("/chunks/{chunk_id}/status") |
| async def get_chunk_status(chunk_id: int): |
| """Get the status and assignments of a specific chunk""" |
| if chunk_id not in state.model_chunks: |
| raise HTTPException(status_code=404, detail="Chunk not found") |
| |
| chunk = state.model_chunks[chunk_id] |
| return { |
| "chunk_id": chunk_id, |
| "status": chunk.status, |
| "server_assignments": chunk.server_assignments, |
| "metrics": chunk.metrics |
| } |
|
|
| @app.post("/initialize") |
| async def initialize_system(): |
| """Download model files and prepare for distribution""" |
| await download_model_files() |
| |
| |
| files_status = {} |
| total_size = 0 |
| for filename, filepath in state.model_files.items(): |
| exists = os.path.exists(filepath) |
| if exists: |
| size = os.path.getsize(filepath) |
| total_size += size |
| files_status[filename] = {"exists": exists, "size_bytes": size} |
| else: |
| files_status[filename] = {"exists": exists, "size_bytes": 0} |
| |
| |
| distribution_status = "not_started" |
| if state.tensor_servers: |
| print("[INFO] Starting automatic model distribution...") |
| try: |
| |
| if await split_model_weights(): |
| print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") |
| |
| if await distribute_model_chunks(): |
| print("[INFO] Successfully distributed chunks to tensor servers") |
| distribution_status = "completed" |
| else: |
| print("[ERROR] Failed to distribute chunks") |
| distribution_status = "distribution_failed" |
| else: |
| print("[ERROR] Failed to split model weights") |
| distribution_status = "split_failed" |
| except Exception as e: |
| print(f"[ERROR] Distribution error: {str(e)}") |
| distribution_status = f"error: {str(e)}" |
| else: |
| print("[INFO] No tensor servers registered yet. Will distribute when servers register.") |
| |
| return { |
| "status": "initialized", |
| "model_loaded": state.is_model_loaded, |
| "files_status": files_status, |
| "total_size_bytes": total_size, |
| "config_loaded": bool(state.model_config), |
| "model_type": state.model_config.get("model_type", "unknown"), |
| "architecture": state.model_config.get("architectures", ["unknown"])[0], |
| "distribution_status": distribution_status, |
| "registered_servers": len(state.tensor_servers), |
| "chunks_created": len(state.model_chunks) if state.model_chunks else 0 |
| } |
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize the server and start distribution""" |
| print("[INFO] Initializing system...") |
| try: |
| |
| await initialize_system() |
| print("[INFO] Model initialization complete") |
| |
| |
| if await split_model_weights(): |
| print(f"[INFO] Successfully split model into {len(state.model_chunks)} chunks") |
| |
| |
| print("[INFO] Starting chunk distribution...") |
| distribution_tasks = [] |
| |
| |
| for chunk_id, chunk in state.model_chunks.items(): |
| |
| server_index = chunk_id % len(Settings.TENSOR_SERVER_URLS) |
| server_url = Settings.TENSOR_SERVER_URLS[server_index] |
| |
| task = asyncio.create_task( |
| send_chunk_to_server(server_url, chunk_id, {"chunk_id": chunk_id}) |
| ) |
| distribution_tasks.append(task) |
| print(f"[INFO] Sending chunk {chunk_id} to {server_url}") |
| |
| try: |
| chunk.server_assignments.append(server_url) |
| except Exception: |
| pass |
| |
| if distribution_tasks: |
| print(f"[INFO] Distributing {len(distribution_tasks)} chunks...") |
| results = await asyncio.gather(*distribution_tasks, return_exceptions=True) |
| success_count = sum(1 for r in results if r is True) |
| print(f"[INFO] Successfully distributed {success_count} chunks out of {len(distribution_tasks)} attempts") |
| else: |
| print("[ERROR] Failed to split model weights") |
| |
| except Exception as e: |
| print(f"[ERROR] Startup error: {str(e)}") |
| |
| print("[INFO] Startup complete") |
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", 8000)) |
| print(f"[INFO] Starting controller server on port {port}") |
| print(f"[INFO] API Documentation available at http://localhost:{port}/docs") |
| |
| uvicorn.run( |
| "controller_server_new:app", |
| host="0.0.0.0", |
| port=port, |
| reload=False |
| ) |