AICME-runtime / modeling_sim_priors_pk.py
cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
"""Hugging Face AutoModel wrapper for consumer-facing PK runtime bundles."""
from __future__ import annotations
from typing import Any, Dict, Optional, Sequence, Union
import torch
from transformers import PreTrainedModel
from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
from sim_priors_pk.hub_runtime.runtime_contract import (
RuntimeBuilderConfig,
build_batch_from_studies,
infer_supported_tasks,
instantiate_backbone_from_hub_config,
normalize_studies_input,
split_runtime_samples,
validate_studies_for_task,
)
from sim_priors_pk.models.amortized_inference.generative_pk import (
NewGenerativeMixin,
NewPredictiveMixin,
)
class PKHubModel(PreTrainedModel):
"""Thin wrapper exposing a stable StudyJSON runtime API on top of PK models."""
config_class = PKHubConfig
base_model_prefix = "backbone"
def __init__(self, config: PKHubConfig, backbone: Optional[torch.nn.Module] = None) -> None:
super().__init__(config)
self.backbone = backbone if backbone is not None else instantiate_backbone_from_hub_config(config)
self.backbone.eval()
def forward(self, *args, **kwargs):
"""Delegate raw forward calls to the wrapped PK backbone."""
return self.backbone(*args, **kwargs)
@property
def supported_tasks(self) -> Sequence[str]:
"""Tasks supported by this runtime model."""
return tuple(getattr(self.config, "supported_tasks", []) or infer_supported_tasks(self.backbone))
@torch.inference_mode()
def run_task(
self,
*,
task: str,
studies: Union[StudyJSON, Sequence[StudyJSON]],
num_samples: int = 1,
**kwargs: Any,
) -> Dict[str, Any]:
"""Run the public StudyJSON inference contract for the requested task."""
supported_tasks = list(self.supported_tasks)
if task not in supported_tasks:
raise ValueError(
f"Unsupported task {task!r}. Supported tasks: {supported_tasks or 'none'}."
)
if int(num_samples) < 1:
raise ValueError("num_samples must be >= 1.")
canonical_studies = normalize_studies_input(studies)
builder_config = RuntimeBuilderConfig.from_dict(self.config.builder_config)
validate_studies_for_task(canonical_studies, task=task, builder_config=builder_config)
experiment_config_payload = getattr(self.config, "experiment_config", {})
meta_dosing_payload = experiment_config_payload.get("dosing", {})
batch = build_batch_from_studies(
canonical_studies,
builder_config=builder_config,
meta_dosing=self.backbone.meta_dosing.__class__(**meta_dosing_payload)
if meta_dosing_payload
else self.backbone.meta_dosing,
)
batch = batch.to(self.device)
if task == "generate":
if not isinstance(self.backbone, NewGenerativeMixin):
raise ValueError(f"Backbone {type(self.backbone).__name__} does not support generate.")
output_studies = self.backbone.sample_new_individuals_to_studyjson(
batch,
sample_size=int(num_samples),
num_steps=kwargs.get("num_steps"),
)
elif task == "predict":
if not isinstance(self.backbone, NewPredictiveMixin):
raise ValueError(f"Backbone {type(self.backbone).__name__} does not support predict.")
output_studies = self.backbone.sample_individual_prediction_from_batch_list_to_studyjson(
[batch],
sample_size=int(num_samples),
)[0]
else:
raise ValueError(f"Unsupported task {task!r}.")
results = [
{
"input_index": index,
"samples": split_runtime_samples(task, study),
}
for index, study in enumerate(output_studies)
]
return {
"task": task,
"io_schema_version": self.config.io_schema_version,
"model_info": {
"architecture_name": self.config.architecture_name,
"experiment_type": self.config.experiment_type,
"supported_tasks": supported_tasks,
"runtime_repo_id": self.config.runtime_repo_id,
"original_repo_id": self.config.original_repo_id,
},
"results": results,
}
__all__ = ["PKHubModel"]