| from typing import Dict, Optional |
| import os |
| from zoneinfo import ZoneInfo |
|
|
| import mlflow |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from datetime import datetime, date |
|
|
|
|
| class TrainerLogger: |
| def __init__( |
| self, |
| tracking_uri: str, |
| experiment: str, |
| total_params: int, |
| model_name: str = None, |
| run_name: str = None, |
| tags: Dict[str, str] = None, |
| ): |
| mlflow.set_tracking_uri(tracking_uri) |
| mlflow.set_experiment(experiment) |
|
|
| |
| mlflow.pytorch.autolog(log_models=True) |
|
|
| |
| self.run = mlflow.start_run(run_name=run_name) |
| self.run_id = self.run.info.run_id |
| self.experiment = experiment |
| self.model_name = model_name |
| self.total_params = total_params |
|
|
| |
| default_tags = {"model_type": self.model_name} |
| if tags: |
| default_tags.update(tags) |
| mlflow.set_tags(default_tags) |
|
|
| |
| base_params = {"model_name": self.model_name, "total_params": self.total_params} |
| self.log_parameters(base_params) |
|
|
| def log_parameters(self, parameters: dict): |
| mlflow.log_params(parameters) |
|
|
| def log_metrics(self, metrics: dict, step: Optional[int] = None): |
| mlflow.log_metrics(metrics, step) |
|
|
| def log_checkpoint_table(self, current_lr:float, loss:float, perplexity: float, last_batch:int) -> None: |
| """ |
| Log a checkpoint record (month, day, hour, perplexity) to MLflow as a table artifact. |
| Perplexity is rounded to 4 decimal places. |
| |
| Parameters |
| ---------- |
| perplexity : float |
| The perplexity metric to log (rounded to 4 decimal places). |
| :param current_lr: |
| :param loss: |
| :param perplexity: |
| :param last_batch: |
| """ |
| |
| artifact_dir = f"checkpoint_table/model" |
| os.makedirs(artifact_dir, exist_ok=True) |
|
|
| |
| now = datetime.now(ZoneInfo("America/Sao_Paulo")) |
| record = { |
| "month": now.month, |
| "day": now.day, |
| "hour": f"{now.hour:02d}:{now.minute:02d}", |
| "last_batch": last_batch, |
| "current_lr": round(current_lr, 7), |
| "perplexity": round(perplexity, 4), |
| "loss": round(loss, 4), |
|
|
| } |
| df_record = pd.DataFrame([record]) |
|
|
| |
| artifact_file = f"{artifact_dir}/checkpoint_table.json" |
|
|
| |
| mlflow.log_table( |
| data=df_record, |
| artifact_file=artifact_file |
| ) |
|
|
| def checkpoint_model(self, model: nn.Module): |
| |
| step = 1 |
| checkpoint_dir = f"checkpoints/model_{step}" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
| |
| checkpoint_path = os.path.join(checkpoint_dir, "model.pth") |
| torch.save(model.state_dict(), checkpoint_path) |
|
|
| |
| mlflow.log_artifact(checkpoint_path, f"model_checkpoints/epoch_{step}") |
|
|
| input_example = torch.zeros(1, 128, dtype=torch.long) |
| |
|
|
| |
| if self.model_name: |
| registered_model_name = f"{self.model_name}" |
| mlflow.pytorch.log_model( |
| pytorch_model=model, |
| artifact_path=f"models/epoch_{step}", |
| registered_model_name=registered_model_name, |
| pip_requirements=["torch>=1.9.0"], |
| code_paths=["tynerox/"], |
| |
| signature=None |
| ) |
|
|
| table_dict = { |
| "entrada": ["Pergunta A", "Pergunta B"], |
| "saida": ["Resposta A", "Resposta B"], |
| "nota": [0.75, 0.40], |
| } |
|
|
| def log_html(self, html: str, step: Optional[int] = None): |
| file_path = f"visualizations/sample.html" |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
|
| with open(file_path, "w") as f: |
| f.write(html) |
|
|
| mlflow.log_artifact(file_path) |
|
|
| def finish(self): |
| """Finaliza a execução do MLflow run""" |
| mlflow.end_run() |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.finish() |
|
|