| """
|
| Coordinator for distributed model training.
|
| """
|
| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from typing import Dict, List, Any, Optional
|
| import asyncio
|
| import logging
|
| from huggingface_hub import snapshot_download
|
| import os
|
| import ray
|
| from .couchdb_client import CouchDBClient
|
| from .config import settings
|
| from .tensor_ops import TensorOps
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| class Coordinator:
|
| """Coordinator for distributed training of OpenPeerLLM."""
|
|
|
| def __init__(self):
|
| self.db_client = CouchDBClient()
|
| self.model_id = settings.MODEL_ID
|
| self.batch_size = settings.BATCH_SIZE
|
| self.gradient_accumulation_steps = settings.GRADIENT_ACCUMULATION_STEPS
|
| self._initialize_model()
|
|
|
| def _initialize_model(self):
|
| """Initialize the model and tokenizer."""
|
| try:
|
|
|
| cache_dir = snapshot_download(self.model_id)
|
| self.model = AutoModelForCausalLM.from_pretrained(cache_dir)
|
| self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
|
|
|
|
| initial_state = {
|
| 'model_state': self.model.state_dict(),
|
| 'step': 0,
|
| 'epoch': 0
|
| }
|
| self.db_client.store_model_state(initial_state)
|
|
|
| except Exception as e:
|
| logger.error(f"Failed to initialize model: {e}")
|
| raise
|
|
|
| async def coordinate_training(self, training_config: Dict[str, Any]):
|
| """Coordinate distributed training across agents."""
|
| try:
|
| num_epochs = training_config.get('num_epochs', 1)
|
| steps_per_epoch = training_config.get('steps_per_epoch', 100)
|
|
|
| for epoch in range(num_epochs):
|
| logger.info(f"Starting epoch {epoch}")
|
| await self._train_epoch(epoch, steps_per_epoch)
|
|
|
|
|
| self._save_checkpoint(epoch)
|
| except Exception as e:
|
| logger.error(f"Training coordination error: {e}")
|
| raise
|
|
|
| async def _train_epoch(self, epoch: int, steps_per_epoch: int):
|
| """Train for one epoch."""
|
| for step in range(steps_per_epoch):
|
|
|
| active_agents = self.db_client.get_active_agents()
|
| if not active_agents:
|
| logger.warning("No active agents available")
|
| await asyncio.sleep(5)
|
| continue
|
|
|
|
|
| gradient_jobs = await self._distribute_gradient_computation(
|
| active_agents,
|
| self.batch_size
|
| )
|
|
|
|
|
| gradients = await self._collect_gradients(gradient_jobs)
|
| if gradients:
|
|
|
| self._update_model_parameters(gradients)
|
|
|
|
|
| await self._distribute_model_update()
|
|
|
| async def _distribute_gradient_computation(
|
| self,
|
| agents: List[Dict[str, Any]],
|
| batch_size: int
|
| ) -> List[str]:
|
| """Distribute gradient computation jobs to available agents."""
|
| job_ids = []
|
|
|
|
|
| current_state = self.db_client.get_latest_model_state()
|
| if not current_state:
|
| raise RuntimeError("No model state available")
|
|
|
|
|
| for agent in agents:
|
| job_id = self.db_client.create_job(
|
| 'gradient_computation',
|
| {
|
| 'batch_size': batch_size,
|
| 'state': current_state['state']
|
| }
|
| )
|
| job_ids.append(job_id)
|
|
|
| return job_ids
|
|
|
| async def _collect_gradients(self, job_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
|
| """Collect gradients from completed jobs."""
|
| all_gradients = []
|
| timeout = 300
|
|
|
| async def wait_for_job(job_id: str) -> Optional[Dict[str, Any]]:
|
| start_time = asyncio.get_event_time()
|
| while True:
|
| if asyncio.get_event_time() - start_time > timeout:
|
| logger.warning(f"Job {job_id} timed out")
|
| return None
|
|
|
| job = self.db_client.get_job(job_id)
|
| if job['status'] == 'completed':
|
| gradient_id = job['result']['gradient_id']
|
| return self.db_client.get_gradients(gradient_id)
|
| elif job['status'] == 'failed':
|
| logger.error(f"Job {job_id} failed: {job.get('result', {}).get('error')}")
|
| return None
|
|
|
| await asyncio.sleep(1)
|
|
|
|
|
| gradient_tasks = [wait_for_job(job_id) for job_id in job_ids]
|
| gradients = await asyncio.gather(*gradient_tasks)
|
|
|
|
|
| return [g for g in gradients if g is not None]
|
|
|
| def _update_model_parameters(self, gradients: List[Dict[str, Any]]):
|
| """Update model parameters with collected gradients."""
|
| try:
|
|
|
| avg_gradients = TensorOps.average_gradients([
|
| {k: torch.tensor(v) for k, v in g.items()}
|
| for g in gradients
|
| ])
|
|
|
|
|
| clipped_gradients = TensorOps.gradient_clipping(avg_gradients, max_norm=1.0)
|
|
|
|
|
| with torch.no_grad():
|
| for name, param in self.model.named_parameters():
|
| if name in clipped_gradients:
|
| param.sub_(clipped_gradients[name] * self.model.config.learning_rate)
|
|
|
| except Exception as e:
|
| logger.error(f"Error updating model parameters: {e}")
|
| raise
|
|
|
| async def _distribute_model_update(self):
|
| """Distribute updated model state to all agents."""
|
| try:
|
|
|
| state = {
|
| 'model_state': self.model.state_dict(),
|
| 'timestamp': datetime.utcnow().isoformat()
|
| }
|
| state_id = self.db_client.store_model_state(state)
|
|
|
|
|
| active_agents = self.db_client.get_active_agents()
|
| for agent in active_agents:
|
| self.db_client.create_job(
|
| 'model_update',
|
| {
|
| 'state_id': state_id,
|
| 'state': state
|
| }
|
| )
|
|
|
| except Exception as e:
|
| logger.error(f"Error distributing model update: {e}")
|
| raise
|
|
|
| def _save_checkpoint(self, epoch: int):
|
| """Save a checkpoint of the current model state."""
|
| try:
|
| checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints')
|
| os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
| checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
| torch.save({
|
| 'epoch': epoch,
|
| 'model_state_dict': self.model.state_dict(),
|
| 'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self, 'optimizer') else None
|
| }, checkpoint_path)
|
|
|
| logger.info(f"Saved checkpoint for epoch {epoch}")
|
|
|
| except Exception as e:
|
| logger.error(f"Error saving checkpoint: {e}")
|
| raise |