"""Hugging Face AutoModel wrapper for consumer-facing PK runtime bundles.""" from __future__ import annotations from typing import Any, Dict, Optional, Sequence, Union import torch from transformers import PreTrainedModel from sim_priors_pk.data.data_empirical.json_schema import StudyJSON from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig from sim_priors_pk.hub_runtime.runtime_contract import ( RuntimeBuilderConfig, build_batch_from_studies, infer_supported_tasks, instantiate_backbone_from_hub_config, normalize_studies_input, split_runtime_samples, validate_studies_for_task, ) from sim_priors_pk.models.amortized_inference.generative_pk import ( NewGenerativeMixin, NewPredictiveMixin, ) class PKHubModel(PreTrainedModel): """Thin wrapper exposing a stable StudyJSON runtime API on top of PK models.""" config_class = PKHubConfig base_model_prefix = "backbone" def __init__(self, config: PKHubConfig, backbone: Optional[torch.nn.Module] = None) -> None: super().__init__(config) self.backbone = backbone if backbone is not None else instantiate_backbone_from_hub_config(config) self.backbone.eval() def forward(self, *args, **kwargs): """Delegate raw forward calls to the wrapped PK backbone.""" return self.backbone(*args, **kwargs) @property def supported_tasks(self) -> Sequence[str]: """Tasks supported by this runtime model.""" return tuple(getattr(self.config, "supported_tasks", []) or infer_supported_tasks(self.backbone)) @torch.inference_mode() def run_task( self, *, task: str, studies: Union[StudyJSON, Sequence[StudyJSON]], num_samples: int = 1, **kwargs: Any, ) -> Dict[str, Any]: """Run the public StudyJSON inference contract for the requested task.""" supported_tasks = list(self.supported_tasks) if task not in supported_tasks: raise ValueError( f"Unsupported task {task!r}. Supported tasks: {supported_tasks or 'none'}." ) if int(num_samples) < 1: raise ValueError("num_samples must be >= 1.") canonical_studies = normalize_studies_input(studies) builder_config = RuntimeBuilderConfig.from_dict(self.config.builder_config) validate_studies_for_task(canonical_studies, task=task, builder_config=builder_config) experiment_config_payload = getattr(self.config, "experiment_config", {}) meta_dosing_payload = experiment_config_payload.get("dosing", {}) batch = build_batch_from_studies( canonical_studies, builder_config=builder_config, meta_dosing=self.backbone.meta_dosing.__class__(**meta_dosing_payload) if meta_dosing_payload else self.backbone.meta_dosing, ) batch = batch.to(self.device) if task == "generate": if not isinstance(self.backbone, NewGenerativeMixin): raise ValueError(f"Backbone {type(self.backbone).__name__} does not support generate.") output_studies = self.backbone.sample_new_individuals_to_studyjson( batch, sample_size=int(num_samples), num_steps=kwargs.get("num_steps"), ) elif task == "predict": if not isinstance(self.backbone, NewPredictiveMixin): raise ValueError(f"Backbone {type(self.backbone).__name__} does not support predict.") output_studies = self.backbone.sample_individual_prediction_from_batch_list_to_studyjson( [batch], sample_size=int(num_samples), )[0] else: raise ValueError(f"Unsupported task {task!r}.") results = [ { "input_index": index, "samples": split_runtime_samples(task, study), } for index, study in enumerate(output_studies) ] return { "task": task, "io_schema_version": self.config.io_schema_version, "model_info": { "architecture_name": self.config.architecture_name, "experiment_type": self.config.experiment_type, "supported_tasks": supported_tasks, "runtime_repo_id": self.config.runtime_repo_id, "original_repo_id": self.config.original_repo_id, }, "results": results, } __all__ = ["PKHubModel"]