| import importlib |
| import logging |
| import os |
|
|
| import torch |
| from src.data.containers import BatchTimeSeriesContainer |
| from src.data.utils import sample_future_length |
| from src.plotting.plot_timeseries import plot_from_container |
| from src.synthetic_generation.anomalies.anomaly_generator_wrapper import ( |
| AnomalyGeneratorWrapper, |
| ) |
| from src.synthetic_generation.cauker.cauker_generator_wrapper import ( |
| CauKerGeneratorWrapper, |
| ) |
| from src.synthetic_generation.forecast_pfn_prior.forecast_pfn_generator_wrapper import ( |
| ForecastPFNGeneratorWrapper, |
| ) |
| from src.synthetic_generation.generator_params import ( |
| AnomalyGeneratorParams, |
| CauKerGeneratorParams, |
| FinancialVolatilityAudioParams, |
| ForecastPFNGeneratorParams, |
| GPGeneratorParams, |
| KernelGeneratorParams, |
| MultiScaleFractalAudioParams, |
| NetworkTopologyAudioParams, |
| OrnsteinUhlenbeckProcessGeneratorParams, |
| SawToothGeneratorParams, |
| SineWaveGeneratorParams, |
| SpikesGeneratorParams, |
| StepGeneratorParams, |
| StochasticRhythmAudioParams, |
| ) |
| from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper |
| from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import ( |
| KernelGeneratorWrapper, |
| ) |
| from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import ( |
| OrnsteinUhlenbeckProcessGeneratorWrapper, |
| ) |
| from src.synthetic_generation.sawtooth.sawtooth_generator_wrapper import ( |
| SawToothGeneratorWrapper, |
| ) |
| from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import ( |
| SineWaveGeneratorWrapper, |
| ) |
| from src.synthetic_generation.spikes.spikes_generator_wrapper import ( |
| SpikesGeneratorWrapper, |
| ) |
| from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper |
|
|
| PYO_AVAILABLE = False |
| spec = importlib.util.find_spec("pyo") |
| if spec is not None: |
| try: |
| _pyo = importlib.import_module("pyo") |
| except (ImportError, OSError): |
| PYO_AVAILABLE = False |
| else: |
| PYO_AVAILABLE = True |
|
|
| if PYO_AVAILABLE: |
| from src.synthetic_generation.audio_generators.financial_volatility_wrapper import ( |
| FinancialVolatilityAudioWrapper, |
| ) |
| from src.synthetic_generation.audio_generators.multi_scale_fractal_wrapper import ( |
| MultiScaleFractalAudioWrapper, |
| ) |
| from src.synthetic_generation.audio_generators.network_topology_wrapper import ( |
| NetworkTopologyAudioWrapper, |
| ) |
| from src.synthetic_generation.audio_generators.stochastic_rhythm_wrapper import ( |
| StochasticRhythmAudioWrapper, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def visualize_batch_sample( |
| generator, |
| batch_size: int = 8, |
| output_dir: str = "outputs/plots", |
| sample_idx: int | None = None, |
| prefix: str = "", |
| seed: int | None = None, |
| ) -> None: |
| os.makedirs(output_dir, exist_ok=True) |
| name = generator.__class__.__name__ |
| logger.info(f"[{name}] Generating batch of size {batch_size}") |
|
|
| batch = generator.generate_batch(batch_size=batch_size, seed=seed) |
| values = torch.from_numpy(batch.values) |
| if values.ndim == 2: |
| values = values.unsqueeze(-1) |
|
|
| future_length = sample_future_length(range="gift_eval") |
| history_values = values[:, :-future_length, :] |
| future_values = values[:, -future_length:, :] |
|
|
| container = BatchTimeSeriesContainer( |
| history_values=history_values, |
| future_values=future_values, |
| start=batch.start, |
| frequency=batch.frequency, |
| ) |
|
|
| indices = [sample_idx] if sample_idx is not None else range(batch_size) |
| for i in indices: |
| filename = f"{prefix}_{name.lower().replace('generatorwrapper', '')}_sample_{i}.png" |
| output_file = os.path.join(output_dir, filename) |
| title = f"{prefix.capitalize()} {name.replace('GeneratorWrapper', '')} Synthetic Series (Sample {i})" |
| plot_from_container(container, sample_idx=i, output_file=output_file, show=False, title=title) |
| logger.info(f"[{name}] Saved plot to {output_file}") |
|
|
|
|
| def generator_factory(global_seed: int, total_length: int) -> list: |
| generators = [ |
| KernelGeneratorWrapper(KernelGeneratorParams(global_seed=global_seed, length=total_length)), |
| GPGeneratorWrapper(GPGeneratorParams(global_seed=global_seed, length=total_length)), |
| ForecastPFNGeneratorWrapper(ForecastPFNGeneratorParams(global_seed=global_seed, length=total_length)), |
| SineWaveGeneratorWrapper(SineWaveGeneratorParams(global_seed=global_seed, length=total_length)), |
| SawToothGeneratorWrapper(SawToothGeneratorParams(global_seed=global_seed, length=total_length)), |
| StepGeneratorWrapper(StepGeneratorParams(global_seed=global_seed, length=total_length)), |
| AnomalyGeneratorWrapper(AnomalyGeneratorParams(global_seed=global_seed, length=total_length)), |
| SpikesGeneratorWrapper(SpikesGeneratorParams(global_seed=global_seed, length=total_length)), |
| CauKerGeneratorWrapper(CauKerGeneratorParams(global_seed=global_seed, length=total_length, num_channels=5)), |
| OrnsteinUhlenbeckProcessGeneratorWrapper( |
| OrnsteinUhlenbeckProcessGeneratorParams(global_seed=global_seed, length=total_length) |
| ), |
| ] |
|
|
| if PYO_AVAILABLE: |
| generators.extend( |
| [ |
| StochasticRhythmAudioWrapper(StochasticRhythmAudioParams(global_seed=global_seed, length=total_length)), |
| FinancialVolatilityAudioWrapper( |
| FinancialVolatilityAudioParams(global_seed=global_seed, length=total_length) |
| ), |
| MultiScaleFractalAudioWrapper( |
| MultiScaleFractalAudioParams(global_seed=global_seed, length=total_length) |
| ), |
| NetworkTopologyAudioWrapper(NetworkTopologyAudioParams(global_seed=global_seed, length=total_length)), |
| ] |
| ) |
| else: |
| logger.warning("Audio generators skipped (pyo not available)") |
|
|
| return generators |
|
|
|
|
| if __name__ == "__main__": |
| batch_size = 2 |
| total_length = 2048 |
| output_dir = "outputs/plots" |
| global_seed = 2025 |
|
|
| logger.info(f"Saving plots to {output_dir}") |
|
|
| for gen in generator_factory(global_seed, total_length): |
| prefix = "multivariate" if getattr(gen.params, "num_channels", 1) > 1 else "" |
| visualize_batch_sample( |
| gen, |
| batch_size=batch_size, |
| output_dir=output_dir, |
| prefix=prefix, |
| seed=global_seed, |
| ) |
|
|