| """ |
| Shared Configuration Classes for TD3B Finetuning |
| |
| This module contains all configuration dataclasses used by both: |
| - finetune_v1.py (single-target training) |
| - finetune_multi_target.py (multi-target training) |
| |
| Extracted to avoid code duplication and ensure consistency. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Optional |
|
|
|
|
| @dataclass |
| class RoFormerConfig: |
| """Configuration for RoFormer model architecture.""" |
| hidden_size: int |
| n_layers: int |
| n_heads: int |
| max_position_embeddings: int = 1035 |
|
|
|
|
| @dataclass(frozen=True) |
| class NoiseConfig: |
| """Configuration for noise scheduling.""" |
| type: str = 'loglinear' |
| sigma_min: float = 1e-4 |
| sigma_max: float = 20.0 |
|
|
|
|
| @dataclass(frozen=True) |
| class TrainingConfig: |
| """Configuration for training parameters.""" |
| sampling_eps: float |
|
|
|
|
| @dataclass(frozen=True) |
| class SamplingConfig: |
| """Configuration for sampling parameters.""" |
| steps: int |
| sampling_eps: float |
| predictor: str = 'ddpm_cache' |
|
|
|
|
| @dataclass(frozen=True) |
| class EvalConfig: |
| """Configuration for evaluation parameters.""" |
| gen_ppl_eval_model_name_or_path: str = 'gpt2-large' |
|
|
|
|
| @dataclass(frozen=True) |
| class OptimConfig: |
| """Configuration for optimizer parameters.""" |
| lr: float |
|
|
|
|
| @dataclass(frozen=True) |
| class MCTSConfig: |
| """Configuration for MCTS parameters.""" |
| sampling: int = 0 |
|
|
|
|
| class DiffusionConfig: |
| """ |
| Complete configuration for Diffusion model. |
| |
| This class encapsulates all nested configuration objects required |
| by the Diffusion model, providing a clean interface and type safety. |
| """ |
|
|
| def __init__( |
| self, |
| roformer: RoFormerConfig, |
| noise: NoiseConfig, |
| training: TrainingConfig, |
| sampling: SamplingConfig, |
| eval_cfg: EvalConfig, |
| optim: OptimConfig, |
| mcts: MCTSConfig |
| ): |
| |
| self.roformer = type('RoFormerObj', (), { |
| 'hidden_size': roformer.hidden_size, |
| 'n_layers': roformer.n_layers, |
| 'n_heads': roformer.n_heads, |
| 'max_position_embeddings': roformer.max_position_embeddings |
| })() |
|
|
| self.noise = type('NoiseObj', (), { |
| 'type': noise.type, |
| 'sigma_min': noise.sigma_min, |
| 'sigma_max': noise.sigma_max |
| })() |
|
|
| self.training = type('TrainingObj', (), { |
| 'sampling_eps': training.sampling_eps |
| })() |
|
|
| self.sampling = type('SamplingObj', (), { |
| 'steps': sampling.steps, |
| 'sampling_eps': sampling.sampling_eps, |
| 'predictor': sampling.predictor |
| })() |
|
|
| self.eval = type('EvalObj', (), { |
| 'gen_ppl_eval_model_name_or_path': eval_cfg.gen_ppl_eval_model_name_or_path |
| })() |
|
|
| self.optim = type('OptimObj', (), { |
| 'lr': optim.lr |
| })() |
|
|
| self.mcts = type('MCTSObj', (), { |
| 'sampling': mcts.sampling |
| })() |
|
|
| |
| self.backbone = 'roformer' |
| self.parameterization = 'subs' |
| self.time_conditioning = False |
| self.T = 0 |
|
|