| """
|
| code-chef ModelOps Training Client
|
|
|
| Usage from agent_orchestrator/agents/infrastructure/modelops/training.py
|
| """
|
|
|
| import base64
|
| import time
|
| from typing import Any, Dict, Optional
|
|
|
| import requests
|
| from langsmith import traceable
|
|
|
|
|
| class ModelOpsTrainerClient:
|
| """Client for code-chef ModelOps HuggingFace Space"""
|
|
|
| def __init__(self, space_url: str, hf_token: Optional[str] = None):
|
| """
|
| Initialize client.
|
|
|
| Args:
|
| space_url: HF Space URL (e.g., https://appsmithery-code-chef-modelops-trainer.hf.space)
|
| hf_token: Optional HF token for private spaces
|
| """
|
| self.space_url = space_url.rstrip("/")
|
| self.hf_token = hf_token
|
| self.headers = {"Content-Type": "application/json"}
|
| if hf_token:
|
| self.headers["Authorization"] = f"Bearer {hf_token}"
|
|
|
| @traceable(name="submit_training_job")
|
| def submit_training_job(
|
| self,
|
| agent_name: str,
|
| base_model: str,
|
| dataset_csv_path: str,
|
| training_method: str = "sft",
|
| demo_mode: bool = False,
|
| config_overrides: Optional[Dict[str, Any]] = None,
|
| ) -> Dict[str, str]:
|
| """
|
| Submit a training job to the HF Space.
|
|
|
| Args:
|
| agent_name: Agent to train (feature_dev, code_review, etc.)
|
| base_model: HF model repo (e.g., Qwen/Qwen2.5-Coder-7B)
|
| dataset_csv_path: Path to CSV file with text/response columns
|
| training_method: sft, dpo, or reward
|
| demo_mode: If True, runs quick validation (100 examples, 1 epoch)
|
| config_overrides: Optional training config overrides
|
|
|
| Returns:
|
| Dict with job_id, status, message
|
| """
|
|
|
| with open(dataset_csv_path, "rb") as f:
|
| csv_content = f.read()
|
| encoded_csv = base64.b64encode(csv_content).decode()
|
|
|
|
|
| response = requests.post(
|
| f"{self.space_url}/train",
|
| headers=self.headers,
|
| json={
|
| "agent_name": agent_name,
|
| "base_model": base_model,
|
| "dataset_csv": encoded_csv,
|
| "training_method": training_method,
|
| "demo_mode": demo_mode,
|
| "config_overrides": config_overrides or {},
|
| },
|
| )
|
| response.raise_for_status()
|
| return response.json()
|
|
|
| @traceable(name="get_job_status")
|
| def get_job_status(self, job_id: str) -> Dict[str, Any]:
|
| """
|
| Get status of a training job.
|
|
|
| Args:
|
| job_id: Job ID returned from submit_training_job
|
|
|
| Returns:
|
| Dict with job status, progress, metrics
|
| """
|
| response = requests.get(
|
| f"{self.space_url}/status/{job_id}", headers=self.headers
|
| )
|
| response.raise_for_status()
|
| return response.json()
|
|
|
| @traceable(name="wait_for_completion")
|
| def wait_for_completion(
|
| self,
|
| job_id: str,
|
| poll_interval: int = 60,
|
| timeout: int = 7200,
|
| callback: Optional[callable] = None,
|
| ) -> Dict[str, Any]:
|
| """
|
| Wait for training job to complete.
|
|
|
| Args:
|
| job_id: Job ID to monitor
|
| poll_interval: Seconds between status checks
|
| timeout: Maximum seconds to wait
|
| callback: Optional callback(status_dict) called on each poll
|
|
|
| Returns:
|
| Final job status dict
|
|
|
| Raises:
|
| TimeoutError: If job doesn't complete within timeout
|
| RuntimeError: If job fails
|
| """
|
| elapsed = 0
|
| while elapsed < timeout:
|
| status = self.get_job_status(job_id)
|
|
|
| if callback:
|
| callback(status)
|
|
|
| if status["status"] == "completed":
|
| return status
|
| elif status["status"] == "failed":
|
| raise RuntimeError(f"Training job failed: {status.get('error')}")
|
|
|
| time.sleep(poll_interval)
|
| elapsed += poll_interval
|
|
|
| raise TimeoutError(f"Training job {job_id} did not complete within {timeout}s")
|
|
|
| def health_check(self) -> Dict[str, Any]:
|
| """
|
| Check if the Space is healthy.
|
|
|
| Returns:
|
| Dict with health status
|
| """
|
| response = requests.get(f"{self.space_url}/health", headers=self.headers)
|
| response.raise_for_status()
|
| return response.json()
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| import os
|
|
|
| client = ModelOpsTrainerClient(
|
| space_url="https://appsmithery-code-chef-modelops-trainer.hf.space",
|
| hf_token=os.environ.get("HF_TOKEN"),
|
| )
|
|
|
|
|
| health = client.health_check()
|
| print(f"Health: {health}")
|
|
|
|
|
| result = client.submit_training_job(
|
| agent_name="feature_dev",
|
| base_model="Qwen/Qwen2.5-Coder-7B",
|
| dataset_csv_path="/tmp/demo_dataset.csv",
|
| training_method="sft",
|
| demo_mode=True,
|
| )
|
|
|
| job_id = result["job_id"]
|
| print(f"Job submitted: {job_id}")
|
|
|
|
|
| def progress_callback(status):
|
| print(f"Status: {status['status']}, Progress: {status.get('progress_pct', 0)}%")
|
|
|
| final_status = client.wait_for_completion(
|
| job_id=job_id, poll_interval=30, callback=progress_callback
|
| )
|
|
|
| print(f"Training complete! Model: {final_status['hub_repo']}")
|
|
|