manual runtime bundle push from load_and_push.ipynb
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LICENSE +21 -0
- README.md +131 -0
- config.json +445 -0
- configuration_sim_priors_pk.py +42 -0
- modeling_sim_priors_pk.py +123 -0
- pytorch_model.bin +3 -0
- requirements.txt +4 -0
- sim_priors_pk/.DS_Store +0 -0
- sim_priors_pk/__init__.py +43 -0
- sim_priors_pk/config_classes/__init__.py +0 -0
- sim_priors_pk/config_classes/data_config.py +375 -0
- sim_priors_pk/config_classes/diffusion_pk_config.py +327 -0
- sim_priors_pk/config_classes/flow_pk_config.py +534 -0
- sim_priors_pk/config_classes/node_pk_config.py +518 -0
- sim_priors_pk/config_classes/source_process_config.py +52 -0
- sim_priors_pk/config_classes/training_config.py +96 -0
- sim_priors_pk/config_classes/utils.py +14 -0
- sim_priors_pk/config_classes/yaml_fallback.py +143 -0
- sim_priors_pk/data/README.md +86 -0
- sim_priors_pk/data/__init__.py +12 -0
- sim_priors_pk/data/data_empirical/__init__.py +35 -0
- sim_priors_pk/data/data_empirical/builder.py +1139 -0
- sim_priors_pk/data/data_empirical/json_schema.py +372 -0
- sim_priors_pk/data/data_empirical/json_stats.py +201 -0
- sim_priors_pk/data/data_empirical/simulx_to_json.py +71 -0
- sim_priors_pk/data/data_generation/__init__.py +0 -0
- sim_priors_pk/data/data_generation/compartment_models.py +721 -0
- sim_priors_pk/data/data_generation/compartment_models_management.py +1338 -0
- sim_priors_pk/data/data_generation/dosing_models.py +0 -0
- sim_priors_pk/data/data_generation/observations_classes.py +1776 -0
- sim_priors_pk/data/data_generation/observations_functions.py +69 -0
- sim_priors_pk/data/data_generation/study_population_stats.py +185 -0
- sim_priors_pk/data/data_preprocessing/__init__.py +0 -0
- sim_priors_pk/data/data_preprocessing/data_preprocessing_utils.py +321 -0
- sim_priors_pk/data/data_preprocessing/raw_to_tensors_bundles.py +360 -0
- sim_priors_pk/data/data_preprocessing/tensors_to_databatch.py +72 -0
- sim_priors_pk/data/datasets/aicme_batch.py +167 -0
- sim_priors_pk/data/datasets/aicme_datasets.py +1874 -0
- sim_priors_pk/data/extra/compartment_models_vectorized.py +182 -0
- sim_priors_pk/data/extra/kernels.py +28 -0
- sim_priors_pk/hub_runtime/README.md +187 -0
- sim_priors_pk/hub_runtime/__init__.py +19 -0
- sim_priors_pk/hub_runtime/configuration_sim_priors_pk.py +42 -0
- sim_priors_pk/hub_runtime/modeling_sim_priors_pk.py +123 -0
- sim_priors_pk/hub_runtime/runtime_bundle.py +269 -0
- sim_priors_pk/hub_runtime/runtime_contract.py +662 -0
- sim_priors_pk/metrics/__init__.py +0 -0
- sim_priors_pk/metrics/pk_metrics.py +490 -0
- sim_priors_pk/metrics/quantiles_coverage.py +310 -0
- sim_priors_pk/metrics/sampling_quality.py +409 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 César A. Ojeda
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
library_name: generative-pk
|
| 6 |
+
datasets:
|
| 7 |
+
- simulated
|
| 8 |
+
metrics:
|
| 9 |
+
- rmse
|
| 10 |
+
- npde
|
| 11 |
+
tags:
|
| 12 |
+
- generative
|
| 13 |
+
- predictive
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Hierarchical Neural Process for Pharmacokinetic Data
|
| 17 |
+
|
| 18 |
+
## Overview
|
| 19 |
+
An Amortized Context Neural Process Generative model for Pharmacokinetic Modelling
|
| 20 |
+
|
| 21 |
+
**Model details:**
|
| 22 |
+
- **Authors:** César Ojeda (@cesarali)
|
| 23 |
+
- **License:** Apache 2.0
|
| 24 |
+
|
| 25 |
+
## Intended use
|
| 26 |
+
Sample Drug Concentration Behavior and Sample and Prediction of New Points or new Individual
|
| 27 |
+
## Runtime Bundle
|
| 28 |
+
|
| 29 |
+
This repository is the consumer-facing runtime bundle for this PK model.
|
| 30 |
+
|
| 31 |
+
- Runtime repo: `cesarali/AICME-runtime`
|
| 32 |
+
- Native training/artifact repo: `cesarali/AICMEPK_cluster`
|
| 33 |
+
- Supported tasks: `generate`, `predict`
|
| 34 |
+
- Default task: `generate`
|
| 35 |
+
- Load path: `AutoModel.from_pretrained(..., trust_remote_code=True)`
|
| 36 |
+
|
| 37 |
+
### Installation
|
| 38 |
+
|
| 39 |
+
You do **not** need to install `sim_priors_pk` to use this runtime bundle.
|
| 40 |
+
|
| 41 |
+
`transformers` is the public loading entrypoint, but `transformers` alone is
|
| 42 |
+
not sufficient because this is a PyTorch model with custom runtime code. A
|
| 43 |
+
reliable consumer environment is:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install torch transformers huggingface_hub lightning datasets pandas torchtyping gpytorch pot torchdiffeq torchsde ruamel.yaml pyyaml
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### Python Usage
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
from transformers import AutoModel
|
| 53 |
+
|
| 54 |
+
model = AutoModel.from_pretrained("cesarali/AICME-runtime", trust_remote_code=True)
|
| 55 |
+
|
| 56 |
+
studies = [
|
| 57 |
+
{
|
| 58 |
+
"context": [
|
| 59 |
+
{
|
| 60 |
+
"name_id": "ctx_0",
|
| 61 |
+
"observations": [0.2, 0.5, 0.3],
|
| 62 |
+
"observation_times": [0.5, 1.0, 2.0],
|
| 63 |
+
"dosing": [1.0],
|
| 64 |
+
"dosing_type": ["oral"],
|
| 65 |
+
"dosing_times": [0.0],
|
| 66 |
+
"dosing_name": ["oral"],
|
| 67 |
+
}
|
| 68 |
+
],
|
| 69 |
+
"target": [],
|
| 70 |
+
"meta_data": {"study_name": "demo", "substance_name": "drug_x"},
|
| 71 |
+
}
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
outputs = model.run_task(
|
| 75 |
+
task="generate",
|
| 76 |
+
studies=studies,
|
| 77 |
+
num_samples=4,
|
| 78 |
+
)
|
| 79 |
+
print(outputs["results"][0]["samples"])
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### Predictive Sampling
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
from transformers import AutoModel
|
| 86 |
+
|
| 87 |
+
model = AutoModel.from_pretrained("cesarali/AICME-runtime", trust_remote_code=True)
|
| 88 |
+
|
| 89 |
+
predict_studies = [
|
| 90 |
+
{
|
| 91 |
+
"context": [
|
| 92 |
+
{
|
| 93 |
+
"name_id": "ctx_0",
|
| 94 |
+
"observations": [0.2, 0.5, 0.3],
|
| 95 |
+
"observation_times": [0.5, 1.0, 2.0],
|
| 96 |
+
"dosing": [1.0],
|
| 97 |
+
"dosing_type": ["oral"],
|
| 98 |
+
"dosing_times": [0.0],
|
| 99 |
+
"dosing_name": ["oral"],
|
| 100 |
+
}
|
| 101 |
+
],
|
| 102 |
+
"target": [
|
| 103 |
+
{
|
| 104 |
+
"name_id": "tgt_0",
|
| 105 |
+
"observations": [0.25, 0.31],
|
| 106 |
+
"observation_times": [0.5, 1.0],
|
| 107 |
+
"remaining": [0.0, 0.0, 0.0],
|
| 108 |
+
"remaining_times": [2.0, 4.0, 8.0],
|
| 109 |
+
"dosing": [1.0],
|
| 110 |
+
"dosing_type": ["oral"],
|
| 111 |
+
"dosing_times": [0.0],
|
| 112 |
+
"dosing_name": ["oral"],
|
| 113 |
+
}
|
| 114 |
+
],
|
| 115 |
+
"meta_data": {"study_name": "demo", "substance_name": "drug_x"},
|
| 116 |
+
}
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
outputs = model.run_task(
|
| 120 |
+
task="predict",
|
| 121 |
+
studies=predict_studies,
|
| 122 |
+
num_samples=4,
|
| 123 |
+
)
|
| 124 |
+
print(outputs["results"][0]["samples"][0]["target"][0]["prediction_samples"])
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Notes
|
| 128 |
+
|
| 129 |
+
- `trust_remote_code=True` is required because this model uses custom Hugging Face Hub runtime code.
|
| 130 |
+
- The consumer API is `transformers` + `run_task(...)`; the consumer does not need a local clone of this repository.
|
| 131 |
+
- This runtime bundle is intentionally separate from the native training export so you can evaluate both distribution paths in parallel.
|
config.json
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture_name": "AICMEPK",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"PKHubModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_sim_priors_pk.PKHubConfig",
|
| 8 |
+
"AutoModel": "modeling_sim_priors_pk.PKHubModel"
|
| 9 |
+
},
|
| 10 |
+
"builder_config": {
|
| 11 |
+
"max_context_individuals": 10,
|
| 12 |
+
"max_context_observations": 15,
|
| 13 |
+
"max_context_remaining": 15,
|
| 14 |
+
"max_target_individuals": 1,
|
| 15 |
+
"max_target_observations": 5,
|
| 16 |
+
"max_target_remaining": 12
|
| 17 |
+
},
|
| 18 |
+
"default_task": "generate",
|
| 19 |
+
"experiment_config": {
|
| 20 |
+
"comet_ai_key": null,
|
| 21 |
+
"context_observations": {
|
| 22 |
+
"add_rem": true,
|
| 23 |
+
"drop_time_zero_observations": false,
|
| 24 |
+
"empirical_number_of_obs": false,
|
| 25 |
+
"generative_bias": false,
|
| 26 |
+
"max_num_obs": 15,
|
| 27 |
+
"max_past": 5,
|
| 28 |
+
"min_past": 3,
|
| 29 |
+
"past_time_ratio": 0.1,
|
| 30 |
+
"split_past_future": false,
|
| 31 |
+
"type": "pk_peak_half_life"
|
| 32 |
+
},
|
| 33 |
+
"debug_test": false,
|
| 34 |
+
"dosing": {
|
| 35 |
+
"logdose_mean_range": [
|
| 36 |
+
-2.0,
|
| 37 |
+
2.0
|
| 38 |
+
],
|
| 39 |
+
"logdose_std_range": [
|
| 40 |
+
0.1,
|
| 41 |
+
0.5
|
| 42 |
+
],
|
| 43 |
+
"num_individuals": 10,
|
| 44 |
+
"route_options": [
|
| 45 |
+
"oral",
|
| 46 |
+
"iv"
|
| 47 |
+
],
|
| 48 |
+
"route_weights": [
|
| 49 |
+
0.8,
|
| 50 |
+
0.2
|
| 51 |
+
],
|
| 52 |
+
"same_route": true,
|
| 53 |
+
"time": 0.0
|
| 54 |
+
},
|
| 55 |
+
"experiment_dir": "/work/ojedamarin/Projects/Pharma/Results/comet/uai/7195d8f55b5d4684a766a69d5a736d28",
|
| 56 |
+
"experiment_indentifier": null,
|
| 57 |
+
"experiment_name": "uai",
|
| 58 |
+
"experiment_type": "nodepk",
|
| 59 |
+
"hf_model_card_path": [
|
| 60 |
+
"hf_model_cards",
|
| 61 |
+
"AICME-PK_Readme.md"
|
| 62 |
+
],
|
| 63 |
+
"hf_model_name": "AICMEPK_cluster",
|
| 64 |
+
"hugging_face_token": null,
|
| 65 |
+
"meta_study": {
|
| 66 |
+
"V_tmag_range": [
|
| 67 |
+
0.001,
|
| 68 |
+
0.001
|
| 69 |
+
],
|
| 70 |
+
"V_tscl_range": [
|
| 71 |
+
1,
|
| 72 |
+
5
|
| 73 |
+
],
|
| 74 |
+
"drug_id_options": [
|
| 75 |
+
"Drug_A",
|
| 76 |
+
"Drug_B",
|
| 77 |
+
"Drug_C"
|
| 78 |
+
],
|
| 79 |
+
"k_1p_tmag_range": [
|
| 80 |
+
0.01,
|
| 81 |
+
0.02
|
| 82 |
+
],
|
| 83 |
+
"k_1p_tscl_range": [
|
| 84 |
+
1,
|
| 85 |
+
5
|
| 86 |
+
],
|
| 87 |
+
"k_a_tmag_range": [
|
| 88 |
+
0.01,
|
| 89 |
+
0.02
|
| 90 |
+
],
|
| 91 |
+
"k_a_tscl_range": [
|
| 92 |
+
1,
|
| 93 |
+
5
|
| 94 |
+
],
|
| 95 |
+
"k_e_tmag_range": [
|
| 96 |
+
0.01,
|
| 97 |
+
0.02
|
| 98 |
+
],
|
| 99 |
+
"k_e_tscl_range": [
|
| 100 |
+
1,
|
| 101 |
+
5
|
| 102 |
+
],
|
| 103 |
+
"k_p1_tmag_range": [
|
| 104 |
+
0.01,
|
| 105 |
+
0.02
|
| 106 |
+
],
|
| 107 |
+
"k_p1_tscl_range": [
|
| 108 |
+
1,
|
| 109 |
+
5
|
| 110 |
+
],
|
| 111 |
+
"log_V_mean_range": [
|
| 112 |
+
2,
|
| 113 |
+
8
|
| 114 |
+
],
|
| 115 |
+
"log_V_std_range": [
|
| 116 |
+
0.2,
|
| 117 |
+
0.6
|
| 118 |
+
],
|
| 119 |
+
"log_k_1p_mean_range": [
|
| 120 |
+
-4,
|
| 121 |
+
0
|
| 122 |
+
],
|
| 123 |
+
"log_k_1p_std_range": [
|
| 124 |
+
0.2,
|
| 125 |
+
0.6
|
| 126 |
+
],
|
| 127 |
+
"log_k_a_mean_range": [
|
| 128 |
+
-1,
|
| 129 |
+
2
|
| 130 |
+
],
|
| 131 |
+
"log_k_a_std_range": [
|
| 132 |
+
0.2,
|
| 133 |
+
0.6
|
| 134 |
+
],
|
| 135 |
+
"log_k_e_mean_range": [
|
| 136 |
+
-5,
|
| 137 |
+
0
|
| 138 |
+
],
|
| 139 |
+
"log_k_e_std_range": [
|
| 140 |
+
0.2,
|
| 141 |
+
0.6
|
| 142 |
+
],
|
| 143 |
+
"log_k_p1_mean_range": [
|
| 144 |
+
-4,
|
| 145 |
+
-1
|
| 146 |
+
],
|
| 147 |
+
"log_k_p1_std_range": [
|
| 148 |
+
0.2,
|
| 149 |
+
0.6
|
| 150 |
+
],
|
| 151 |
+
"num_individuals_range": [
|
| 152 |
+
5,
|
| 153 |
+
10
|
| 154 |
+
],
|
| 155 |
+
"num_peripherals_range": [
|
| 156 |
+
1,
|
| 157 |
+
3
|
| 158 |
+
],
|
| 159 |
+
"rel_ruv_range": [
|
| 160 |
+
0.001,
|
| 161 |
+
0.01
|
| 162 |
+
],
|
| 163 |
+
"solver_method": "rk4",
|
| 164 |
+
"time_num_steps": 100,
|
| 165 |
+
"time_start": 0.0,
|
| 166 |
+
"time_stop": 16.0
|
| 167 |
+
},
|
| 168 |
+
"mix_data": {
|
| 169 |
+
"evaluate_prediction_steps_past": 5,
|
| 170 |
+
"keep_tempfile": false,
|
| 171 |
+
"log_and_max": false,
|
| 172 |
+
"log_and_z": false,
|
| 173 |
+
"log_transform": false,
|
| 174 |
+
"n_of_databatches": null,
|
| 175 |
+
"n_of_permutations": 3,
|
| 176 |
+
"n_of_target_individuals": 1,
|
| 177 |
+
"normalize_by_max": true,
|
| 178 |
+
"normalize_time": true,
|
| 179 |
+
"recreate_tempfile": false,
|
| 180 |
+
"sample_size_for_generative_evaluation": null,
|
| 181 |
+
"sample_size_for_generative_evaluation_end_of_training": 500,
|
| 182 |
+
"sample_size_for_generative_evaluation_val": 10,
|
| 183 |
+
"store_in_tempfile": false,
|
| 184 |
+
"tempfile_path": [
|
| 185 |
+
"preprocessed",
|
| 186 |
+
"simulated_ou_as_rates"
|
| 187 |
+
],
|
| 188 |
+
"test_empirical_datasets": [
|
| 189 |
+
"cesarali/lenuzza-2016",
|
| 190 |
+
"cesarali/Indometacin",
|
| 191 |
+
"cesarali/Theophylline"
|
| 192 |
+
],
|
| 193 |
+
"test_size": 64,
|
| 194 |
+
"tqdm_progress": false,
|
| 195 |
+
"train_size": 12800,
|
| 196 |
+
"val_size": 256,
|
| 197 |
+
"z_score_normalization": false
|
| 198 |
+
},
|
| 199 |
+
"my_results_path": "/work/ojedamarin/Projects/Pharma/Results/",
|
| 200 |
+
"name_str": "AICMEPK",
|
| 201 |
+
"network": {
|
| 202 |
+
"activation": "ReLU",
|
| 203 |
+
"aggregator_num_heads": 8,
|
| 204 |
+
"aggregator_type": "mean",
|
| 205 |
+
"combine_latent_mode": "mlp",
|
| 206 |
+
"cov_proj_dim": 16,
|
| 207 |
+
"decoder_attention_layers": 2,
|
| 208 |
+
"decoder_hidden_dim": 512,
|
| 209 |
+
"decoder_name": "TransformerDecoder",
|
| 210 |
+
"decoder_num_layers": 4,
|
| 211 |
+
"decoder_rnn_hidden_dim": 256,
|
| 212 |
+
"drift_activation": "Tanh",
|
| 213 |
+
"drift_num_layers": 2,
|
| 214 |
+
"dropout": 0.1,
|
| 215 |
+
"encoder_rnn_hidden_dim": 256,
|
| 216 |
+
"exclusive_node_step": true,
|
| 217 |
+
"ignore_logvar": true,
|
| 218 |
+
"individual_encoder_name": "RNNContextEncoder",
|
| 219 |
+
"individual_encoder_number_of_heads": 4,
|
| 220 |
+
"init_hidden_num_layers": 2,
|
| 221 |
+
"input_encoding_hidden_dim": 128,
|
| 222 |
+
"kl_weight": 1.0,
|
| 223 |
+
"loss_name": "log_nll",
|
| 224 |
+
"node_step": true,
|
| 225 |
+
"norm": "layer",
|
| 226 |
+
"output_head_num_layers": 3,
|
| 227 |
+
"prediction_latent_deterministic": false,
|
| 228 |
+
"prediction_only": false,
|
| 229 |
+
"reconstruction_only": false,
|
| 230 |
+
"rnn_decoder_number_of_layers": 4,
|
| 231 |
+
"rnn_individual_encoder_number_of_layers": 4,
|
| 232 |
+
"scale_dosing_amounts": true,
|
| 233 |
+
"study_latent_deterministic": false,
|
| 234 |
+
"time_obs_encoder_hidden_dim": 256,
|
| 235 |
+
"time_obs_encoder_output_dim": 256,
|
| 236 |
+
"use_attention": true,
|
| 237 |
+
"use_invariance_loss": false,
|
| 238 |
+
"use_kl_i": true,
|
| 239 |
+
"use_kl_i_np": true,
|
| 240 |
+
"use_kl_init": true,
|
| 241 |
+
"use_kl_s": true,
|
| 242 |
+
"use_self_attention": true,
|
| 243 |
+
"use_time_deltas": true,
|
| 244 |
+
"zi_latent_dim": 128
|
| 245 |
+
},
|
| 246 |
+
"run_index": 0,
|
| 247 |
+
"tags": [
|
| 248 |
+
"AICME",
|
| 249 |
+
"AISTATS-2026",
|
| 250 |
+
"camera-ready"
|
| 251 |
+
],
|
| 252 |
+
"target_observations": {
|
| 253 |
+
"add_rem": true,
|
| 254 |
+
"drop_time_zero_observations": false,
|
| 255 |
+
"empirical_number_of_obs": 2,
|
| 256 |
+
"generative_bias": false,
|
| 257 |
+
"max_num_obs": 15,
|
| 258 |
+
"max_past": 5,
|
| 259 |
+
"min_past": 3,
|
| 260 |
+
"past_time_ratio": 0.1,
|
| 261 |
+
"split_past_future": true,
|
| 262 |
+
"type": "pk_peak_half_life"
|
| 263 |
+
},
|
| 264 |
+
"train": {
|
| 265 |
+
"amsgrad": false,
|
| 266 |
+
"batch_size": 64,
|
| 267 |
+
"betas": [
|
| 268 |
+
0.9,
|
| 269 |
+
0.999
|
| 270 |
+
],
|
| 271 |
+
"callbacks_scheduler": {
|
| 272 |
+
"checkpoint_used_in_end": [
|
| 273 |
+
"end",
|
| 274 |
+
"best",
|
| 275 |
+
"log_rmse"
|
| 276 |
+
],
|
| 277 |
+
"include_end": true,
|
| 278 |
+
"keep_temp_files": false,
|
| 279 |
+
"max_samples_per_group": 500,
|
| 280 |
+
"percent_step": 0.1,
|
| 281 |
+
"skip_sanity_check": true,
|
| 282 |
+
"store_samples": true,
|
| 283 |
+
"task_during": [
|
| 284 |
+
{
|
| 285 |
+
"fn_key": "pk.predictive.images",
|
| 286 |
+
"log_prefix": "Synthetic",
|
| 287 |
+
"n_samples": 1,
|
| 288 |
+
"name": "synthetic/predictive_images",
|
| 289 |
+
"sample_source": "val_batch",
|
| 290 |
+
"save_to_disk": true,
|
| 291 |
+
"split": "val",
|
| 292 |
+
"task_cfg": {
|
| 293 |
+
"label": "Synthetic",
|
| 294 |
+
"milestone_stride": 1
|
| 295 |
+
}
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"fn_key": "pk.generative.images",
|
| 299 |
+
"log_prefix": "Synthetic",
|
| 300 |
+
"n_samples": 10,
|
| 301 |
+
"name": "synthetic/new_individuals_images",
|
| 302 |
+
"sample_source": "val_batch",
|
| 303 |
+
"save_to_disk": true,
|
| 304 |
+
"split": "val",
|
| 305 |
+
"task_cfg": {
|
| 306 |
+
"label": "Synthetic",
|
| 307 |
+
"milestone_stride": 1
|
| 308 |
+
}
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"fn_key": "pk.predictive.metrics",
|
| 312 |
+
"log_prefix": "Empirical",
|
| 313 |
+
"n_samples": 1,
|
| 314 |
+
"name": "empirical/predictive_metrics",
|
| 315 |
+
"sample_source": "empirical_set",
|
| 316 |
+
"save_to_disk": false,
|
| 317 |
+
"split": "empirical_heldout",
|
| 318 |
+
"task_cfg": {
|
| 319 |
+
"label": "Empirical",
|
| 320 |
+
"milestone_stride": 5
|
| 321 |
+
}
|
| 322 |
+
},
|
| 323 |
+
{
|
| 324 |
+
"checkpoint_metric": true,
|
| 325 |
+
"checkpoint_metric_name": "log_rmse",
|
| 326 |
+
"checkpoint_mode": "min",
|
| 327 |
+
"fn_key": "pk.empirical.summary",
|
| 328 |
+
"log_prefix": "Empirical",
|
| 329 |
+
"n_samples": 0,
|
| 330 |
+
"name": "empirical/summary",
|
| 331 |
+
"sample_source": "val_batch",
|
| 332 |
+
"save_to_disk": false,
|
| 333 |
+
"split": "val",
|
| 334 |
+
"task_cfg": {
|
| 335 |
+
"label": "Empirical",
|
| 336 |
+
"milestone_stride": 5,
|
| 337 |
+
"selected_summary_drugs": [
|
| 338 |
+
"paracetamol glucuronide",
|
| 339 |
+
"midazolam"
|
| 340 |
+
],
|
| 341 |
+
"summary_metric": "log_rmse"
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
],
|
| 345 |
+
"tasks_end": [
|
| 346 |
+
{
|
| 347 |
+
"fn_key": "pk.predictive.metrics",
|
| 348 |
+
"log_prefix": "Empirical",
|
| 349 |
+
"n_samples": 1,
|
| 350 |
+
"name": "empirical/predictive_metrics",
|
| 351 |
+
"sample_source": "empirical_set",
|
| 352 |
+
"save_to_disk": false,
|
| 353 |
+
"split": "empirical_heldout",
|
| 354 |
+
"task_cfg": {
|
| 355 |
+
"label": "Empirical"
|
| 356 |
+
}
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"fn_key": "pk.predictive.images",
|
| 360 |
+
"log_prefix": "Empirical",
|
| 361 |
+
"n_samples": 1,
|
| 362 |
+
"name": "empirical/predictive_images",
|
| 363 |
+
"sample_source": "empirical_set",
|
| 364 |
+
"save_to_disk": true,
|
| 365 |
+
"split": "empirical_heldout",
|
| 366 |
+
"task_cfg": {
|
| 367 |
+
"label": "Empirical"
|
| 368 |
+
}
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"fn_key": "pk.vpc.npde_pvalues",
|
| 372 |
+
"log_prefix": "Empirical",
|
| 373 |
+
"n_samples": 500,
|
| 374 |
+
"name": "empirical/vpc_npde_pvalues",
|
| 375 |
+
"sample_source": "empirical_set",
|
| 376 |
+
"save_to_disk": false,
|
| 377 |
+
"split": "empirical_no_heldout",
|
| 378 |
+
"task_cfg": {
|
| 379 |
+
"label": "Empirical"
|
| 380 |
+
}
|
| 381 |
+
},
|
| 382 |
+
{
|
| 383 |
+
"fn_key": "pk.vpc.images",
|
| 384 |
+
"log_prefix": "Empirical",
|
| 385 |
+
"n_samples": 500,
|
| 386 |
+
"name": "empirical/vpc_images",
|
| 387 |
+
"sample_source": "empirical_set",
|
| 388 |
+
"save_to_disk": true,
|
| 389 |
+
"split": "empirical_no_heldout",
|
| 390 |
+
"task_cfg": {
|
| 391 |
+
"label": "Empirical"
|
| 392 |
+
}
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"fn_key": "pk.empirical.summary",
|
| 396 |
+
"log_prefix": "Empirical",
|
| 397 |
+
"n_samples": 0,
|
| 398 |
+
"name": "empirical/summary",
|
| 399 |
+
"sample_source": "val_batch",
|
| 400 |
+
"save_to_disk": false,
|
| 401 |
+
"split": "val",
|
| 402 |
+
"task_cfg": {
|
| 403 |
+
"label": "Empirical",
|
| 404 |
+
"selected_summary_drugs": [
|
| 405 |
+
"paracetamol glucuronide",
|
| 406 |
+
"midazolam"
|
| 407 |
+
],
|
| 408 |
+
"summary_metric": "log_rmse"
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
],
|
| 412 |
+
"tasks_validation": []
|
| 413 |
+
},
|
| 414 |
+
"epochs": 100,
|
| 415 |
+
"eps": 1e-08,
|
| 416 |
+
"gradient_clip_val": 0.5,
|
| 417 |
+
"learning_rate": 0.0001,
|
| 418 |
+
"log_interval": 1,
|
| 419 |
+
"num_batch_plot": 1,
|
| 420 |
+
"num_workers": 8,
|
| 421 |
+
"optimizer_name": "AdamW",
|
| 422 |
+
"persistent_workers": true,
|
| 423 |
+
"scheduler_name": "CosineAnnealingLR",
|
| 424 |
+
"scheduler_params": {
|
| 425 |
+
"T_max": 1000,
|
| 426 |
+
"eta_min": 5e-05,
|
| 427 |
+
"last_epoch": -1
|
| 428 |
+
},
|
| 429 |
+
"shuffle_val": true,
|
| 430 |
+
"weight_decay": 0.0001
|
| 431 |
+
},
|
| 432 |
+
"upload_to_hf_hub": true,
|
| 433 |
+
"verbose": false
|
| 434 |
+
},
|
| 435 |
+
"experiment_type": "nodepk",
|
| 436 |
+
"io_schema_version": "studyjson-v1",
|
| 437 |
+
"model_type": "sim_priors_pk",
|
| 438 |
+
"original_repo_id": "cesarali/AICMEPK_cluster",
|
| 439 |
+
"runtime_repo_id": "cesarali/AICME-runtime",
|
| 440 |
+
"supported_tasks": [
|
| 441 |
+
"generate",
|
| 442 |
+
"predict"
|
| 443 |
+
],
|
| 444 |
+
"transformers_version": "4.52.4"
|
| 445 |
+
}
|
configuration_sim_priors_pk.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face configuration for self-contained PK runtime bundles."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
|
| 9 |
+
from sim_priors_pk.hub_runtime.runtime_contract import STUDY_JSON_IO_VERSION
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PKHubConfig(PretrainedConfig):
|
| 13 |
+
"""Public Hub config describing a consumer-facing PK runtime bundle."""
|
| 14 |
+
|
| 15 |
+
model_type = "sim_priors_pk"
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
architecture_name: Optional[str] = None,
|
| 20 |
+
experiment_type: str = "nodepk",
|
| 21 |
+
experiment_config: Optional[Dict[str, Any]] = None,
|
| 22 |
+
builder_config: Optional[Dict[str, Any]] = None,
|
| 23 |
+
supported_tasks: Optional[List[str]] = None,
|
| 24 |
+
default_task: Optional[str] = None,
|
| 25 |
+
io_schema_version: str = STUDY_JSON_IO_VERSION,
|
| 26 |
+
original_repo_id: Optional[str] = None,
|
| 27 |
+
runtime_repo_id: Optional[str] = None,
|
| 28 |
+
**kwargs,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__(**kwargs)
|
| 31 |
+
self.architecture_name = architecture_name
|
| 32 |
+
self.experiment_type = experiment_type
|
| 33 |
+
self.experiment_config = dict(experiment_config or {})
|
| 34 |
+
self.builder_config = dict(builder_config or {})
|
| 35 |
+
self.supported_tasks = list(supported_tasks or [])
|
| 36 |
+
self.default_task = default_task or (self.supported_tasks[0] if self.supported_tasks else None)
|
| 37 |
+
self.io_schema_version = io_schema_version
|
| 38 |
+
self.original_repo_id = original_repo_id
|
| 39 |
+
self.runtime_repo_id = runtime_repo_id
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
__all__ = ["PKHubConfig"]
|
modeling_sim_priors_pk.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face AutoModel wrapper for consumer-facing PK runtime bundles."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Optional, Sequence, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
|
| 10 |
+
from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
|
| 11 |
+
from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
|
| 12 |
+
from sim_priors_pk.hub_runtime.runtime_contract import (
|
| 13 |
+
RuntimeBuilderConfig,
|
| 14 |
+
build_batch_from_studies,
|
| 15 |
+
infer_supported_tasks,
|
| 16 |
+
instantiate_backbone_from_hub_config,
|
| 17 |
+
normalize_studies_input,
|
| 18 |
+
split_runtime_samples,
|
| 19 |
+
validate_studies_for_task,
|
| 20 |
+
)
|
| 21 |
+
from sim_priors_pk.models.amortized_inference.generative_pk import (
|
| 22 |
+
NewGenerativeMixin,
|
| 23 |
+
NewPredictiveMixin,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class PKHubModel(PreTrainedModel):
|
| 28 |
+
"""Thin wrapper exposing a stable StudyJSON runtime API on top of PK models."""
|
| 29 |
+
|
| 30 |
+
config_class = PKHubConfig
|
| 31 |
+
base_model_prefix = "backbone"
|
| 32 |
+
|
| 33 |
+
def __init__(self, config: PKHubConfig, backbone: Optional[torch.nn.Module] = None) -> None:
|
| 34 |
+
super().__init__(config)
|
| 35 |
+
self.backbone = backbone if backbone is not None else instantiate_backbone_from_hub_config(config)
|
| 36 |
+
self.backbone.eval()
|
| 37 |
+
|
| 38 |
+
def forward(self, *args, **kwargs):
|
| 39 |
+
"""Delegate raw forward calls to the wrapped PK backbone."""
|
| 40 |
+
|
| 41 |
+
return self.backbone(*args, **kwargs)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def supported_tasks(self) -> Sequence[str]:
|
| 45 |
+
"""Tasks supported by this runtime model."""
|
| 46 |
+
|
| 47 |
+
return tuple(getattr(self.config, "supported_tasks", []) or infer_supported_tasks(self.backbone))
|
| 48 |
+
|
| 49 |
+
@torch.inference_mode()
|
| 50 |
+
def run_task(
|
| 51 |
+
self,
|
| 52 |
+
*,
|
| 53 |
+
task: str,
|
| 54 |
+
studies: Union[StudyJSON, Sequence[StudyJSON]],
|
| 55 |
+
num_samples: int = 1,
|
| 56 |
+
**kwargs: Any,
|
| 57 |
+
) -> Dict[str, Any]:
|
| 58 |
+
"""Run the public StudyJSON inference contract for the requested task."""
|
| 59 |
+
|
| 60 |
+
supported_tasks = list(self.supported_tasks)
|
| 61 |
+
if task not in supported_tasks:
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"Unsupported task {task!r}. Supported tasks: {supported_tasks or 'none'}."
|
| 64 |
+
)
|
| 65 |
+
if int(num_samples) < 1:
|
| 66 |
+
raise ValueError("num_samples must be >= 1.")
|
| 67 |
+
|
| 68 |
+
canonical_studies = normalize_studies_input(studies)
|
| 69 |
+
builder_config = RuntimeBuilderConfig.from_dict(self.config.builder_config)
|
| 70 |
+
validate_studies_for_task(canonical_studies, task=task, builder_config=builder_config)
|
| 71 |
+
|
| 72 |
+
experiment_config_payload = getattr(self.config, "experiment_config", {})
|
| 73 |
+
meta_dosing_payload = experiment_config_payload.get("dosing", {})
|
| 74 |
+
batch = build_batch_from_studies(
|
| 75 |
+
canonical_studies,
|
| 76 |
+
builder_config=builder_config,
|
| 77 |
+
meta_dosing=self.backbone.meta_dosing.__class__(**meta_dosing_payload)
|
| 78 |
+
if meta_dosing_payload
|
| 79 |
+
else self.backbone.meta_dosing,
|
| 80 |
+
)
|
| 81 |
+
batch = batch.to(self.device)
|
| 82 |
+
|
| 83 |
+
if task == "generate":
|
| 84 |
+
if not isinstance(self.backbone, NewGenerativeMixin):
|
| 85 |
+
raise ValueError(f"Backbone {type(self.backbone).__name__} does not support generate.")
|
| 86 |
+
output_studies = self.backbone.sample_new_individuals_to_studyjson(
|
| 87 |
+
batch,
|
| 88 |
+
sample_size=int(num_samples),
|
| 89 |
+
num_steps=kwargs.get("num_steps"),
|
| 90 |
+
)
|
| 91 |
+
elif task == "predict":
|
| 92 |
+
if not isinstance(self.backbone, NewPredictiveMixin):
|
| 93 |
+
raise ValueError(f"Backbone {type(self.backbone).__name__} does not support predict.")
|
| 94 |
+
output_studies = self.backbone.sample_individual_prediction_from_batch_list_to_studyjson(
|
| 95 |
+
[batch],
|
| 96 |
+
sample_size=int(num_samples),
|
| 97 |
+
)[0]
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"Unsupported task {task!r}.")
|
| 100 |
+
|
| 101 |
+
results = [
|
| 102 |
+
{
|
| 103 |
+
"input_index": index,
|
| 104 |
+
"samples": split_runtime_samples(task, study),
|
| 105 |
+
}
|
| 106 |
+
for index, study in enumerate(output_studies)
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
return {
|
| 110 |
+
"task": task,
|
| 111 |
+
"io_schema_version": self.config.io_schema_version,
|
| 112 |
+
"model_info": {
|
| 113 |
+
"architecture_name": self.config.architecture_name,
|
| 114 |
+
"experiment_type": self.config.experiment_type,
|
| 115 |
+
"supported_tasks": supported_tasks,
|
| 116 |
+
"runtime_repo_id": self.config.runtime_repo_id,
|
| 117 |
+
"original_repo_id": self.config.original_repo_id,
|
| 118 |
+
},
|
| 119 |
+
"results": results,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
__all__ = ["PKHubModel"]
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec19d3a6970fcda03332a75ea0b12bb53e17e2d945088ef46e28c74f73195c84
|
| 3 |
+
size 37495779
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytest==8.3.5
|
| 2 |
+
ipython==9.2.0
|
| 3 |
+
comet_ml==3.49.6
|
| 4 |
+
matplotlib==3.10.1 # If not needed for inference
|
sim_priors_pk/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
sim_priors_pk/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _load_key_file(path: Path) -> str | None:
|
| 5 |
+
"""Return the contents of a key file if it exists, otherwise ``None``."""
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
return path.read_text(encoding="utf-8").strip()
|
| 9 |
+
except FileNotFoundError:
|
| 10 |
+
return None
|
| 11 |
+
except OSError:
|
| 12 |
+
# If the file is unreadable we surface the issue by returning ``None``
|
| 13 |
+
# so callers can decide how to handle missing credentials.
|
| 14 |
+
return None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
base_dir = Path(__file__).resolve().parent
|
| 18 |
+
project_dir = (base_dir / "..").resolve()
|
| 19 |
+
data_dir = project_dir / "data"
|
| 20 |
+
test_resources_dir = project_dir / "tests" / "resources"
|
| 21 |
+
results_dir = project_dir / "results"
|
| 22 |
+
reports_dir = project_dir / "reports"
|
| 23 |
+
config_dir = project_dir / "config_files"
|
| 24 |
+
|
| 25 |
+
comet_keys_file = project_dir / "COMET_KEYS.txt"
|
| 26 |
+
hf_keys_file = project_dir / "KEYS.txt"
|
| 27 |
+
|
| 28 |
+
COMET_KEY = _load_key_file(comet_keys_file)
|
| 29 |
+
HUGGINGFACE_KEY = _load_key_file(hf_keys_file)
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
"COMET_KEY",
|
| 33 |
+
"HUGGINGFACE_KEY",
|
| 34 |
+
"base_dir",
|
| 35 |
+
"comet_keys_file",
|
| 36 |
+
"config_dir",
|
| 37 |
+
"data_dir",
|
| 38 |
+
"hf_keys_file",
|
| 39 |
+
"project_dir",
|
| 40 |
+
"reports_dir",
|
| 41 |
+
"results_dir",
|
| 42 |
+
"test_resources_dir",
|
| 43 |
+
]
|
sim_priors_pk/config_classes/__init__.py
ADDED
|
File without changes
|
sim_priors_pk/config_classes/data_config.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import List, Dict, Tuple, Optional, Union
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
try: # pragma: no cover - exercised indirectly via configuration loading
|
| 7 |
+
import yaml # type: ignore
|
| 8 |
+
except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
|
| 9 |
+
from sim_priors_pk.config_classes import yaml_fallback as yaml
|
| 10 |
+
|
| 11 |
+
try: # pragma: no cover - optional dependency for downstream modules
|
| 12 |
+
import torch # type: ignore
|
| 13 |
+
except ModuleNotFoundError: # pragma: no cover - torch is not required for configuration loading
|
| 14 |
+
torch = None # type: ignore
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class SimpleMetaStudyConfig:
|
| 18 |
+
"""
|
| 19 |
+
Minimal configuration for the synthetic (non-mechanistic) PK simulator.
|
| 20 |
+
Used when `simple_mode=True` is detected in the YAML file.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
simple_mode: bool = True
|
| 24 |
+
|
| 25 |
+
# --- keep same naming as MetaStudyConfig for compatibility ---
|
| 26 |
+
num_individuals: int = 16
|
| 27 |
+
num_individuals_range: Tuple[int, int] = (16, 16) # <== added to avoid downstream errors
|
| 28 |
+
|
| 29 |
+
time_start: float = 0.0
|
| 30 |
+
time_stop: float = 24.0
|
| 31 |
+
time_num_steps: int = 40
|
| 32 |
+
|
| 33 |
+
band_scale_range: Tuple[float, float] = (0.1, 0.3)
|
| 34 |
+
baseline_range: Tuple[float, float] = (0.0, 0.1)
|
| 35 |
+
decay_rate_range: Tuple[float, float] = (0.3, 0.6)
|
| 36 |
+
p1: float = 0.5 # of runs use the exponential, 65% use the pulse
|
| 37 |
+
num_peripherals_range: Tuple[int, int] = (1, 3)
|
| 38 |
+
|
| 39 |
+
solver_method: str = "dummy"
|
| 40 |
+
drug_id_options: List[str] = field(default_factory=lambda: ["DummyDrug"])
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "SimpleMetaStudyConfig":
|
| 44 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 45 |
+
cfg = yaml.safe_load(handle) or {}
|
| 46 |
+
cfg = cfg.get("meta_study", cfg)
|
| 47 |
+
|
| 48 |
+
# Ensure backward compatibility if YAML only defines num_individuals
|
| 49 |
+
if "num_individuals_range" not in cfg and "num_individuals" in cfg:
|
| 50 |
+
n = cfg["num_individuals"]
|
| 51 |
+
cfg["num_individuals_range"] = (n, n)
|
| 52 |
+
|
| 53 |
+
return cls(**cfg)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class MetaStudyConfig:
|
| 58 |
+
"""
|
| 59 |
+
This class contains the configuration for the compartment study.
|
| 60 |
+
i.e. it specifies the parameters to sample the population which
|
| 61 |
+
in turns will sample the individuals.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
drug_id_options: List[str] = field(default_factory=lambda: ["Drug_A", "Drug_B", "Drug_C"])
|
| 65 |
+
num_individuals_range: Tuple[int, int] = (20, 20)
|
| 66 |
+
|
| 67 |
+
num_peripherals_range: Tuple[int, int] = (1, 3)
|
| 68 |
+
log_k_a_mean_range: Tuple[float, float] = (-1.5, 1.5)
|
| 69 |
+
log_k_a_std_range: Tuple[float, float] = (0.1, 0.5)
|
| 70 |
+
k_a_tmag_range: Tuple[float, float] = (0.01, 0.1)
|
| 71 |
+
k_a_tscl_range: Tuple[float, float] = (1.0, 5.0)
|
| 72 |
+
log_k_e_mean_range: Tuple[float, float] = (-1.5, 1.5)
|
| 73 |
+
log_k_e_std_range: Tuple[float, float] = (0.1, 0.5)
|
| 74 |
+
k_e_tmag_range: Tuple[float, float] = (0.01, 0.1)
|
| 75 |
+
k_e_tscl_range: Tuple[float, float] = (1.0, 5.0)
|
| 76 |
+
log_V_mean_range: Tuple[float, float] = (-1.5, 1.5)
|
| 77 |
+
log_V_std_range: Tuple[float, float] = (0.1, 0.5)
|
| 78 |
+
V_tmag_range: Tuple[float, float] = (0.01, 0.1)
|
| 79 |
+
V_tscl_range: Tuple[float, float] = (1.0, 5.0)
|
| 80 |
+
log_k_1p_mean_range: Tuple[float, float] = (-1.5, 1.5)
|
| 81 |
+
log_k_1p_std_range: Tuple[float, float] = (0.1, 0.5)
|
| 82 |
+
k_1p_tmag_range: Tuple[float, float] = (0.01, 0.1)
|
| 83 |
+
k_1p_tscl_range: Tuple[float, float] = (1.0, 5.0)
|
| 84 |
+
log_k_p1_mean_range: Tuple[float, float] = (-1.5, 1.5)
|
| 85 |
+
log_k_p1_std_range: Tuple[float, float] = (0.1, 0.5)
|
| 86 |
+
k_p1_tmag_range: Tuple[float, float] = (0.01, 0.1)
|
| 87 |
+
k_p1_tscl_range: Tuple[float, float] = (1.0, 5.0)
|
| 88 |
+
|
| 89 |
+
# Parameters for observation noise
|
| 90 |
+
rel_ruv_range: Tuple[float, float] = (0.05, 0.3)
|
| 91 |
+
|
| 92 |
+
# Parameters for generating time_points
|
| 93 |
+
time_start: float = 0.0
|
| 94 |
+
time_stop: float = 10.0
|
| 95 |
+
time_num_steps: int = 100
|
| 96 |
+
|
| 97 |
+
# parameters for solver
|
| 98 |
+
solver_method: str = "rk4"
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "MetaStudyConfig":
|
| 102 |
+
"""Instantiate the meta-study configuration from a YAML file."""
|
| 103 |
+
|
| 104 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 105 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 106 |
+
|
| 107 |
+
if isinstance(config_dict, dict) and "meta_study" in config_dict:
|
| 108 |
+
config_dict = config_dict.get("meta_study") or {}
|
| 109 |
+
|
| 110 |
+
if not isinstance(config_dict, dict):
|
| 111 |
+
raise TypeError("Expected 'meta_study' section in YAML to be a mapping.")
|
| 112 |
+
|
| 113 |
+
return cls(**config_dict)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass
|
| 117 |
+
class ObservationsConfig:
|
| 118 |
+
"""High-level knobs describing an observation strategy."""
|
| 119 |
+
|
| 120 |
+
# ``None`` (e.g. YAML ``type: null``) is treated as the legacy
|
| 121 |
+
# ``pk_peak_half_life`` strategy by the observation factory.
|
| 122 |
+
type: Optional[str] = "pk_peak_half_life"
|
| 123 |
+
add_rem: bool = True
|
| 124 |
+
split_past_future: bool = False
|
| 125 |
+
min_past: Optional[int] = None
|
| 126 |
+
max_past: Optional[int] = None
|
| 127 |
+
max_num_obs: int = 10
|
| 128 |
+
empirical_number_of_obs: int = 2
|
| 129 |
+
# When True, entries at non-positive times are excluded from sampled
|
| 130 |
+
# observations (e.g. concentration at dosing time t=0).
|
| 131 |
+
drop_time_zero_observations: bool = False
|
| 132 |
+
|
| 133 |
+
# Strategy specific semantic controls (do not affect tensor shapes directly)
|
| 134 |
+
past_time_ratio: float = 0.1 # Used by random strategies with fixed boundary
|
| 135 |
+
# Sampling policy for split-past/future strategies:
|
| 136 |
+
# - False: sample uniformly in [min_past, max_past]
|
| 137 |
+
# - True: sample 0 with prob. 0.5, otherwise sample uniformly
|
| 138 |
+
# in [max(1, min_past), max_past]
|
| 139 |
+
generative_bias: bool = False
|
| 140 |
+
|
| 141 |
+
def __post_init__(self):
|
| 142 |
+
if not isinstance(self.generative_bias, bool):
|
| 143 |
+
raise ValueError("generative_bias must be a boolean (true/false)")
|
| 144 |
+
|
| 145 |
+
if self.split_past_future:
|
| 146 |
+
if self.min_past is None or self.max_past is None:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
"min_past and max_past must be provided when split_past_future=True"
|
| 149 |
+
)
|
| 150 |
+
if self.min_past < 0:
|
| 151 |
+
raise ValueError("min_past must be non-negative")
|
| 152 |
+
if self.max_past < self.min_past:
|
| 153 |
+
raise ValueError("max_past must be >= min_past")
|
| 154 |
+
self.add_rem = True
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def from_yaml(
|
| 158 |
+
cls,
|
| 159 |
+
file_path: Union[str, os.PathLike],
|
| 160 |
+
section: Optional[str] = None,
|
| 161 |
+
) -> "ObservationsConfig":
|
| 162 |
+
"""Instantiate an observation configuration from a YAML file."""
|
| 163 |
+
|
| 164 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 165 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 166 |
+
|
| 167 |
+
if not isinstance(config_dict, dict):
|
| 168 |
+
raise TypeError("Expected YAML content to be a mapping.")
|
| 169 |
+
|
| 170 |
+
if section is not None:
|
| 171 |
+
if section not in config_dict:
|
| 172 |
+
raise KeyError(f"Section '{section}' not found in YAML file '{file_path}'.")
|
| 173 |
+
config_dict = config_dict.get(section) or {}
|
| 174 |
+
else:
|
| 175 |
+
potential_sections = [
|
| 176 |
+
key for key in ("context_observations", "target_observations") if key in config_dict
|
| 177 |
+
]
|
| 178 |
+
if len(potential_sections) > 1:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"Multiple observation sections found; specify which one to load using the 'section' argument."
|
| 181 |
+
)
|
| 182 |
+
if potential_sections:
|
| 183 |
+
config_dict = config_dict.get(potential_sections[0]) or {}
|
| 184 |
+
|
| 185 |
+
if not isinstance(config_dict, dict):
|
| 186 |
+
raise TypeError("Expected observation configuration to be provided as a mapping.")
|
| 187 |
+
|
| 188 |
+
return cls(**config_dict)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@dataclass
|
| 192 |
+
class MixDataConfig:
|
| 193 |
+
"""
|
| 194 |
+
Here we specify how do we construct the mix databatch,
|
| 195 |
+
i.e. if we treat as a the decoder variable one full path
|
| 196 |
+
or if we treat as the decoder variable the future steps of a paht
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
test_empirical_datasets: List[str] = field(default_factory=lambda: ["cesarali/lenuzza-2016"])
|
| 200 |
+
# Deprecated fields removed (unused in current training flow):
|
| 201 |
+
# pretraining_*, val_protocol, test_protocol, split_strategy, split_seed.
|
| 202 |
+
evaluate_prediction_steps_past: int = 4 # lenght of past is kept fix for evaluation
|
| 203 |
+
sample_size_for_generative_evaluation_val: Optional[int] = None
|
| 204 |
+
# Number of generative samples (S) used for validation-time callback
|
| 205 |
+
# evaluation (new individuals and VPC/NPDE consumers). Defaults to 10.
|
| 206 |
+
sample_size_for_generative_evaluation_end_of_training: Optional[int] = None
|
| 207 |
+
# Number of generative samples (S) used for end-of-training callback
|
| 208 |
+
# evaluation (empirical end hooks). Defaults to 500.
|
| 209 |
+
sample_size_for_generative_evaluation: Optional[int] = None
|
| 210 |
+
# Deprecated legacy alias for both values above. When set and the new
|
| 211 |
+
# fields are not provided, the same value is applied to both stages.
|
| 212 |
+
# Value/time normalization flags consumed by PKScaler.
|
| 213 |
+
# Precedence for value scaling:
|
| 214 |
+
# 1) log_and_z=True -> "log_and_z"
|
| 215 |
+
# 2) log_and_max=True -> "log_and_max"
|
| 216 |
+
# 3) log_transform=True -> "log"
|
| 217 |
+
# 4) z_score_normalization=True -> "zscore"
|
| 218 |
+
# 5) normalize_by_max=True -> "max"
|
| 219 |
+
# 6) otherwise -> "none"
|
| 220 |
+
z_score_normalization: bool = False
|
| 221 |
+
# Explicit single switch for log + z-score scaling in PKScaler.
|
| 222 |
+
log_and_z: bool = False
|
| 223 |
+
# Explicit single switch for log + max scaling in PKScaler.
|
| 224 |
+
log_and_max: bool = False
|
| 225 |
+
normalize_by_max: bool = True
|
| 226 |
+
normalize_time: bool = True
|
| 227 |
+
|
| 228 |
+
n_of_permutations: int = 1
|
| 229 |
+
n_of_databatches: Optional[int] = None # deprecated alias
|
| 230 |
+
n_of_target_individuals: int = 1 # ignored for LOO/NO_TARGET
|
| 231 |
+
# Log-only transform flag consumed by PKScaler (value_method="log").
|
| 232 |
+
# This is no longer handled in the dataset/datamodule path.
|
| 233 |
+
log_transform: bool = False # Matches node-pk-1804.yaml
|
| 234 |
+
|
| 235 |
+
store_in_tempfile: bool = False # When True dataset is generated and saved to a temporary file
|
| 236 |
+
keep_tempfile: bool = False # Don't delete the temporary file on cleanup
|
| 237 |
+
recreate_tempfile: bool = False # Regenerate file even if it already exists
|
| 238 |
+
|
| 239 |
+
tempfile_path: Tuple[str, str] = (
|
| 240 |
+
"preprocessed",
|
| 241 |
+
"simulated_ou_as_rates.tr",
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
tqdm_progress: bool = False # Show progress bar when generating temp files
|
| 245 |
+
# DATA SIZES
|
| 246 |
+
train_size: int = 1000
|
| 247 |
+
val_size: int = 100
|
| 248 |
+
test_size: int = 100
|
| 249 |
+
|
| 250 |
+
def __post_init__(self) -> None:
|
| 251 |
+
if self.n_of_databatches is not None and self.n_of_permutations == 1:
|
| 252 |
+
self.n_of_permutations = self.n_of_databatches
|
| 253 |
+
warnings.warn(
|
| 254 |
+
"n_of_databatches is deprecated; use n_of_permutations",
|
| 255 |
+
DeprecationWarning,
|
| 256 |
+
)
|
| 257 |
+
legacy_sample_size = self.sample_size_for_generative_evaluation
|
| 258 |
+
if (
|
| 259 |
+
self.sample_size_for_generative_evaluation_val is None
|
| 260 |
+
and legacy_sample_size is not None
|
| 261 |
+
):
|
| 262 |
+
self.sample_size_for_generative_evaluation_val = int(legacy_sample_size)
|
| 263 |
+
if (
|
| 264 |
+
self.sample_size_for_generative_evaluation_end_of_training is None
|
| 265 |
+
and legacy_sample_size is not None
|
| 266 |
+
):
|
| 267 |
+
self.sample_size_for_generative_evaluation_end_of_training = int(legacy_sample_size)
|
| 268 |
+
|
| 269 |
+
if self.sample_size_for_generative_evaluation_val is None:
|
| 270 |
+
self.sample_size_for_generative_evaluation_val = 10
|
| 271 |
+
if self.sample_size_for_generative_evaluation_end_of_training is None:
|
| 272 |
+
self.sample_size_for_generative_evaluation_end_of_training = 500
|
| 273 |
+
|
| 274 |
+
if int(self.sample_size_for_generative_evaluation_val) < 1:
|
| 275 |
+
raise ValueError("sample_size_for_generative_evaluation_val must be >= 1")
|
| 276 |
+
if int(self.sample_size_for_generative_evaluation_end_of_training) < 1:
|
| 277 |
+
raise ValueError("sample_size_for_generative_evaluation_end_of_training must be >= 1")
|
| 278 |
+
|
| 279 |
+
self.sample_size_for_generative_evaluation_val = int(
|
| 280 |
+
self.sample_size_for_generative_evaluation_val
|
| 281 |
+
)
|
| 282 |
+
self.sample_size_for_generative_evaluation_end_of_training = int(
|
| 283 |
+
self.sample_size_for_generative_evaluation_end_of_training
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if legacy_sample_size is not None:
|
| 287 |
+
warnings.warn(
|
| 288 |
+
"sample_size_for_generative_evaluation is deprecated; use "
|
| 289 |
+
"sample_size_for_generative_evaluation_val and "
|
| 290 |
+
"sample_size_for_generative_evaluation_end_of_training",
|
| 291 |
+
DeprecationWarning,
|
| 292 |
+
)
|
| 293 |
+
if self.n_of_permutations < 1:
|
| 294 |
+
raise ValueError("n_of_permutations must be >= 1")
|
| 295 |
+
|
| 296 |
+
@classmethod
|
| 297 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "MixDataConfig":
|
| 298 |
+
"""Instantiate the mix-data configuration from a YAML file."""
|
| 299 |
+
|
| 300 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 301 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 302 |
+
|
| 303 |
+
if isinstance(config_dict, dict):
|
| 304 |
+
for key in ("mix_data", "mix_data_config"):
|
| 305 |
+
if key in config_dict and isinstance(config_dict[key], dict):
|
| 306 |
+
config_dict = config_dict[key]
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
if not isinstance(config_dict, dict):
|
| 310 |
+
raise TypeError("Expected mix data configuration to be provided as a mapping.")
|
| 311 |
+
|
| 312 |
+
return cls(**config_dict)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@dataclass
|
| 316 |
+
class MetaDosingConfig:
|
| 317 |
+
"""
|
| 318 |
+
Config for specifying meta dosing information.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
num_individuals: int = 10
|
| 322 |
+
same_route: bool = True
|
| 323 |
+
logdose_mean_range: Tuple[float, float] = (-2, 2)
|
| 324 |
+
logdose_std_range: Tuple[float, float] = (0.1, 0.5)
|
| 325 |
+
route_options: List[str] = field(default_factory=lambda: ["oral", "iv"])
|
| 326 |
+
route_weights: List[float] = field(default_factory=lambda: [0.8, 0.2])
|
| 327 |
+
time: float = 0.0
|
| 328 |
+
|
| 329 |
+
@classmethod
|
| 330 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "MetaDosingConfig":
|
| 331 |
+
"""Instantiate the meta-dosing configuration from a YAML file."""
|
| 332 |
+
|
| 333 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 334 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 335 |
+
|
| 336 |
+
if isinstance(config_dict, dict) and "dosing" in config_dict:
|
| 337 |
+
config_dict = config_dict.get("dosing") or {}
|
| 338 |
+
|
| 339 |
+
if not isinstance(config_dict, dict):
|
| 340 |
+
raise TypeError("Expected 'dosing' section in YAML to be a mapping.")
|
| 341 |
+
|
| 342 |
+
return cls(**config_dict)
|
| 343 |
+
|
| 344 |
+
@dataclass
|
| 345 |
+
class MetaDosingWithDurationConfig(MetaDosingConfig):
|
| 346 |
+
"""
|
| 347 |
+
Config for specifying meta dosing information including iv infusions.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
route_duration_weights: Dict[str,float] = field(default_factory=lambda: {"oral": 0.0, "iv": 0.5}) # no duration for oral, 50% chance of infusion for iv
|
| 351 |
+
duration_range: Tuple[float, float] = (0.5, 2.0) # Duration of infusion; 0.0 means bolus
|
| 352 |
+
|
| 353 |
+
@dataclass
|
| 354 |
+
class DosingConfig:
|
| 355 |
+
"""
|
| 356 |
+
Config for specifying dosing information. For now, it just holds the amount D of a single oral dose
|
| 357 |
+
given at time t = 0.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
dose: float = 1.0
|
| 361 |
+
route: str = "oral"
|
| 362 |
+
time: float = 0.0
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
@dataclass
|
| 366 |
+
class DosingWithDurationConfig:
|
| 367 |
+
"""
|
| 368 |
+
Config for specifying dosing information. It holds the amount D of a dose
|
| 369 |
+
given at time t = 0, optionally with an infusion duration.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
dose: float = 1.0
|
| 373 |
+
route: str = "oral"
|
| 374 |
+
time: float = 0.0
|
| 375 |
+
duration: float = 0.0 # Duration of infusion; 0.0 means bolus
|
sim_priors_pk/config_classes/diffusion_pk_config.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import yaml # type: ignore
|
| 7 |
+
|
| 8 |
+
from sim_priors_pk.config_classes.data_config import (
|
| 9 |
+
MetaDosingConfig,
|
| 10 |
+
MetaStudyConfig,
|
| 11 |
+
MixDataConfig,
|
| 12 |
+
ObservationsConfig,
|
| 13 |
+
SimpleMetaStudyConfig,
|
| 14 |
+
)
|
| 15 |
+
from sim_priors_pk.config_classes.node_pk_config import EncoderDecoderNetworkConfig
|
| 16 |
+
from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
|
| 17 |
+
from sim_priors_pk.config_classes.training_config import TrainingConfig
|
| 18 |
+
from sim_priors_pk.config_classes.utils import TupleSafeLoader
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DiffusionPKExperimentConfig:
|
| 23 |
+
"""Experiment configuration dedicated to diffusion PK models."""
|
| 24 |
+
|
| 25 |
+
experiment_type: str = "diffusionpk"
|
| 26 |
+
name_str: str = "ContinuousDiffusionPK"
|
| 27 |
+
diffusion_type: str = "continuous" # "continuous" or "discrete"
|
| 28 |
+
|
| 29 |
+
comet_ai_key: str = None
|
| 30 |
+
experiment_name: str = "diffusion_pk_compartments"
|
| 31 |
+
hugging_face_token: str = None
|
| 32 |
+
upload_to_hf_hub: bool = True
|
| 33 |
+
hf_model_name: str = "DiffusionPK_test"
|
| 34 |
+
hf_model_card_path: Tuple[str, str, str] = ("hf_model_card", "DIFFUSION-PK_Readme.md")
|
| 35 |
+
|
| 36 |
+
tags: List[str] = field(default_factory=lambda: ["diffusion-pk", "B-0"])
|
| 37 |
+
experiment_indentifier: str = None
|
| 38 |
+
my_results_path: str = None
|
| 39 |
+
experiment_dir: str = None
|
| 40 |
+
verbose: bool = False
|
| 41 |
+
run_index: int = 0
|
| 42 |
+
debug_test: bool = False
|
| 43 |
+
|
| 44 |
+
# Diffusion training knob: predict unit Gaussian noise or correlated noise.
|
| 45 |
+
predict_gaussian_noise: bool = True
|
| 46 |
+
|
| 47 |
+
network: EncoderDecoderNetworkConfig = field(default_factory=EncoderDecoderNetworkConfig)
|
| 48 |
+
source_process: SourceProcessConfig = field(default_factory=SourceProcessConfig)
|
| 49 |
+
mix_data: MixDataConfig = field(default_factory=MixDataConfig)
|
| 50 |
+
|
| 51 |
+
context_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
|
| 52 |
+
target_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
|
| 53 |
+
|
| 54 |
+
meta_study: MetaStudyConfig = field(default_factory=MetaStudyConfig)
|
| 55 |
+
dosing: MetaDosingConfig = field(default_factory=MetaDosingConfig)
|
| 56 |
+
|
| 57 |
+
train: TrainingConfig = field(default_factory=TrainingConfig)
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def from_yaml(file_path: str) -> "DiffusionPKExperimentConfig":
|
| 61 |
+
"""Initializes the class from a YAML file."""
|
| 62 |
+
|
| 63 |
+
with open(file_path, "r") as file:
|
| 64 |
+
config_dict = yaml.load(file, Loader=TupleSafeLoader) or {}
|
| 65 |
+
|
| 66 |
+
if not isinstance(config_dict, dict):
|
| 67 |
+
raise TypeError("Expected experiment YAML to be a mapping.")
|
| 68 |
+
|
| 69 |
+
exp_type = config_dict.get("experiment_type")
|
| 70 |
+
if exp_type is not None and str(exp_type).lower() != "diffusionpk":
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"Expected experiment_type 'diffusionpk' for DiffusionPKExperimentConfig, "
|
| 73 |
+
f"got {exp_type!r}."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
base_dir = os.path.dirname(os.path.abspath(file_path))
|
| 77 |
+
|
| 78 |
+
data_cfg_dict = (
|
| 79 |
+
DiffusionPKExperimentConfig._load_ref_yaml(config_dict.get("data_config"), base_dir)
|
| 80 |
+
or {}
|
| 81 |
+
)
|
| 82 |
+
training_cfg_dict = (
|
| 83 |
+
DiffusionPKExperimentConfig._load_ref_yaml(config_dict.get("training_config"), base_dir)
|
| 84 |
+
or {}
|
| 85 |
+
)
|
| 86 |
+
model_cfg_dict = (
|
| 87 |
+
DiffusionPKExperimentConfig._load_ref_yaml(config_dict.get("model_config"), base_dir)
|
| 88 |
+
or {}
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
observations_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 92 |
+
config_dict, base_dir, "observations_config"
|
| 93 |
+
)
|
| 94 |
+
if observations_section is None:
|
| 95 |
+
observations_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 96 |
+
data_cfg_dict, base_dir, "observations_config"
|
| 97 |
+
)
|
| 98 |
+
if observations_section is not None:
|
| 99 |
+
context_observations_base = observations_section.get("context_observations")
|
| 100 |
+
target_observations_base = observations_section.get("target_observations")
|
| 101 |
+
else:
|
| 102 |
+
context_observations_base = data_cfg_dict.get("context_observations")
|
| 103 |
+
target_observations_base = data_cfg_dict.get("target_observations")
|
| 104 |
+
|
| 105 |
+
mix_data_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 106 |
+
config_dict, base_dir, "mix_data_config"
|
| 107 |
+
)
|
| 108 |
+
if mix_data_section is None:
|
| 109 |
+
mix_data_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 110 |
+
data_cfg_dict, base_dir, "mix_data_config"
|
| 111 |
+
)
|
| 112 |
+
if mix_data_section is None:
|
| 113 |
+
mix_data_section = data_cfg_dict.get("mix_data")
|
| 114 |
+
|
| 115 |
+
meta_study_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 116 |
+
config_dict, base_dir, "meta_study_config"
|
| 117 |
+
)
|
| 118 |
+
if meta_study_section is None:
|
| 119 |
+
meta_study_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 120 |
+
data_cfg_dict, base_dir, "meta_study_config"
|
| 121 |
+
)
|
| 122 |
+
meta_dosing_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 123 |
+
config_dict, base_dir, "meta_dosing_config"
|
| 124 |
+
)
|
| 125 |
+
if meta_dosing_section is None:
|
| 126 |
+
meta_dosing_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 127 |
+
data_cfg_dict, base_dir, "meta_dosing_config"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
meta_study_base = DiffusionPKExperimentConfig._extract_config_mapping(
|
| 131 |
+
meta_study_section, "meta_study"
|
| 132 |
+
)
|
| 133 |
+
if meta_study_base is None and meta_dosing_section is not None:
|
| 134 |
+
meta_study_base = DiffusionPKExperimentConfig._extract_config_mapping(
|
| 135 |
+
meta_dosing_section, "meta_study"
|
| 136 |
+
)
|
| 137 |
+
if meta_study_base is None:
|
| 138 |
+
meta_study_base = data_cfg_dict.get("meta_study")
|
| 139 |
+
|
| 140 |
+
dosing_base = DiffusionPKExperimentConfig._extract_config_mapping(
|
| 141 |
+
meta_dosing_section, "dosing"
|
| 142 |
+
)
|
| 143 |
+
if dosing_base is None:
|
| 144 |
+
dosing_base = data_cfg_dict.get("dosing")
|
| 145 |
+
|
| 146 |
+
mix_data_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 147 |
+
mix_data_section, config_dict.get("mix_data")
|
| 148 |
+
)
|
| 149 |
+
context_obs_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 150 |
+
context_observations_base, config_dict.get("context_observations")
|
| 151 |
+
)
|
| 152 |
+
target_obs_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 153 |
+
target_observations_base, config_dict.get("target_observations")
|
| 154 |
+
)
|
| 155 |
+
meta_study_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 156 |
+
meta_study_base, config_dict.get("meta_study")
|
| 157 |
+
)
|
| 158 |
+
dosing_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 159 |
+
dosing_base, config_dict.get("dosing")
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
train_section = training_cfg_dict.get("train", training_cfg_dict)
|
| 163 |
+
train_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 164 |
+
train_section, config_dict.get("train")
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
network_section = model_cfg_dict.get("network", model_cfg_dict)
|
| 168 |
+
network_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 169 |
+
network_section, config_dict.get("network")
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
source_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 173 |
+
config_dict, base_dir, "source_config"
|
| 174 |
+
)
|
| 175 |
+
if source_section is None:
|
| 176 |
+
source_section = DiffusionPKExperimentConfig._resolve_config_section(
|
| 177 |
+
model_cfg_dict, base_dir, "source_config"
|
| 178 |
+
)
|
| 179 |
+
if source_section is None:
|
| 180 |
+
source_section = model_cfg_dict.get("source_process") or model_cfg_dict.get("noise_model")
|
| 181 |
+
source_section = DiffusionPKExperimentConfig._extract_config_mapping(
|
| 182 |
+
source_section, "source_process"
|
| 183 |
+
)
|
| 184 |
+
if isinstance(source_section, dict) and "noise_model" in source_section:
|
| 185 |
+
source_section = source_section.get("noise_model")
|
| 186 |
+
source_cfg = DiffusionPKExperimentConfig._merge_dicts(
|
| 187 |
+
source_section, config_dict.get("source_process")
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if meta_study_cfg.get("simple_mode", False):
|
| 191 |
+
meta_study_instance = SimpleMetaStudyConfig(**meta_study_cfg)
|
| 192 |
+
else:
|
| 193 |
+
meta_study_instance = MetaStudyConfig(**meta_study_cfg)
|
| 194 |
+
|
| 195 |
+
train_cfg = TrainingConfig._filter_kwargs(train_cfg)
|
| 196 |
+
|
| 197 |
+
return DiffusionPKExperimentConfig(
|
| 198 |
+
experiment_type=str(config_dict.get("experiment_type", "diffusionpk")).lower(),
|
| 199 |
+
name_str=config_dict.get("name_str", "ContinuousDiffusionPK"),
|
| 200 |
+
diffusion_type=config_dict.get("diffusion_type", "continuous"),
|
| 201 |
+
tags=config_dict.get("tags", ["diffusion-pk", "B-0"]),
|
| 202 |
+
experiment_name=config_dict.get("experiment_name", "diffusion_pk_compartments"),
|
| 203 |
+
experiment_indentifier=config_dict.get("experiment_indentifier", None),
|
| 204 |
+
my_results_path=config_dict.get("my_results_path", None),
|
| 205 |
+
experiment_dir=config_dict.get("experiment_dir", None),
|
| 206 |
+
comet_ai_key=config_dict.get("comet_ai_key", None),
|
| 207 |
+
hugging_face_token=config_dict.get("hugging_face_token", None),
|
| 208 |
+
upload_to_hf_hub=config_dict.get("upload_to_hf_hub", True),
|
| 209 |
+
hf_model_name=config_dict.get("hf_model_name", "DiffusionPK_test"),
|
| 210 |
+
hf_model_card_path=tuple(
|
| 211 |
+
config_dict.get(
|
| 212 |
+
"hf_model_card_path", ("hf_model_card", "DIFFUSION-PK_Readme.md")
|
| 213 |
+
)
|
| 214 |
+
),
|
| 215 |
+
debug_test=config_dict.get("debug_test", False),
|
| 216 |
+
predict_gaussian_noise=bool(config_dict.get("predict_gaussian_noise", True)),
|
| 217 |
+
network=EncoderDecoderNetworkConfig(**network_cfg),
|
| 218 |
+
source_process=SourceProcessConfig(**source_cfg),
|
| 219 |
+
mix_data=MixDataConfig(**mix_data_cfg),
|
| 220 |
+
context_observations=ObservationsConfig(**context_obs_cfg),
|
| 221 |
+
target_observations=ObservationsConfig(**target_obs_cfg),
|
| 222 |
+
meta_study=meta_study_instance,
|
| 223 |
+
dosing=MetaDosingConfig(**dosing_cfg),
|
| 224 |
+
train=TrainingConfig(**train_cfg),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def _merge_dicts(
|
| 229 |
+
base_dict: Optional[Dict[str, Any]], override_dict: Optional[Dict[str, Any]]
|
| 230 |
+
) -> Dict[str, Any]:
|
| 231 |
+
"""Merge two optional dictionaries returning a new dictionary."""
|
| 232 |
+
|
| 233 |
+
merged: Dict[str, Any] = {}
|
| 234 |
+
|
| 235 |
+
if base_dict:
|
| 236 |
+
if not isinstance(base_dict, dict):
|
| 237 |
+
raise TypeError(
|
| 238 |
+
"Expected base_dict to be a mapping when merging configuration sections."
|
| 239 |
+
)
|
| 240 |
+
merged = deepcopy(base_dict)
|
| 241 |
+
|
| 242 |
+
if override_dict:
|
| 243 |
+
if not isinstance(override_dict, dict):
|
| 244 |
+
raise TypeError(
|
| 245 |
+
"Expected override_dict to be a mapping when merging configuration sections."
|
| 246 |
+
)
|
| 247 |
+
merged.update(override_dict)
|
| 248 |
+
|
| 249 |
+
return merged
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def _extract_config_mapping(
|
| 253 |
+
section: Optional[Dict[str, Any]], nested_key: str
|
| 254 |
+
) -> Optional[Dict[str, Any]]:
|
| 255 |
+
"""Return a nested configuration mapping or the section itself."""
|
| 256 |
+
|
| 257 |
+
if section is None:
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
if not isinstance(section, dict):
|
| 261 |
+
raise TypeError(
|
| 262 |
+
"Expected configuration section to be a mapping when extracting nested"
|
| 263 |
+
f" '{nested_key}' values."
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
if nested_key in section:
|
| 267 |
+
nested_value = section[nested_key]
|
| 268 |
+
if nested_value is None:
|
| 269 |
+
return None
|
| 270 |
+
if not isinstance(nested_value, dict):
|
| 271 |
+
raise TypeError(
|
| 272 |
+
f"Expected '{nested_key}' section to be a mapping when extracting configuration values."
|
| 273 |
+
)
|
| 274 |
+
return nested_value
|
| 275 |
+
|
| 276 |
+
return section
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def _load_ref_yaml(
|
| 280 |
+
ref: Optional[Union[str, Dict[str, Any]]], base_dir: str
|
| 281 |
+
) -> Optional[Dict[str, Any]]:
|
| 282 |
+
"""Load a referenced YAML block or return inline dictionaries as-is."""
|
| 283 |
+
|
| 284 |
+
if ref is None:
|
| 285 |
+
return None
|
| 286 |
+
|
| 287 |
+
if isinstance(ref, dict):
|
| 288 |
+
return ref
|
| 289 |
+
|
| 290 |
+
if isinstance(ref, str):
|
| 291 |
+
ref_path = ref
|
| 292 |
+
if not os.path.isabs(ref_path):
|
| 293 |
+
ref_path = os.path.join(base_dir, ref_path)
|
| 294 |
+
|
| 295 |
+
with open(ref_path, "r") as handle:
|
| 296 |
+
return yaml.load(handle, Loader=TupleSafeLoader) or {}
|
| 297 |
+
|
| 298 |
+
raise TypeError("Expected configuration reference to be a mapping or string path.")
|
| 299 |
+
|
| 300 |
+
@staticmethod
|
| 301 |
+
def _resolve_config_section(
|
| 302 |
+
cfg_dict: Dict[str, Any], base_dir: str, key: str
|
| 303 |
+
) -> Optional[Dict[str, Any]]:
|
| 304 |
+
"""Resolve nested configuration references within a configuration block."""
|
| 305 |
+
|
| 306 |
+
if key not in cfg_dict:
|
| 307 |
+
return None
|
| 308 |
+
|
| 309 |
+
section = cfg_dict[key]
|
| 310 |
+
|
| 311 |
+
if section is None:
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
if isinstance(section, dict):
|
| 315 |
+
ref_value = section.get("_ref") if "_ref" in section else None
|
| 316 |
+
if ref_value is not None:
|
| 317 |
+
loaded = DiffusionPKExperimentConfig._load_ref_yaml(ref_value, base_dir)
|
| 318 |
+
return loaded or {}
|
| 319 |
+
return section
|
| 320 |
+
|
| 321 |
+
if isinstance(section, str):
|
| 322 |
+
loaded = DiffusionPKExperimentConfig._load_ref_yaml(section, base_dir)
|
| 323 |
+
return loaded or {}
|
| 324 |
+
|
| 325 |
+
raise TypeError(
|
| 326 |
+
f"Expected configuration section '{key}' to be a mapping or string reference."
|
| 327 |
+
)
|
sim_priors_pk/config_classes/flow_pk_config.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from dataclasses import asdict, dataclass, field, fields
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
try: # pragma: no cover - exercised indirectly via configuration loading
|
| 8 |
+
import yaml # type: ignore
|
| 9 |
+
except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
|
| 10 |
+
from sim_priors_pk.config_classes import yaml_fallback as yaml
|
| 11 |
+
|
| 12 |
+
try: # pragma: no cover - optional dependency for HF integration
|
| 13 |
+
from transformers import PretrainedConfig # type: ignore
|
| 14 |
+
except ModuleNotFoundError: # pragma: no cover - allow configuration utilities without transformers
|
| 15 |
+
|
| 16 |
+
class PretrainedConfig: # type: ignore
|
| 17 |
+
def __init__(self, **kwargs):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
from sim_priors_pk.config_classes.data_config import (
|
| 21 |
+
MetaDosingConfig,
|
| 22 |
+
MetaStudyConfig,
|
| 23 |
+
MixDataConfig,
|
| 24 |
+
ObservationsConfig,
|
| 25 |
+
SimpleMetaStudyConfig,
|
| 26 |
+
)
|
| 27 |
+
from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
|
| 28 |
+
from sim_priors_pk.config_classes.training_config import TrainingConfig
|
| 29 |
+
from sim_priors_pk.config_classes.utils import TupleSafeLoader
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _to_float(x: Any) -> float:
|
| 33 |
+
try:
|
| 34 |
+
v = float(x)
|
| 35 |
+
except Exception:
|
| 36 |
+
return math.inf
|
| 37 |
+
# guard against NaN
|
| 38 |
+
if math.isnan(v):
|
| 39 |
+
return math.inf
|
| 40 |
+
return v
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _raise_flowpk_network_migration() -> None:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"FlowPK configs no longer accept a 'network' section. "
|
| 46 |
+
"Please rename 'network' to 'vector_field' and set 'experiment_type: flowpk' in your YAML."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class VectorFieldPKConfig:
|
| 52 |
+
"""Configuration for the transformer vector field used by FlowPK."""
|
| 53 |
+
|
| 54 |
+
# Transformer vector field configuration
|
| 55 |
+
hidden_dim: int = 64
|
| 56 |
+
fourier_modes: int = 16
|
| 57 |
+
use_spectral_qkv: bool = False
|
| 58 |
+
time_fourier_max_freq: int = 64
|
| 59 |
+
encoder_num_heads: int = 4
|
| 60 |
+
decoder_num_heads: int = 4
|
| 61 |
+
encoder_attention_layers: int = 2
|
| 62 |
+
decoder_attention_layers: int = 2
|
| 63 |
+
dropout: float = 0.0
|
| 64 |
+
|
| 65 |
+
# # Latent/conditioning settings required by the vector field implementation
|
| 66 |
+
cov_proj_dim: int = 16 # p in the paper
|
| 67 |
+
combine_latent_mode: str = "mlp" # Options: "mlp", "sum"
|
| 68 |
+
zi_latent_dim: int = 200
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "VectorFieldPKConfig":
|
| 72 |
+
"""Instantiate the vector field configuration from a YAML file."""
|
| 73 |
+
|
| 74 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 75 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 76 |
+
|
| 77 |
+
if isinstance(config_dict, dict) and "network" in config_dict:
|
| 78 |
+
_raise_flowpk_network_migration()
|
| 79 |
+
|
| 80 |
+
if isinstance(config_dict, dict) and "vector_field" in config_dict:
|
| 81 |
+
config_dict = config_dict.get("vector_field") or {}
|
| 82 |
+
|
| 83 |
+
if not isinstance(config_dict, dict):
|
| 84 |
+
raise TypeError("Expected 'vector_field' section in YAML to be a mapping.")
|
| 85 |
+
|
| 86 |
+
return cls(**config_dict)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@dataclass
|
| 90 |
+
class FlowPKExperimentConfig:
|
| 91 |
+
"""Experiment configuration for FlowPK (vector field only)."""
|
| 92 |
+
|
| 93 |
+
experiment_type: str = "flowpk"
|
| 94 |
+
name_str: str = "FlowPK"
|
| 95 |
+
comet_ai_key: str = None
|
| 96 |
+
experiment_name: str = "flow_pk_compartments"
|
| 97 |
+
hugging_face_token: str = None
|
| 98 |
+
upload_to_hf_hub: bool = True
|
| 99 |
+
hf_model_name: str = "FlowPK_test"
|
| 100 |
+
hf_model_card_path: Tuple[str, str, str] = ("hf_model_card", "CVAE_Readme.md")
|
| 101 |
+
|
| 102 |
+
tags: List[str] = field(default_factory=lambda: ["flow-pk", "B-0"])
|
| 103 |
+
experiment_indentifier: str = None
|
| 104 |
+
my_results_path: str = None
|
| 105 |
+
experiment_dir: str = None
|
| 106 |
+
verbose: bool = False
|
| 107 |
+
run_index: int = 0
|
| 108 |
+
debug_test: bool = False
|
| 109 |
+
# Default Euler integration steps used by FlowPK sampling when callers
|
| 110 |
+
# do not provide ``num_steps`` explicitly (for example VPC callbacks).
|
| 111 |
+
flow_num_steps: int = 50
|
| 112 |
+
|
| 113 |
+
vector_field: VectorFieldPKConfig = field(default_factory=VectorFieldPKConfig)
|
| 114 |
+
source_process: SourceProcessConfig = field(default_factory=SourceProcessConfig)
|
| 115 |
+
mix_data: MixDataConfig = field(default_factory=MixDataConfig)
|
| 116 |
+
|
| 117 |
+
context_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
|
| 118 |
+
target_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
|
| 119 |
+
|
| 120 |
+
meta_study: MetaStudyConfig = field(default_factory=MetaStudyConfig)
|
| 121 |
+
dosing: MetaDosingConfig = field(default_factory=MetaDosingConfig)
|
| 122 |
+
|
| 123 |
+
train: TrainingConfig = field(default_factory=TrainingConfig)
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def from_yaml(file_path: str) -> "FlowPKExperimentConfig":
|
| 127 |
+
"""Initializes the class from a YAML file.
|
| 128 |
+
|
| 129 |
+
Supports both monolithic experiment YAML files as well as files that
|
| 130 |
+
reference dedicated data, training, and model configuration YAMLs.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
with open(file_path, "r") as file:
|
| 134 |
+
config_dict = yaml.load(file, Loader=TupleSafeLoader) or {}
|
| 135 |
+
|
| 136 |
+
if not isinstance(config_dict, dict):
|
| 137 |
+
raise TypeError("Expected experiment YAML to be a mapping.")
|
| 138 |
+
|
| 139 |
+
exp_type = config_dict.get("experiment_type")
|
| 140 |
+
if exp_type is not None and str(exp_type).lower() != "flowpk":
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"Expected experiment_type 'flowpk' for FlowPKExperimentConfig, got {exp_type!r}."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if "network" in config_dict:
|
| 146 |
+
_raise_flowpk_network_migration()
|
| 147 |
+
|
| 148 |
+
base_dir = os.path.dirname(os.path.abspath(file_path))
|
| 149 |
+
|
| 150 |
+
data_cfg_dict = (
|
| 151 |
+
FlowPKExperimentConfig._load_ref_yaml(config_dict.get("data_config"), base_dir) or {}
|
| 152 |
+
)
|
| 153 |
+
training_cfg_dict = (
|
| 154 |
+
FlowPKExperimentConfig._load_ref_yaml(config_dict.get("training_config"), base_dir)
|
| 155 |
+
or {}
|
| 156 |
+
)
|
| 157 |
+
model_cfg_dict = (
|
| 158 |
+
FlowPKExperimentConfig._load_ref_yaml(config_dict.get("model_config"), base_dir) or {}
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if isinstance(model_cfg_dict, dict) and "network" in model_cfg_dict:
|
| 162 |
+
_raise_flowpk_network_migration()
|
| 163 |
+
|
| 164 |
+
observations_section = FlowPKExperimentConfig._resolve_config_section(
|
| 165 |
+
config_dict, base_dir, "observations_config"
|
| 166 |
+
)
|
| 167 |
+
if observations_section is None:
|
| 168 |
+
observations_section = FlowPKExperimentConfig._resolve_config_section(
|
| 169 |
+
data_cfg_dict, base_dir, "observations_config"
|
| 170 |
+
)
|
| 171 |
+
if observations_section is not None:
|
| 172 |
+
context_observations_base = observations_section.get("context_observations")
|
| 173 |
+
target_observations_base = observations_section.get("target_observations")
|
| 174 |
+
else:
|
| 175 |
+
context_observations_base = data_cfg_dict.get("context_observations")
|
| 176 |
+
target_observations_base = data_cfg_dict.get("target_observations")
|
| 177 |
+
|
| 178 |
+
mix_data_section = FlowPKExperimentConfig._resolve_config_section(
|
| 179 |
+
config_dict, base_dir, "mix_data_config"
|
| 180 |
+
)
|
| 181 |
+
if mix_data_section is None:
|
| 182 |
+
mix_data_section = FlowPKExperimentConfig._resolve_config_section(
|
| 183 |
+
data_cfg_dict, base_dir, "mix_data_config"
|
| 184 |
+
)
|
| 185 |
+
if mix_data_section is None:
|
| 186 |
+
mix_data_section = data_cfg_dict.get("mix_data")
|
| 187 |
+
|
| 188 |
+
meta_study_section = FlowPKExperimentConfig._resolve_config_section(
|
| 189 |
+
config_dict, base_dir, "meta_study_config"
|
| 190 |
+
)
|
| 191 |
+
if meta_study_section is None:
|
| 192 |
+
meta_study_section = FlowPKExperimentConfig._resolve_config_section(
|
| 193 |
+
data_cfg_dict, base_dir, "meta_study_config"
|
| 194 |
+
)
|
| 195 |
+
meta_dosing_section = FlowPKExperimentConfig._resolve_config_section(
|
| 196 |
+
config_dict, base_dir, "meta_dosing_config"
|
| 197 |
+
)
|
| 198 |
+
if meta_dosing_section is None:
|
| 199 |
+
meta_dosing_section = FlowPKExperimentConfig._resolve_config_section(
|
| 200 |
+
data_cfg_dict, base_dir, "meta_dosing_config"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
meta_study_base = FlowPKExperimentConfig._extract_config_mapping(
|
| 204 |
+
meta_study_section, "meta_study"
|
| 205 |
+
)
|
| 206 |
+
if meta_study_base is None and meta_dosing_section is not None:
|
| 207 |
+
meta_study_base = FlowPKExperimentConfig._extract_config_mapping(
|
| 208 |
+
meta_dosing_section, "meta_study"
|
| 209 |
+
)
|
| 210 |
+
if meta_study_base is None:
|
| 211 |
+
meta_study_base = data_cfg_dict.get("meta_study")
|
| 212 |
+
|
| 213 |
+
dosing_base = FlowPKExperimentConfig._extract_config_mapping(meta_dosing_section, "dosing")
|
| 214 |
+
if dosing_base is None:
|
| 215 |
+
dosing_base = data_cfg_dict.get("dosing")
|
| 216 |
+
|
| 217 |
+
mix_data_inline = config_dict.get("mix_data")
|
| 218 |
+
if mix_data_inline is not None and not isinstance(mix_data_inline, dict):
|
| 219 |
+
raise TypeError("Expected 'mix_data' section in experiment YAML to be a mapping.")
|
| 220 |
+
|
| 221 |
+
# Backward compatibility: allow mix-data keys at experiment top-level.
|
| 222 |
+
# Nested `mix_data:` values take precedence over these legacy top-level keys.
|
| 223 |
+
legacy_mix_data_inline = {
|
| 224 |
+
field_meta.name: config_dict[field_meta.name]
|
| 225 |
+
for field_meta in fields(MixDataConfig)
|
| 226 |
+
if field_meta.name in config_dict
|
| 227 |
+
}
|
| 228 |
+
merged_mix_data_inline = dict(legacy_mix_data_inline)
|
| 229 |
+
if isinstance(mix_data_inline, dict):
|
| 230 |
+
merged_mix_data_inline.update(mix_data_inline)
|
| 231 |
+
|
| 232 |
+
mix_data_cfg = FlowPKExperimentConfig._merge_dicts(
|
| 233 |
+
mix_data_section, merged_mix_data_inline
|
| 234 |
+
)
|
| 235 |
+
context_obs_cfg = FlowPKExperimentConfig._merge_dicts(
|
| 236 |
+
context_observations_base, config_dict.get("context_observations")
|
| 237 |
+
)
|
| 238 |
+
target_obs_cfg = FlowPKExperimentConfig._merge_dicts(
|
| 239 |
+
target_observations_base, config_dict.get("target_observations")
|
| 240 |
+
)
|
| 241 |
+
meta_study_cfg = FlowPKExperimentConfig._merge_dicts(
|
| 242 |
+
meta_study_base, config_dict.get("meta_study")
|
| 243 |
+
)
|
| 244 |
+
dosing_cfg = FlowPKExperimentConfig._merge_dicts(dosing_base, config_dict.get("dosing"))
|
| 245 |
+
|
| 246 |
+
train_section = training_cfg_dict.get("train", training_cfg_dict)
|
| 247 |
+
train_cfg = FlowPKExperimentConfig._merge_dicts(train_section, config_dict.get("train"))
|
| 248 |
+
|
| 249 |
+
vector_field_section = model_cfg_dict.get("vector_field", model_cfg_dict)
|
| 250 |
+
if isinstance(vector_field_section, dict) and "network" in vector_field_section:
|
| 251 |
+
_raise_flowpk_network_migration()
|
| 252 |
+
vector_field_cfg = FlowPKExperimentConfig._merge_dicts(
|
| 253 |
+
vector_field_section, config_dict.get("vector_field")
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
source_section = FlowPKExperimentConfig._resolve_config_section(
|
| 257 |
+
config_dict, base_dir, "source_config"
|
| 258 |
+
)
|
| 259 |
+
if source_section is None:
|
| 260 |
+
source_section = FlowPKExperimentConfig._resolve_config_section(
|
| 261 |
+
model_cfg_dict, base_dir, "source_config"
|
| 262 |
+
)
|
| 263 |
+
if source_section is None:
|
| 264 |
+
source_section = model_cfg_dict.get("source_process") or model_cfg_dict.get("noise_model")
|
| 265 |
+
source_section = FlowPKExperimentConfig._extract_config_mapping(
|
| 266 |
+
source_section, "source_process"
|
| 267 |
+
)
|
| 268 |
+
if isinstance(source_section, dict) and "noise_model" in source_section:
|
| 269 |
+
source_section = source_section.get("noise_model")
|
| 270 |
+
|
| 271 |
+
source_cfg = FlowPKExperimentConfig._merge_dicts(
|
| 272 |
+
source_section, config_dict.get("source_process")
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# -----------------------------------------------------------------
|
| 276 |
+
# Choose MetaStudy class dynamically (simple vs full)
|
| 277 |
+
# -----------------------------------------------------------------
|
| 278 |
+
if meta_study_cfg.get("simple_mode", False):
|
| 279 |
+
meta_study_instance = SimpleMetaStudyConfig(**meta_study_cfg)
|
| 280 |
+
else:
|
| 281 |
+
meta_study_instance = MetaStudyConfig(**meta_study_cfg)
|
| 282 |
+
|
| 283 |
+
train_cfg = TrainingConfig._filter_kwargs(train_cfg)
|
| 284 |
+
|
| 285 |
+
return FlowPKExperimentConfig(
|
| 286 |
+
experiment_type=str(config_dict.get("experiment_type", "flowpk")).lower(),
|
| 287 |
+
name_str=config_dict.get("name_str", "FlowPK"),
|
| 288 |
+
tags=config_dict.get("tags", ["flow-pk", "B-0"]),
|
| 289 |
+
experiment_name=config_dict.get("experiment_name", "flow_pk_compartments"),
|
| 290 |
+
experiment_indentifier=config_dict.get("experiment_indentifier", None),
|
| 291 |
+
my_results_path=config_dict.get("my_results_path", None),
|
| 292 |
+
experiment_dir=config_dict.get("experiment_dir", None),
|
| 293 |
+
comet_ai_key=config_dict.get("comet_ai_key", None),
|
| 294 |
+
hugging_face_token=config_dict.get("hugging_face_token", None),
|
| 295 |
+
upload_to_hf_hub=config_dict.get("upload_to_hf_hub", True),
|
| 296 |
+
hf_model_name=config_dict.get("hf_model_name", "FlowPK_test"),
|
| 297 |
+
hf_model_card_path=tuple(
|
| 298 |
+
config_dict.get("hf_model_card_path", ("hf_model_card", "CVAE_Readme.md"))
|
| 299 |
+
),
|
| 300 |
+
debug_test=config_dict.get("debug_test", False),
|
| 301 |
+
flow_num_steps=int(config_dict.get("flow_num_steps", 50)),
|
| 302 |
+
vector_field=VectorFieldPKConfig(**vector_field_cfg),
|
| 303 |
+
source_process=SourceProcessConfig(**source_cfg),
|
| 304 |
+
mix_data=MixDataConfig(**mix_data_cfg),
|
| 305 |
+
context_observations=ObservationsConfig(**context_obs_cfg),
|
| 306 |
+
target_observations=ObservationsConfig(**target_obs_cfg),
|
| 307 |
+
meta_study=meta_study_instance,
|
| 308 |
+
dosing=MetaDosingConfig(**dosing_cfg),
|
| 309 |
+
train=TrainingConfig(**train_cfg),
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
@staticmethod
|
| 313 |
+
def _merge_dicts(
|
| 314 |
+
base_dict: Optional[Dict[str, Any]], override_dict: Optional[Dict[str, Any]]
|
| 315 |
+
) -> Dict[str, Any]:
|
| 316 |
+
"""Merge two optional dictionaries returning a new dictionary."""
|
| 317 |
+
|
| 318 |
+
merged: Dict[str, Any] = {}
|
| 319 |
+
|
| 320 |
+
if base_dict:
|
| 321 |
+
if not isinstance(base_dict, dict):
|
| 322 |
+
raise TypeError(
|
| 323 |
+
"Expected base_dict to be a mapping when merging configuration sections."
|
| 324 |
+
)
|
| 325 |
+
merged = deepcopy(base_dict)
|
| 326 |
+
|
| 327 |
+
if override_dict:
|
| 328 |
+
if not isinstance(override_dict, dict):
|
| 329 |
+
raise TypeError(
|
| 330 |
+
"Expected override_dict to be a mapping when merging configuration sections."
|
| 331 |
+
)
|
| 332 |
+
merged.update(override_dict)
|
| 333 |
+
|
| 334 |
+
return merged
|
| 335 |
+
|
| 336 |
+
@staticmethod
|
| 337 |
+
def _extract_config_mapping(
|
| 338 |
+
section: Optional[Dict[str, Any]], nested_key: str
|
| 339 |
+
) -> Optional[Dict[str, Any]]:
|
| 340 |
+
"""Return a nested configuration mapping or the section itself."""
|
| 341 |
+
|
| 342 |
+
if section is None:
|
| 343 |
+
return None
|
| 344 |
+
|
| 345 |
+
if not isinstance(section, dict):
|
| 346 |
+
raise TypeError(
|
| 347 |
+
"Expected configuration section to be a mapping when extracting nested"
|
| 348 |
+
f" '{nested_key}' values."
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
if nested_key in section:
|
| 352 |
+
nested_value = section[nested_key]
|
| 353 |
+
if nested_value is None:
|
| 354 |
+
return None
|
| 355 |
+
if not isinstance(nested_value, dict):
|
| 356 |
+
raise TypeError(
|
| 357 |
+
f"Expected '{nested_key}' section to be a mapping when extracting configuration values."
|
| 358 |
+
)
|
| 359 |
+
return nested_value
|
| 360 |
+
|
| 361 |
+
return section
|
| 362 |
+
|
| 363 |
+
@staticmethod
|
| 364 |
+
def _load_ref_yaml(
|
| 365 |
+
ref: Optional[Union[str, Dict[str, Any]]], base_dir: str
|
| 366 |
+
) -> Optional[Dict[str, Any]]:
|
| 367 |
+
"""Load a referenced YAML block or return inline dictionaries as-is."""
|
| 368 |
+
|
| 369 |
+
if ref is None:
|
| 370 |
+
return None
|
| 371 |
+
|
| 372 |
+
if isinstance(ref, dict):
|
| 373 |
+
return ref
|
| 374 |
+
|
| 375 |
+
if isinstance(ref, str):
|
| 376 |
+
ref_path = ref
|
| 377 |
+
if not os.path.isabs(ref_path):
|
| 378 |
+
ref_path = os.path.join(base_dir, ref_path)
|
| 379 |
+
|
| 380 |
+
with open(ref_path, "r") as handle:
|
| 381 |
+
return yaml.load(handle, Loader=TupleSafeLoader) or {}
|
| 382 |
+
|
| 383 |
+
raise TypeError("Expected configuration reference to be a mapping or string path.")
|
| 384 |
+
|
| 385 |
+
@staticmethod
|
| 386 |
+
def _resolve_config_section(
|
| 387 |
+
cfg_dict: Dict[str, Any], base_dir: str, key: str
|
| 388 |
+
) -> Optional[Dict[str, Any]]:
|
| 389 |
+
"""Resolve nested configuration references within a configuration block."""
|
| 390 |
+
|
| 391 |
+
if key not in cfg_dict:
|
| 392 |
+
return None
|
| 393 |
+
|
| 394 |
+
section = cfg_dict[key]
|
| 395 |
+
|
| 396 |
+
if section is None:
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
if isinstance(section, dict):
|
| 400 |
+
ref_value = section.get("_ref") if "_ref" in section else None
|
| 401 |
+
if ref_value is not None:
|
| 402 |
+
loaded = FlowPKExperimentConfig._load_ref_yaml(ref_value, base_dir)
|
| 403 |
+
return loaded or {}
|
| 404 |
+
return section
|
| 405 |
+
|
| 406 |
+
if isinstance(section, str):
|
| 407 |
+
loaded = FlowPKExperimentConfig._load_ref_yaml(section, base_dir)
|
| 408 |
+
return loaded or {}
|
| 409 |
+
|
| 410 |
+
raise TypeError(
|
| 411 |
+
f"Expected configuration section '{key}' to be a mapping or string reference."
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
def to_yaml(self, file_path: str):
|
| 415 |
+
"""Saves the class to a YAML file."""
|
| 416 |
+
with open(file_path, "w") as file:
|
| 417 |
+
yaml.dump(asdict(self), file, default_flow_style=False)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class HFFlowPKConfig(PretrainedConfig):
|
| 421 |
+
"""
|
| 422 |
+
HF config wrapping FlowPKExperimentConfig plus tracked metrics.
|
| 423 |
+
|
| 424 |
+
Canonical storage:
|
| 425 |
+
self.tracking: dict with shape
|
| 426 |
+
{
|
| 427 |
+
"best": { "<metric_name>": {"value": float, "step": int|None, "epoch": int|None} },
|
| 428 |
+
"meta": { ...optional... }
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
Backward compat:
|
| 432 |
+
- Accepts legacy keys like best_val_loss / best_val_rmse.
|
| 433 |
+
- Mirrors best["val_rmse"] to `best_val_loss` if you still use that elsewhere.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
model_type = "flow_pk"
|
| 437 |
+
|
| 438 |
+
def __init__(self, **kwargs):
|
| 439 |
+
# --- extract tracking / legacy keys before super().__init__ ---
|
| 440 |
+
tracking = kwargs.pop("tracking", None)
|
| 441 |
+
|
| 442 |
+
# legacy keys (accept either; normalize into tracking)
|
| 443 |
+
legacy_best_val_loss = kwargs.pop("best_val_loss", None)
|
| 444 |
+
legacy_best_val_rmse = kwargs.pop("best_val_rmse", None)
|
| 445 |
+
|
| 446 |
+
super().__init__(**kwargs)
|
| 447 |
+
|
| 448 |
+
# copy remaining config fields
|
| 449 |
+
for k, v in kwargs.items():
|
| 450 |
+
setattr(self, k, v)
|
| 451 |
+
|
| 452 |
+
# initialize tracking
|
| 453 |
+
if tracking is None or not isinstance(tracking, dict):
|
| 454 |
+
tracking = {"best": {}, "meta": {}}
|
| 455 |
+
tracking.setdefault("best", {})
|
| 456 |
+
tracking.setdefault("meta", {})
|
| 457 |
+
self.tracking: Dict[str, Any] = tracking
|
| 458 |
+
|
| 459 |
+
# fold legacy into canonical schema if present
|
| 460 |
+
legacy = legacy_best_val_loss if legacy_best_val_loss is not None else legacy_best_val_rmse
|
| 461 |
+
if legacy is not None:
|
| 462 |
+
# choose a canonical metric name; I'd recommend "val_rmse" if that’s what it is.
|
| 463 |
+
self.set_best("val_rmse", legacy)
|
| 464 |
+
|
| 465 |
+
# optional alias for older codepaths
|
| 466 |
+
self._sync_legacy_aliases()
|
| 467 |
+
|
| 468 |
+
# --------- public API ----------
|
| 469 |
+
def set_best(
|
| 470 |
+
self,
|
| 471 |
+
metric_name: str,
|
| 472 |
+
value: Any,
|
| 473 |
+
*,
|
| 474 |
+
step: Optional[int] = None,
|
| 475 |
+
epoch: Optional[int] = None,
|
| 476 |
+
) -> None:
|
| 477 |
+
v = _to_float(value)
|
| 478 |
+
self.tracking["best"][metric_name] = {"value": v, "step": step, "epoch": epoch}
|
| 479 |
+
self._sync_legacy_aliases()
|
| 480 |
+
|
| 481 |
+
def get_best(self, metric_name: str, default: float = math.inf) -> float:
|
| 482 |
+
d = self.tracking.get("best", {}).get(metric_name)
|
| 483 |
+
if not d:
|
| 484 |
+
return float(default)
|
| 485 |
+
return _to_float(d.get("value", default))
|
| 486 |
+
|
| 487 |
+
def is_better(
|
| 488 |
+
self,
|
| 489 |
+
metric_name: str,
|
| 490 |
+
candidate_value: Any,
|
| 491 |
+
*,
|
| 492 |
+
higher_is_better: bool = False,
|
| 493 |
+
) -> bool:
|
| 494 |
+
cand = _to_float(candidate_value)
|
| 495 |
+
best = self.get_best(metric_name, default=(-math.inf if higher_is_better else math.inf))
|
| 496 |
+
return cand > best if higher_is_better else cand < best
|
| 497 |
+
|
| 498 |
+
def update_if_better(
|
| 499 |
+
self,
|
| 500 |
+
metric_name: str,
|
| 501 |
+
candidate_value: Any,
|
| 502 |
+
*,
|
| 503 |
+
step: Optional[int] = None,
|
| 504 |
+
epoch: Optional[int] = None,
|
| 505 |
+
higher_is_better: bool = False,
|
| 506 |
+
) -> bool:
|
| 507 |
+
if self.is_better(metric_name, candidate_value, higher_is_better=higher_is_better):
|
| 508 |
+
self.set_best(metric_name, candidate_value, step=step, epoch=epoch)
|
| 509 |
+
return True
|
| 510 |
+
return False
|
| 511 |
+
|
| 512 |
+
# --------- construction ----------
|
| 513 |
+
@classmethod
|
| 514 |
+
def from_flowpk(cls, flowpk_cfg, **tracked_best: float) -> "HFFlowPKConfig":
|
| 515 |
+
"""
|
| 516 |
+
tracked_best: e.g. val_rmse=..., val_nll=..., val_crps=...
|
| 517 |
+
"""
|
| 518 |
+
cfg_dict = asdict(flowpk_cfg)
|
| 519 |
+
cfg = cls(**cfg_dict)
|
| 520 |
+
for k, v in tracked_best.items():
|
| 521 |
+
cfg.set_best(k, v)
|
| 522 |
+
return cfg
|
| 523 |
+
|
| 524 |
+
# --------- internal ----------
|
| 525 |
+
def _sync_legacy_aliases(self) -> None:
|
| 526 |
+
"""
|
| 527 |
+
Keep a legacy scalar field for older code that expects `best_val_loss`.
|
| 528 |
+
Here we mirror it to `best["val_rmse"]` by convention.
|
| 529 |
+
"""
|
| 530 |
+
# if val_rmse exists, mirror it; otherwise inf
|
| 531 |
+
self.best_val_loss = self.get_best("val_rmse", default=math.inf)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
FlowPKConfig = FlowPKExperimentConfig
|
sim_priors_pk/config_classes/node_pk_config.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import yaml # type: ignore
|
| 8 |
+
from transformers import PretrainedConfig
|
| 9 |
+
|
| 10 |
+
from sim_priors_pk.config_classes.data_config import (
|
| 11 |
+
MetaDosingConfig,
|
| 12 |
+
MetaStudyConfig,
|
| 13 |
+
MixDataConfig,
|
| 14 |
+
ObservationsConfig,
|
| 15 |
+
SimpleMetaStudyConfig,
|
| 16 |
+
)
|
| 17 |
+
from sim_priors_pk.config_classes.training_config import TrainingConfig
|
| 18 |
+
from sim_priors_pk.config_classes.utils import TupleSafeLoader
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _to_float(x: Any) -> float:
|
| 22 |
+
try:
|
| 23 |
+
v = float(x)
|
| 24 |
+
except Exception:
|
| 25 |
+
return math.inf
|
| 26 |
+
# guard against NaN
|
| 27 |
+
if math.isnan(v):
|
| 28 |
+
return math.inf
|
| 29 |
+
return v
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class EncoderDecoderNetworkConfig:
|
| 34 |
+
"""
|
| 35 |
+
Configuration for the encoder-decoder network.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# Encoder configuration
|
| 39 |
+
individual_encoder_name: str = "RNNContextEncoder"
|
| 40 |
+
time_obs_encoder_hidden_dim: int = 200
|
| 41 |
+
time_obs_encoder_output_dim: int = 200
|
| 42 |
+
rnn_individual_encoder_number_of_layers: int = 2
|
| 43 |
+
individual_encoder_number_of_heads: int = 4
|
| 44 |
+
encoder_rnn_hidden_dim: int = 128
|
| 45 |
+
input_encoding_hidden_dim: int = 128
|
| 46 |
+
zi_latent_dim: int = 200
|
| 47 |
+
use_attention: bool = True
|
| 48 |
+
use_self_attention: bool = False
|
| 49 |
+
use_time_deltas: bool = True
|
| 50 |
+
|
| 51 |
+
# Decoder configuration
|
| 52 |
+
decoder_name: str = "RNNDecoder"
|
| 53 |
+
decoder_num_layers: int = 2
|
| 54 |
+
decoder_attention_layers: int = 2
|
| 55 |
+
decoder_hidden_dim: int = 128
|
| 56 |
+
decoder_rnn_hidden_dim: int = 200
|
| 57 |
+
rnn_decoder_number_of_layers: int = 4
|
| 58 |
+
node_step: bool = True
|
| 59 |
+
exclusive_node_step: bool = False
|
| 60 |
+
cov_proj_dim: int = 16 # p in the paper
|
| 61 |
+
ignore_logvar: bool = True # sampling
|
| 62 |
+
|
| 63 |
+
# Aggregator
|
| 64 |
+
aggregator_type: str = "attention" # attention, mean
|
| 65 |
+
aggregator_num_heads: int = 8
|
| 66 |
+
|
| 67 |
+
# Control reconstruction vs prediction losses
|
| 68 |
+
prediction_only: bool = False
|
| 69 |
+
reconstruction_only: bool = False
|
| 70 |
+
|
| 71 |
+
# Deterministic study latent (disable sampling)
|
| 72 |
+
study_latent_deterministic: bool = False
|
| 73 |
+
|
| 74 |
+
# Deterministic individual latent for prediction
|
| 75 |
+
prediction_latent_deterministic: bool = False
|
| 76 |
+
|
| 77 |
+
# How to combine study and individual latents
|
| 78 |
+
combine_latent_mode: str = "mlp" # Options: "mlp", "sum"
|
| 79 |
+
|
| 80 |
+
# MLP configurations (used in init_hidden, output heads, drift)
|
| 81 |
+
init_hidden_num_layers: int = 2
|
| 82 |
+
output_head_num_layers: int = 2
|
| 83 |
+
drift_num_layers: int = 3
|
| 84 |
+
dropout: float = 0.1
|
| 85 |
+
activation: str = "ReLU" # For init/logvar/mean
|
| 86 |
+
drift_activation: str = "Tanh"
|
| 87 |
+
norm: str = "layer" # Options: "layer", "batch", None
|
| 88 |
+
|
| 89 |
+
# Loss
|
| 90 |
+
loss_name: str = "nll" # Options: "nll", "log_nll", "rmse", mv_nll
|
| 91 |
+
|
| 92 |
+
# latent node pk
|
| 93 |
+
kl_weight: float = 1.0
|
| 94 |
+
|
| 95 |
+
# KL regularisation flags
|
| 96 |
+
use_kl_s: bool = True
|
| 97 |
+
use_kl_i: bool = True
|
| 98 |
+
use_kl_i_np: bool = True
|
| 99 |
+
use_kl_init: bool = True
|
| 100 |
+
use_invariance_loss: bool = True
|
| 101 |
+
|
| 102 |
+
# Optional scaling for dosing amount inputs (route types remain unscaled)
|
| 103 |
+
scale_dosing_amounts: bool = True
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "EncoderDecoderNetworkConfig":
|
| 107 |
+
"""Instantiate the network configuration from a YAML file."""
|
| 108 |
+
|
| 109 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 110 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 111 |
+
|
| 112 |
+
if isinstance(config_dict, dict) and "network" in config_dict:
|
| 113 |
+
config_dict = config_dict.get("network") or {}
|
| 114 |
+
|
| 115 |
+
if not isinstance(config_dict, dict):
|
| 116 |
+
raise TypeError("Expected 'network' section in YAML to be a mapping.")
|
| 117 |
+
|
| 118 |
+
return cls(**config_dict)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class NodePKExperimentConfig:
|
| 123 |
+
"""Experiment configuration for NodePK-family models."""
|
| 124 |
+
|
| 125 |
+
experiment_type: str = "nodepk"
|
| 126 |
+
name_str: str = "NodePK"
|
| 127 |
+
comet_ai_key: str = None
|
| 128 |
+
experiment_name: str = "node_pk_compartments"
|
| 129 |
+
hugging_face_token: str = None
|
| 130 |
+
upload_to_hf_hub: bool = True
|
| 131 |
+
hf_model_name: str = "NodePK_test"
|
| 132 |
+
hf_model_card_path: Tuple[str, str, str] = ("hf_model_card", "CVAE_Readme.md")
|
| 133 |
+
|
| 134 |
+
tags: List[str] = field(default_factory=lambda: ["node-pk", "B-0"])
|
| 135 |
+
experiment_indentifier: str = None
|
| 136 |
+
my_results_path: str = None
|
| 137 |
+
experiment_dir: str = None
|
| 138 |
+
verbose: bool = False
|
| 139 |
+
run_index: int = 0
|
| 140 |
+
debug_test: bool = False
|
| 141 |
+
|
| 142 |
+
network: EncoderDecoderNetworkConfig = field(default_factory=EncoderDecoderNetworkConfig)
|
| 143 |
+
mix_data: MixDataConfig = field(default_factory=MixDataConfig)
|
| 144 |
+
|
| 145 |
+
context_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
|
| 146 |
+
target_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
|
| 147 |
+
|
| 148 |
+
meta_study: MetaStudyConfig = field(default_factory=MetaStudyConfig)
|
| 149 |
+
dosing: MetaDosingConfig = field(default_factory=MetaDosingConfig)
|
| 150 |
+
|
| 151 |
+
train: TrainingConfig = field(default_factory=TrainingConfig)
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def from_yaml(file_path: str) -> "NodePKExperimentConfig":
|
| 155 |
+
"""Initializes the class from a YAML file.
|
| 156 |
+
|
| 157 |
+
Supports both monolithic experiment YAML files as well as files that
|
| 158 |
+
reference dedicated data, training, and model configuration YAMLs.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
with open(file_path, "r") as file:
|
| 162 |
+
config_dict = yaml.load(file, Loader=TupleSafeLoader) or {}
|
| 163 |
+
|
| 164 |
+
exp_type = None
|
| 165 |
+
if isinstance(config_dict, dict):
|
| 166 |
+
exp_type = config_dict.get("experiment_type")
|
| 167 |
+
if exp_type is not None and str(exp_type).lower() != "nodepk":
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"Expected experiment_type 'nodepk' for NodePKExperimentConfig, got {exp_type!r}."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
base_dir = os.path.dirname(os.path.abspath(file_path))
|
| 173 |
+
|
| 174 |
+
data_cfg_dict = (
|
| 175 |
+
NodePKExperimentConfig._load_ref_yaml(config_dict.get("data_config"), base_dir) or {}
|
| 176 |
+
)
|
| 177 |
+
training_cfg_dict = (
|
| 178 |
+
NodePKExperimentConfig._load_ref_yaml(config_dict.get("training_config"), base_dir)
|
| 179 |
+
or {}
|
| 180 |
+
)
|
| 181 |
+
model_cfg_dict = (
|
| 182 |
+
NodePKExperimentConfig._load_ref_yaml(config_dict.get("model_config"), base_dir) or {}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
observations_section = NodePKExperimentConfig._resolve_config_section(
|
| 187 |
+
config_dict, base_dir, "observations_config"
|
| 188 |
+
)
|
| 189 |
+
if observations_section is None:
|
| 190 |
+
observations_section = NodePKExperimentConfig._resolve_config_section(
|
| 191 |
+
data_cfg_dict, base_dir, "observations_config"
|
| 192 |
+
)
|
| 193 |
+
if observations_section is not None:
|
| 194 |
+
context_observations_base = observations_section.get("context_observations")
|
| 195 |
+
target_observations_base = observations_section.get("target_observations")
|
| 196 |
+
else:
|
| 197 |
+
context_observations_base = data_cfg_dict.get("context_observations")
|
| 198 |
+
target_observations_base = data_cfg_dict.get("target_observations")
|
| 199 |
+
|
| 200 |
+
mix_data_section = NodePKExperimentConfig._resolve_config_section(
|
| 201 |
+
config_dict, base_dir, "mix_data_config"
|
| 202 |
+
)
|
| 203 |
+
if mix_data_section is None:
|
| 204 |
+
mix_data_section = NodePKExperimentConfig._resolve_config_section(
|
| 205 |
+
data_cfg_dict, base_dir, "mix_data_config"
|
| 206 |
+
)
|
| 207 |
+
if mix_data_section is None:
|
| 208 |
+
mix_data_section = data_cfg_dict.get("mix_data")
|
| 209 |
+
|
| 210 |
+
meta_study_section = NodePKExperimentConfig._resolve_config_section(
|
| 211 |
+
config_dict, base_dir, "meta_study_config"
|
| 212 |
+
)
|
| 213 |
+
if meta_study_section is None:
|
| 214 |
+
meta_study_section = NodePKExperimentConfig._resolve_config_section(
|
| 215 |
+
data_cfg_dict, base_dir, "meta_study_config"
|
| 216 |
+
)
|
| 217 |
+
meta_dosing_section = NodePKExperimentConfig._resolve_config_section(
|
| 218 |
+
config_dict, base_dir, "meta_dosing_config"
|
| 219 |
+
)
|
| 220 |
+
if meta_dosing_section is None:
|
| 221 |
+
meta_dosing_section = NodePKExperimentConfig._resolve_config_section(
|
| 222 |
+
data_cfg_dict, base_dir, "meta_dosing_config"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
meta_study_base = NodePKExperimentConfig._extract_config_mapping(
|
| 226 |
+
meta_study_section, "meta_study"
|
| 227 |
+
)
|
| 228 |
+
if meta_study_base is None and meta_dosing_section is not None:
|
| 229 |
+
meta_study_base = NodePKExperimentConfig._extract_config_mapping(
|
| 230 |
+
meta_dosing_section, "meta_study"
|
| 231 |
+
)
|
| 232 |
+
if meta_study_base is None:
|
| 233 |
+
meta_study_base = data_cfg_dict.get("meta_study")
|
| 234 |
+
|
| 235 |
+
dosing_base = NodePKExperimentConfig._extract_config_mapping(meta_dosing_section, "dosing")
|
| 236 |
+
if dosing_base is None:
|
| 237 |
+
dosing_base = data_cfg_dict.get("dosing")
|
| 238 |
+
|
| 239 |
+
mix_data_cfg = NodePKExperimentConfig._merge_dicts(
|
| 240 |
+
mix_data_section, config_dict.get("mix_data")
|
| 241 |
+
)
|
| 242 |
+
context_obs_cfg = NodePKExperimentConfig._merge_dicts(
|
| 243 |
+
context_observations_base, config_dict.get("context_observations")
|
| 244 |
+
)
|
| 245 |
+
target_obs_cfg = NodePKExperimentConfig._merge_dicts(
|
| 246 |
+
target_observations_base, config_dict.get("target_observations")
|
| 247 |
+
)
|
| 248 |
+
meta_study_cfg = NodePKExperimentConfig._merge_dicts(
|
| 249 |
+
meta_study_base, config_dict.get("meta_study")
|
| 250 |
+
)
|
| 251 |
+
dosing_cfg = NodePKExperimentConfig._merge_dicts(dosing_base, config_dict.get("dosing"))
|
| 252 |
+
|
| 253 |
+
train_section = training_cfg_dict.get("train", training_cfg_dict)
|
| 254 |
+
train_cfg = NodePKExperimentConfig._merge_dicts(train_section, config_dict.get("train"))
|
| 255 |
+
|
| 256 |
+
network_section = model_cfg_dict.get("network", model_cfg_dict)
|
| 257 |
+
network_cfg = NodePKExperimentConfig._merge_dicts(
|
| 258 |
+
network_section, config_dict.get("network")
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# -----------------------------------------------------------------
|
| 262 |
+
# Choose MetaStudy class dynamically (simple vs full)
|
| 263 |
+
# -----------------------------------------------------------------
|
| 264 |
+
if meta_study_cfg.get("simple_mode", False):
|
| 265 |
+
meta_study_instance = SimpleMetaStudyConfig(**meta_study_cfg)
|
| 266 |
+
else:
|
| 267 |
+
meta_study_instance = MetaStudyConfig(**meta_study_cfg)
|
| 268 |
+
|
| 269 |
+
train_cfg = TrainingConfig._filter_kwargs(train_cfg)
|
| 270 |
+
|
| 271 |
+
return NodePKExperimentConfig(
|
| 272 |
+
experiment_type=str(config_dict.get("experiment_type", "nodepk")).lower(),
|
| 273 |
+
name_str=config_dict.get("name_str", "ExampleModel"),
|
| 274 |
+
tags=config_dict.get("tags", ["node-pk", "B-0"]),
|
| 275 |
+
experiment_name=config_dict.get("experiment_name", "aicme_compartments"),
|
| 276 |
+
experiment_indentifier=config_dict.get("experiment_indentifier", None),
|
| 277 |
+
my_results_path=config_dict.get("my_results_path", None),
|
| 278 |
+
experiment_dir=config_dict.get("experiment_dir", None),
|
| 279 |
+
comet_ai_key=config_dict.get("comet_ai_key", None),
|
| 280 |
+
hugging_face_token=config_dict.get("hugging_face_token", None),
|
| 281 |
+
upload_to_hf_hub=config_dict.get("upload_to_hf_hub", True),
|
| 282 |
+
hf_model_name=config_dict.get("hf_model_name", "NodePK_test"),
|
| 283 |
+
hf_model_card_path=tuple(
|
| 284 |
+
config_dict.get("hf_model_card_path", ("hf_model_card", "CVAE_Readme.md"))
|
| 285 |
+
),
|
| 286 |
+
debug_test=config_dict.get("debug_test", False),
|
| 287 |
+
network=EncoderDecoderNetworkConfig(**network_cfg),
|
| 288 |
+
mix_data=MixDataConfig(**mix_data_cfg),
|
| 289 |
+
context_observations=ObservationsConfig(**context_obs_cfg),
|
| 290 |
+
target_observations=ObservationsConfig(**target_obs_cfg),
|
| 291 |
+
meta_study=meta_study_instance,
|
| 292 |
+
dosing=MetaDosingConfig(**dosing_cfg),
|
| 293 |
+
train=TrainingConfig(**train_cfg),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
@staticmethod
|
| 297 |
+
def _merge_dicts(
|
| 298 |
+
base_dict: Optional[Dict[str, Any]], override_dict: Optional[Dict[str, Any]]
|
| 299 |
+
) -> Dict[str, Any]:
|
| 300 |
+
"""Merge two optional dictionaries returning a new dictionary."""
|
| 301 |
+
|
| 302 |
+
merged: Dict[str, Any] = {}
|
| 303 |
+
|
| 304 |
+
if base_dict:
|
| 305 |
+
if not isinstance(base_dict, dict):
|
| 306 |
+
raise TypeError(
|
| 307 |
+
"Expected base_dict to be a mapping when merging configuration sections."
|
| 308 |
+
)
|
| 309 |
+
merged = deepcopy(base_dict)
|
| 310 |
+
|
| 311 |
+
if override_dict:
|
| 312 |
+
if not isinstance(override_dict, dict):
|
| 313 |
+
raise TypeError(
|
| 314 |
+
"Expected override_dict to be a mapping when merging configuration sections."
|
| 315 |
+
)
|
| 316 |
+
merged.update(override_dict)
|
| 317 |
+
|
| 318 |
+
return merged
|
| 319 |
+
|
| 320 |
+
@staticmethod
|
| 321 |
+
def _extract_config_mapping(
|
| 322 |
+
section: Optional[Dict[str, Any]], nested_key: str
|
| 323 |
+
) -> Optional[Dict[str, Any]]:
|
| 324 |
+
"""Return a nested configuration mapping or the section itself."""
|
| 325 |
+
|
| 326 |
+
if section is None:
|
| 327 |
+
return None
|
| 328 |
+
|
| 329 |
+
if not isinstance(section, dict):
|
| 330 |
+
raise TypeError(
|
| 331 |
+
"Expected configuration section to be a mapping when extracting nested"
|
| 332 |
+
f" '{nested_key}' values."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if nested_key in section:
|
| 336 |
+
nested_value = section[nested_key]
|
| 337 |
+
if nested_value is None:
|
| 338 |
+
return None
|
| 339 |
+
if not isinstance(nested_value, dict):
|
| 340 |
+
raise TypeError(
|
| 341 |
+
f"Expected '{nested_key}' section to be a mapping when extracting configuration values."
|
| 342 |
+
)
|
| 343 |
+
return nested_value
|
| 344 |
+
|
| 345 |
+
return section
|
| 346 |
+
|
| 347 |
+
@staticmethod
|
| 348 |
+
def _load_ref_yaml(
|
| 349 |
+
ref: Optional[Union[str, Dict[str, Any]]], base_dir: str
|
| 350 |
+
) -> Optional[Dict[str, Any]]:
|
| 351 |
+
"""Load a referenced YAML block or return inline dictionaries as-is."""
|
| 352 |
+
|
| 353 |
+
if ref is None:
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
if isinstance(ref, dict):
|
| 357 |
+
return ref
|
| 358 |
+
|
| 359 |
+
if isinstance(ref, str):
|
| 360 |
+
ref_path = ref
|
| 361 |
+
if not os.path.isabs(ref_path):
|
| 362 |
+
ref_path = os.path.join(base_dir, ref_path)
|
| 363 |
+
|
| 364 |
+
with open(ref_path, "r") as handle:
|
| 365 |
+
return yaml.load(handle, Loader=TupleSafeLoader) or {}
|
| 366 |
+
|
| 367 |
+
raise TypeError("Expected configuration reference to be a mapping or string path.")
|
| 368 |
+
|
| 369 |
+
@staticmethod
|
| 370 |
+
def _resolve_config_section(
|
| 371 |
+
cfg_dict: Dict[str, Any], base_dir: str, key: str
|
| 372 |
+
) -> Optional[Dict[str, Any]]:
|
| 373 |
+
"""Resolve nested configuration references within a configuration block."""
|
| 374 |
+
|
| 375 |
+
if key not in cfg_dict:
|
| 376 |
+
return None
|
| 377 |
+
|
| 378 |
+
section = cfg_dict[key]
|
| 379 |
+
|
| 380 |
+
if section is None:
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
if isinstance(section, dict):
|
| 384 |
+
ref_value = section.get("_ref") if "_ref" in section else None
|
| 385 |
+
if ref_value is not None:
|
| 386 |
+
loaded = NodePKExperimentConfig._load_ref_yaml(ref_value, base_dir)
|
| 387 |
+
return loaded or {}
|
| 388 |
+
return section
|
| 389 |
+
|
| 390 |
+
if isinstance(section, str):
|
| 391 |
+
loaded = NodePKExperimentConfig._load_ref_yaml(section, base_dir)
|
| 392 |
+
return loaded or {}
|
| 393 |
+
|
| 394 |
+
raise TypeError(
|
| 395 |
+
f"Expected configuration section '{key}' to be a mapping or string reference."
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
def to_yaml(self, file_path: str):
|
| 399 |
+
"""Saves the class to a YAML file."""
|
| 400 |
+
with open(file_path, "w") as file:
|
| 401 |
+
yaml.dump(asdict(self), file, default_flow_style=False)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
NodePKConfig = NodePKExperimentConfig
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class HFNodePKConfig(PretrainedConfig):
|
| 408 |
+
"""
|
| 409 |
+
HF config wrapping NodePKConfig plus tracked metrics.
|
| 410 |
+
|
| 411 |
+
Canonical storage:
|
| 412 |
+
self.tracking: dict with shape
|
| 413 |
+
{
|
| 414 |
+
"best": { "<metric_name>": {"value": float, "step": int|None, "epoch": int|None} },
|
| 415 |
+
"meta": { ...optional... }
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
Backward compat:
|
| 419 |
+
- Accepts legacy keys like best_val_loss / best_val_rmse.
|
| 420 |
+
- Mirrors best["val_rmse"] to `best_val_loss` if you still use that elsewhere.
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
model_type = "node_pk"
|
| 424 |
+
|
| 425 |
+
def __init__(self, **kwargs):
|
| 426 |
+
# --- extract tracking / legacy keys before super().__init__ ---
|
| 427 |
+
tracking = kwargs.pop("tracking", None)
|
| 428 |
+
|
| 429 |
+
# legacy keys (accept either; normalize into tracking)
|
| 430 |
+
legacy_best_val_loss = kwargs.pop("best_val_loss", None)
|
| 431 |
+
legacy_best_val_rmse = kwargs.pop("best_val_rmse", None)
|
| 432 |
+
|
| 433 |
+
super().__init__(**kwargs)
|
| 434 |
+
|
| 435 |
+
# copy remaining config fields
|
| 436 |
+
for k, v in kwargs.items():
|
| 437 |
+
setattr(self, k, v)
|
| 438 |
+
|
| 439 |
+
# initialize tracking
|
| 440 |
+
if tracking is None or not isinstance(tracking, dict):
|
| 441 |
+
tracking = {"best": {}, "meta": {}}
|
| 442 |
+
tracking.setdefault("best", {})
|
| 443 |
+
tracking.setdefault("meta", {})
|
| 444 |
+
self.tracking: Dict[str, Any] = tracking
|
| 445 |
+
|
| 446 |
+
# fold legacy into canonical schema if present
|
| 447 |
+
legacy = legacy_best_val_loss if legacy_best_val_loss is not None else legacy_best_val_rmse
|
| 448 |
+
if legacy is not None:
|
| 449 |
+
# choose a canonical metric name; I'd recommend "val_rmse" if that’s what it is.
|
| 450 |
+
self.set_best("val_rmse", legacy)
|
| 451 |
+
|
| 452 |
+
# optional alias for older codepaths
|
| 453 |
+
self._sync_legacy_aliases()
|
| 454 |
+
|
| 455 |
+
# --------- public API ----------
|
| 456 |
+
def set_best(
|
| 457 |
+
self,
|
| 458 |
+
metric_name: str,
|
| 459 |
+
value: Any,
|
| 460 |
+
*,
|
| 461 |
+
step: Optional[int] = None,
|
| 462 |
+
epoch: Optional[int] = None,
|
| 463 |
+
) -> None:
|
| 464 |
+
v = _to_float(value)
|
| 465 |
+
self.tracking["best"][metric_name] = {"value": v, "step": step, "epoch": epoch}
|
| 466 |
+
self._sync_legacy_aliases()
|
| 467 |
+
|
| 468 |
+
def get_best(self, metric_name: str, default: float = math.inf) -> float:
|
| 469 |
+
d = self.tracking.get("best", {}).get(metric_name)
|
| 470 |
+
if not d:
|
| 471 |
+
return float(default)
|
| 472 |
+
return _to_float(d.get("value", default))
|
| 473 |
+
|
| 474 |
+
def is_better(
|
| 475 |
+
self,
|
| 476 |
+
metric_name: str,
|
| 477 |
+
candidate_value: Any,
|
| 478 |
+
*,
|
| 479 |
+
higher_is_better: bool = False,
|
| 480 |
+
) -> bool:
|
| 481 |
+
cand = _to_float(candidate_value)
|
| 482 |
+
best = self.get_best(metric_name, default=(-math.inf if higher_is_better else math.inf))
|
| 483 |
+
return cand > best if higher_is_better else cand < best
|
| 484 |
+
|
| 485 |
+
def update_if_better(
|
| 486 |
+
self,
|
| 487 |
+
metric_name: str,
|
| 488 |
+
candidate_value: Any,
|
| 489 |
+
*,
|
| 490 |
+
step: Optional[int] = None,
|
| 491 |
+
epoch: Optional[int] = None,
|
| 492 |
+
higher_is_better: bool = False,
|
| 493 |
+
) -> bool:
|
| 494 |
+
if self.is_better(metric_name, candidate_value, higher_is_better=higher_is_better):
|
| 495 |
+
self.set_best(metric_name, candidate_value, step=step, epoch=epoch)
|
| 496 |
+
return True
|
| 497 |
+
return False
|
| 498 |
+
|
| 499 |
+
# --------- construction ----------
|
| 500 |
+
@classmethod
|
| 501 |
+
def from_nodepk(cls, nodepk_cfg, **tracked_best: float) -> "HFNodePKConfig":
|
| 502 |
+
"""
|
| 503 |
+
tracked_best: e.g. val_rmse=..., val_nll=..., val_crps=...
|
| 504 |
+
"""
|
| 505 |
+
cfg_dict = asdict(nodepk_cfg)
|
| 506 |
+
cfg = cls(**cfg_dict)
|
| 507 |
+
for k, v in tracked_best.items():
|
| 508 |
+
cfg.set_best(k, v)
|
| 509 |
+
return cfg
|
| 510 |
+
|
| 511 |
+
# --------- internal ----------
|
| 512 |
+
def _sync_legacy_aliases(self) -> None:
|
| 513 |
+
"""
|
| 514 |
+
Keep a legacy scalar field for older code that expects `best_val_loss`.
|
| 515 |
+
Here we mirror it to `best["val_rmse"]` by convention.
|
| 516 |
+
"""
|
| 517 |
+
# if val_rmse exists, mirror it; otherwise inf
|
| 518 |
+
self.best_val_loss = self.get_best("val_rmse", default=math.inf)
|
sim_priors_pk/config_classes/source_process_config.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
try: # pragma: no cover - exercised indirectly via configuration loading
|
| 6 |
+
import yaml # type: ignore
|
| 7 |
+
except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
|
| 8 |
+
from sim_priors_pk.config_classes import yaml_fallback as yaml
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SourceProcessConfig:
|
| 13 |
+
"""
|
| 14 |
+
Configuration for source processes used by flow and diffusion PK models.
|
| 15 |
+
|
| 16 |
+
Supported source_type values (case-insensitive):
|
| 17 |
+
- "gaussian_process" / "gp"
|
| 18 |
+
- "ornstein_uhlenbeck" / "ou"
|
| 19 |
+
- "wiener"
|
| 20 |
+
- "normal" / "gaussian"
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
source_type: str = "gaussian_process"
|
| 24 |
+
|
| 25 |
+
# Gaussian process hyper-parameter for RBF or OU.
|
| 26 |
+
gp_length_scale: float = 0.1
|
| 27 |
+
gp_variance: float = 1.0
|
| 28 |
+
gp_eps: float = 1e-8
|
| 29 |
+
gp_transform: str = 'softplus' # transformation to apply to the sampled noise, e.g. 'softplus', 'exp'
|
| 30 |
+
|
| 31 |
+
# Flow matching additive noise scale (used only in FlowPK).
|
| 32 |
+
flow_sigma: float = 1e-4
|
| 33 |
+
flow_num_steps: int = 100
|
| 34 |
+
use_OT_coupling: bool = False
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "SourceProcessConfig":
|
| 38 |
+
"""Instantiate the source-process configuration from a YAML file."""
|
| 39 |
+
|
| 40 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 41 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 42 |
+
|
| 43 |
+
if isinstance(config_dict, dict):
|
| 44 |
+
for key in ("source_process", "source", "noise_model"):
|
| 45 |
+
if key in config_dict:
|
| 46 |
+
config_dict = config_dict.get(key) or {}
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
if not isinstance(config_dict, dict):
|
| 50 |
+
raise TypeError("Expected source process configuration to be a mapping.")
|
| 51 |
+
|
| 52 |
+
return cls(**config_dict)
|
sim_priors_pk/config_classes/training_config.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field, fields
|
| 3 |
+
from typing import Any, Dict, List, Optional, Union
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class SchedulerTaskConfig:
|
| 10 |
+
"""Typed configuration for one scheduler task."""
|
| 11 |
+
|
| 12 |
+
name: str
|
| 13 |
+
fn_key: str
|
| 14 |
+
n_samples: int = 0
|
| 15 |
+
sample_source: str = "unconditional"
|
| 16 |
+
split: str = "val"
|
| 17 |
+
empirical_name: Optional[str] = None
|
| 18 |
+
save_to_disk: bool = True
|
| 19 |
+
log_prefix: str = "val"
|
| 20 |
+
use_ema: bool = False
|
| 21 |
+
checkpoint_metric: bool = False
|
| 22 |
+
checkpoint_metric_name: Optional[str] = None
|
| 23 |
+
checkpoint_mode: str = "min"
|
| 24 |
+
task_cfg: Dict[str, Any] = field(default_factory=dict)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class SchedulerConfig:
|
| 29 |
+
"""Typed configuration for scheduler-driven callback execution."""
|
| 30 |
+
|
| 31 |
+
percent_step: float = 1.0
|
| 32 |
+
include_end: bool = True
|
| 33 |
+
skip_sanity_check: bool = True
|
| 34 |
+
store_samples: bool = True
|
| 35 |
+
max_samples_per_group: int = 32
|
| 36 |
+
keep_temp_files: bool = False
|
| 37 |
+
cache_dir: Optional[str] = None
|
| 38 |
+
# Supported selectors are:
|
| 39 |
+
# - ``end`` for in-memory train-end weights,
|
| 40 |
+
# - ``last`` / ``best`` for experiment checkpoint callbacks,
|
| 41 |
+
# - scheduler-managed metric checkpoint names emitted by tasks.
|
| 42 |
+
checkpoint_used_in_end: List[str] = field(default_factory=lambda: ["end"])
|
| 43 |
+
tasks_validation: List[SchedulerTaskConfig] = field(default_factory=list)
|
| 44 |
+
task_during: List[SchedulerTaskConfig] = field(default_factory=list)
|
| 45 |
+
tasks_end: List[SchedulerTaskConfig] = field(default_factory=list)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class TrainingConfig:
|
| 50 |
+
epochs: int = 20
|
| 51 |
+
batch_size: int = 8
|
| 52 |
+
gradient_clip_val: float = 1.0
|
| 53 |
+
optimizer_name: str = "AdamW"
|
| 54 |
+
learning_rate: float = 0.0001
|
| 55 |
+
weight_decay: float = 1.0e-4
|
| 56 |
+
num_workers: int = 3
|
| 57 |
+
persistent_workers: bool = True
|
| 58 |
+
shuffle_val: bool = True
|
| 59 |
+
|
| 60 |
+
num_batch_plot: int = 1
|
| 61 |
+
log_interval: int = 1 # Frequency of logging and visualization
|
| 62 |
+
|
| 63 |
+
# Scheduler-driven PK evaluation and visualization.
|
| 64 |
+
callbacks_scheduler: Optional[Union[SchedulerConfig, Dict[str, Any]]] = None
|
| 65 |
+
|
| 66 |
+
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
|
| 67 |
+
eps: float = 1.0e-8
|
| 68 |
+
amsgrad: bool = False
|
| 69 |
+
scheduler_name: str = "CosineAnnealingLR"
|
| 70 |
+
scheduler_params: Dict[str, Union[float, int]] = field(
|
| 71 |
+
default_factory=lambda: {"T_max": 1000, "eta_min": 5.0e-5, "last_epoch": -1}
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "TrainingConfig":
|
| 76 |
+
"""Instantiate the training configuration from a YAML file."""
|
| 77 |
+
|
| 78 |
+
with open(file_path, "r", encoding="utf-8") as handle:
|
| 79 |
+
config_dict = yaml.safe_load(handle) or {}
|
| 80 |
+
|
| 81 |
+
if isinstance(config_dict, dict) and "train" in config_dict:
|
| 82 |
+
config_dict = config_dict.get("train") or {}
|
| 83 |
+
|
| 84 |
+
if not isinstance(config_dict, dict):
|
| 85 |
+
raise TypeError("Expected 'train' section in YAML to be a mapping.")
|
| 86 |
+
|
| 87 |
+
return cls(**cls._filter_kwargs(config_dict))
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def _filter_kwargs(cls, raw: Dict[str, Any]) -> Dict[str, Any]:
|
| 91 |
+
"""Drop unknown keys (including deprecated logging flags)."""
|
| 92 |
+
|
| 93 |
+
if not isinstance(raw, dict):
|
| 94 |
+
return {}
|
| 95 |
+
valid = {f.name for f in fields(cls)}
|
| 96 |
+
return {key: value for key, value in raw.items() if key in valid}
|
sim_priors_pk/config_classes/utils.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try: # pragma: no cover - exercised indirectly via configuration loading
|
| 2 |
+
import yaml # type: ignore
|
| 3 |
+
from yaml import SafeLoader # type: ignore
|
| 4 |
+
except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
|
| 5 |
+
from sim_priors_pk.config_classes import yaml_fallback as yaml
|
| 6 |
+
SafeLoader = yaml.SafeLoader
|
| 7 |
+
|
| 8 |
+
class TupleSafeLoader(SafeLoader):
|
| 9 |
+
def construct_python_tuple(self, node):
|
| 10 |
+
# Convert the YAML sequence (e.g., [0.01, 0.1]) into a tuple
|
| 11 |
+
return tuple(self.construct_sequence(node))
|
| 12 |
+
|
| 13 |
+
# Register the constructor for the fully qualified tag
|
| 14 |
+
TupleSafeLoader.add_constructor('tag:yaml.org,2002:python/tuple', TupleSafeLoader.construct_python_tuple)
|
sim_priors_pk/config_classes/yaml_fallback.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal YAML loader fallback used when PyYAML is unavailable."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import ast
|
| 6 |
+
import json
|
| 7 |
+
from typing import Any, Dict, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafeLoader:
|
| 11 |
+
"""Compatibility stub mimicking :class:`yaml.SafeLoader`."""
|
| 12 |
+
|
| 13 |
+
_constructors: Dict[str, Any] = {}
|
| 14 |
+
|
| 15 |
+
@classmethod
|
| 16 |
+
def add_constructor(cls, tag: str, constructor: Any) -> None:
|
| 17 |
+
cls._constructors[tag] = constructor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
Loader = SafeLoader
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _convert_scalar(value: str) -> Any:
|
| 24 |
+
lowered = value.lower()
|
| 25 |
+
if lowered in {"true", "yes"}:
|
| 26 |
+
return True
|
| 27 |
+
if lowered in {"false", "no"}:
|
| 28 |
+
return False
|
| 29 |
+
if lowered in {"null", "none", "~"}:
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
if value.startswith("[") or value.startswith("{") or value.startswith("("):
|
| 33 |
+
try:
|
| 34 |
+
return ast.literal_eval(value)
|
| 35 |
+
except (SyntaxError, ValueError):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
if value.startswith("\"") and value.endswith("\""):
|
| 39 |
+
return value[1:-1]
|
| 40 |
+
if value.startswith("'") and value.endswith("'"):
|
| 41 |
+
return value[1:-1]
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
if "." in value or "e" in lowered:
|
| 45 |
+
return float(value)
|
| 46 |
+
return int(value)
|
| 47 |
+
except ValueError:
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
return value
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _parse_lines(lines: List[str], indent: int = 0) -> Any:
|
| 54 |
+
mapping: Dict[str, Any] = {}
|
| 55 |
+
sequence: List[Any] = []
|
| 56 |
+
is_list: bool | None = None
|
| 57 |
+
|
| 58 |
+
while lines:
|
| 59 |
+
line = lines[0]
|
| 60 |
+
stripped = line.lstrip()
|
| 61 |
+
|
| 62 |
+
if not stripped or stripped.startswith("#"):
|
| 63 |
+
lines.pop(0)
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
current_indent = len(line) - len(stripped)
|
| 67 |
+
|
| 68 |
+
if current_indent < indent and not stripped.startswith("- "):
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
if stripped.startswith("- "):
|
| 72 |
+
if is_list is False:
|
| 73 |
+
raise ValueError("Mixed mapping and sequence at the same level is unsupported.")
|
| 74 |
+
is_list = True
|
| 75 |
+
|
| 76 |
+
lines.pop(0)
|
| 77 |
+
item_value = stripped[2:].strip()
|
| 78 |
+
|
| 79 |
+
if not item_value:
|
| 80 |
+
sequence.append(_parse_lines(lines, current_indent + 2))
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
if item_value.endswith(":"):
|
| 84 |
+
key = item_value[:-1].strip()
|
| 85 |
+
value = _parse_lines(lines, current_indent + 2)
|
| 86 |
+
sequence.append({key: value})
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
sequence.append(_convert_scalar(item_value))
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
if is_list is True:
|
| 93 |
+
raise ValueError("Mixed mapping and sequence at the same level is unsupported.")
|
| 94 |
+
|
| 95 |
+
is_list = False
|
| 96 |
+
|
| 97 |
+
lines.pop(0)
|
| 98 |
+
if ":" not in stripped:
|
| 99 |
+
raise ValueError(f"Invalid mapping entry: '{stripped}'.")
|
| 100 |
+
|
| 101 |
+
key, value_part = stripped.split(":", 1)
|
| 102 |
+
key = key.strip()
|
| 103 |
+
value_part = value_part.strip()
|
| 104 |
+
|
| 105 |
+
if value_part:
|
| 106 |
+
mapping[key] = _convert_scalar(value_part)
|
| 107 |
+
else:
|
| 108 |
+
mapping[key] = _parse_lines(lines, current_indent + 2)
|
| 109 |
+
|
| 110 |
+
if is_list:
|
| 111 |
+
return sequence
|
| 112 |
+
return mapping
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def safe_load(stream: Any) -> Any:
|
| 116 |
+
"""Parse YAML content from ``stream`` and return Python data structures."""
|
| 117 |
+
|
| 118 |
+
if hasattr(stream, "read"):
|
| 119 |
+
content = stream.read()
|
| 120 |
+
else:
|
| 121 |
+
content = stream
|
| 122 |
+
|
| 123 |
+
if not isinstance(content, str):
|
| 124 |
+
raise TypeError("YAML content must be a string or text stream.")
|
| 125 |
+
|
| 126 |
+
raw_lines = content.splitlines()
|
| 127 |
+
return _parse_lines(raw_lines.copy()) if raw_lines else None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load(stream: Any, Loader: Any | None = None) -> Any: # noqa: N803 - API compatibility
|
| 131 |
+
"""Compatibility wrapper mirroring :func:`yaml.load`."""
|
| 132 |
+
|
| 133 |
+
return safe_load(stream)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def dump(data: Any, stream: Any | None = None, default_flow_style: bool | None = None) -> str:
|
| 137 |
+
"""Serialise ``data`` to YAML (JSON style in the fallback implementation)."""
|
| 138 |
+
|
| 139 |
+
text = json.dumps(data, indent=2)
|
| 140 |
+
if stream is not None:
|
| 141 |
+
stream.write(text)
|
| 142 |
+
return ""
|
| 143 |
+
return text
|
sim_priors_pk/data/README.md
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# `sim_priors_pk.data` Package Guide
|
| 2 |
+
|
| 3 |
+
This guide documents the purpose of every subpackage that lives under `sim_priors_pk/data`. Use it as a quick reference when wiring new data pipelines or navigating the simulated pharmacokinetic (PK) workflow.
|
| 4 |
+
|
| 5 |
+
## Configuration Preamble
|
| 6 |
+
|
| 7 |
+
Simulations are configured by combining reusable YAML files with the dataclasses that live in `sim_priors_pk.config_classes`. YAMLS are the file that populates those classes, we can define configs from the classes or reading the files.
|
| 8 |
+
|
| 9 |
+
- **YAML files (`config_files/`)** – Ready-to-use experiment definitions grouped under `config_files/experiment_configs`. For example, the `node-pk` folder contains `base-homogeneous.*.yaml` files that describe meta-study, dosing, and observation settings.
|
| 10 |
+
- **Config dataclasses (`sim_priors_pk/config_classes/`)** – Python dataclasses (`MetaStudyConfig`, `MetaDosingConfig`, `ObservationsConfig`, and friends) that parse those YAML files or can be instantiated directly in code when you need programmatic overrides.
|
| 11 |
+
|
| 12 |
+
When you load configurations in tests or scripts, prefer `MetaStudyConfig.from_yaml(...)` and similar helpers. They keep the simulation code aligned with the canonical YAML layout while still allowing you to craft configurations in pure Python when necessary.
|
| 13 |
+
|
| 14 |
+
## Top-Level Layout
|
| 15 |
+
|
| 16 |
+
This are the files that matter for the handling of simulations anda data:
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
sim_priors_pk/
|
| 21 |
+
├── config_files/
|
| 22 |
+
│ └── experiment_configs/
|
| 23 |
+
├── scripts/
|
| 24 |
+
├── sim_priors_pk/
|
| 25 |
+
│ ├── config_classes/
|
| 26 |
+
│ └── data/
|
| 27 |
+
│ ├── data_empirical/
|
| 28 |
+
│ ├── data_generation/
|
| 29 |
+
│ ├── data_preprocessing/
|
| 30 |
+
│ ├── datasets/
|
| 31 |
+
│ └── extra/
|
| 32 |
+
└── tests/
|
| 33 |
+
└── data/
|
| 34 |
+
└── simulation_data/
|
| 35 |
+
└── test_simulations.py
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Each directory is described below together with the most important entry points it exposes.
|
| 39 |
+
|
| 40 |
+
## `data_empirical`
|
| 41 |
+
|
| 42 |
+
Defines the data contracts used across the project.
|
| 43 |
+
These contracts specify the canonical JSON schema (StudyJSON, IndividualJSON) that standardizes how pharmacokinetic studies are represented — both empirical and simulated.
|
| 44 |
+
They serve as the interface between raw datasets, tensor batches, and model-ready data structures, ensuring a unified format throughout the pipeline. These helpers make it straightforward to load Hugging Face datasets or local JSON files, validate them, and materialise PyTorch-compatible batches.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
## `data_generation`
|
| 48 |
+
|
| 49 |
+
Simulation building blocks used to synthesise PK trajectories under configurable dosing and observation schemes.
|
| 50 |
+
|
| 51 |
+
* [`compartment_models.py`](data_generation/compartment_models.py) implements the stochastic sampling of population/individual PK parameters and the compartmental simulation loops.
|
| 52 |
+
* [`observations_classes.py`](data_generation/observations_classes.py) describe observation strategies (e.g. sparse vs. dense sampling) and utilities to realise them.
|
| 53 |
+
* [`compartment_models_management.py`](data_generation/compartment_models_management.py) orchestrates the full simulation workflow: it takes the meta-configuration, samples individual and dosing configurations, runs the compartmental simulations, applies the observation strategy, and assembles complete ensembles of studies in the data contracts.
|
| 54 |
+
|
| 55 |
+
Together these modules allow you to go from configuration dataclasses to simulated studies that mirror the empirical format.
|
| 56 |
+
|
| 57 |
+
## `data_preprocessing`
|
| 58 |
+
|
| 59 |
+
deprecated
|
| 60 |
+
|
| 61 |
+
## `datasets`
|
| 62 |
+
|
| 63 |
+
Lightning-ready dataset/dataloader factories.
|
| 64 |
+
|
| 65 |
+
- [`aicme_datasets.py`](datasets/aicme_datasets.py) defines `AICMECompartmentsDataBatch` and related PyTorch Lightning `DataModule` wrappers that harmonise both empirical and simulated studies for downstream training.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
## Putting It All Together
|
| 69 |
+
|
| 70 |
+
A typical workflow is:
|
| 71 |
+
|
| 72 |
+
1. **Configure**: Use `sim_priors_pk.config_classes` to describe study, dosing, and observation priors.
|
| 73 |
+
2. **Simulate**: Call into `data_generation` to sample synthetic studies or to augment empirical cohorts.
|
| 74 |
+
3. **Serialise or load**: Store simulations as JSON, or load existing JSON/CSV with `data_empirical` and `data_preprocessing`.
|
| 75 |
+
4. **Batch**: Wrap tensors using `datasets.AICMECompartmentsDataModule` for consumption by modules in `sim_priors_pk.models` and training scripts.
|
| 76 |
+
|
| 77 |
+
Refer back to this document whenever you onboard a new collaborator or reorganise data flows—the sections above stay aligned with the current code base.
|
| 78 |
+
|
| 79 |
+
## Worked Examples and Tests
|
| 80 |
+
|
| 81 |
+
Integration-style tests in `tests/data/simulation_data/test_simulations.py` demonstrate how the configuration pieces fit together:
|
| 82 |
+
|
| 83 |
+
- `test_prepare_full_simulation_to_study_json` shows how YAML-driven configs from `config_files/experiment_configs/node-pk` feed into `prepare_full_simulation_to_study_json` and culminate in a canonical `StudyJSON`.
|
| 84 |
+
- `test_prepare_ensemble_of_simulations` builds on the same configuration files to generate an ensemble of studies and persists them to disk, illustrating how bulk simulations can be orchestrated.
|
| 85 |
+
|
| 86 |
+
Use these tests as executable documentation whenever you need to follow the end-to-end flow from configuration files to simulated study artefacts.
|
sim_priors_pk/data/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility namespace for data-related modules.
|
| 2 |
+
|
| 3 |
+
This package groups empirical, generation, preprocessing, and dataset
|
| 4 |
+
helpers so they can be imported with the ``sim_priors_pk.data`` prefix.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"data_empirical",
|
| 9 |
+
"data_generation",
|
| 10 |
+
"data_preprocessing",
|
| 11 |
+
"datasets",
|
| 12 |
+
]
|
sim_priors_pk/data/data_empirical/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for working with empirical JSON study data."""
|
| 2 |
+
|
| 3 |
+
try: # pragma: no cover - optional torch dependency
|
| 4 |
+
from .builder import (
|
| 5 |
+
JSON2AICMEBuilder,
|
| 6 |
+
EmpiricalBatchConfig,
|
| 7 |
+
held_out_ind_json,
|
| 8 |
+
held_out_list_json,
|
| 9 |
+
load_empirical_json_batches,
|
| 10 |
+
load_empirical_json_batches_as_dm,
|
| 11 |
+
load_empirical_hf_batches_as_dm,
|
| 12 |
+
databatch_to_study_jsons,
|
| 13 |
+
prediction_to_study_jsons,
|
| 14 |
+
)
|
| 15 |
+
except ModuleNotFoundError as exc: # pragma: no cover - allow missing torch
|
| 16 |
+
if exc.name != "torch":
|
| 17 |
+
raise
|
| 18 |
+
JSON2AICMEBuilder = EmpiricalBatchConfig = None # type: ignore
|
| 19 |
+
held_out_ind_json = held_out_list_json = None # type: ignore
|
| 20 |
+
load_empirical_json_batches = load_empirical_json_batches_as_dm = None # type: ignore
|
| 21 |
+
databatch_to_study_jsons = prediction_to_study_jsons = None # type: ignore
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"json_schema",
|
| 25 |
+
"JSON2AICMEBuilder",
|
| 26 |
+
"EmpiricalBatchConfig",
|
| 27 |
+
"held_out_ind_json",
|
| 28 |
+
"held_out_list_json",
|
| 29 |
+
"load_empirical_json_batches",
|
| 30 |
+
"load_empirical_json_batches_as_dm",
|
| 31 |
+
"load_empirical_hf_batches_as_dm",
|
| 32 |
+
"databatch_to_study_jsons",
|
| 33 |
+
"prediction_to_study_jsons",
|
| 34 |
+
"json_stats",
|
| 35 |
+
]
|
sim_priors_pk/data/data_empirical/builder.py
ADDED
|
@@ -0,0 +1,1139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from torchtyping import TensorType as TT
|
| 11 |
+
|
| 12 |
+
from sim_priors_pk.config_classes.data_config import (
|
| 13 |
+
MetaDosingConfig,
|
| 14 |
+
)
|
| 15 |
+
from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING: # pragma: no cover - imported only for type hints
|
| 18 |
+
from sim_priors_pk.data.datasets.aicme_datasets import AICMECompartmentsDataModule
|
| 19 |
+
|
| 20 |
+
from .json_schema import IndividualJSON, StudyJSON, canonicalize_study
|
| 21 |
+
from .json_stats import EmpiricalJSONStats, compute_json_stats
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class EmpiricalBatchConfig:
|
| 26 |
+
"""Configuration for empirical batch construction.
|
| 27 |
+
|
| 28 |
+
Attributes
|
| 29 |
+
----------
|
| 30 |
+
pad_value_time:
|
| 31 |
+
Value used to pad time tensors.
|
| 32 |
+
pad_value_obs:
|
| 33 |
+
Value used to pad observation tensors.
|
| 34 |
+
max_databatch_size:
|
| 35 |
+
Maximum number of studies that can be stacked into a single batch.
|
| 36 |
+
max_individuals:
|
| 37 |
+
Maximum number of individuals per context or target block.
|
| 38 |
+
max_observations:
|
| 39 |
+
Maximum number of observation time points per individual.
|
| 40 |
+
max_remaining:
|
| 41 |
+
Maximum number of remaining time points per individual.
|
| 42 |
+
max_context_individuals / max_target_individuals:
|
| 43 |
+
Optional overrides specifying separate capacities for context and
|
| 44 |
+
target individual counts.
|
| 45 |
+
max_context_observations / max_target_observations:
|
| 46 |
+
Optional overrides specifying per-block observation capacities.
|
| 47 |
+
max_context_remaining / max_target_remaining:
|
| 48 |
+
Optional overrides specifying per-block remaining simulation
|
| 49 |
+
capacities.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
pad_value_time: float = 0.0
|
| 53 |
+
pad_value_obs: float = 0.0
|
| 54 |
+
max_databatch_size: int = 8
|
| 55 |
+
max_individuals: int = 1
|
| 56 |
+
max_observations: int = 0
|
| 57 |
+
max_remaining: int = 0
|
| 58 |
+
max_context_individuals: Optional[int] = None
|
| 59 |
+
max_target_individuals: Optional[int] = None
|
| 60 |
+
max_context_observations: Optional[int] = None
|
| 61 |
+
max_target_observations: Optional[int] = None
|
| 62 |
+
max_context_remaining: Optional[int] = None
|
| 63 |
+
max_target_remaining: Optional[int] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class JSON2AICMEBuilder:
|
| 67 |
+
"""Convert empirical study JSON to :class:`AICMECompartmentsDataBatch`.
|
| 68 |
+
|
| 69 |
+
The builder pads context and target individuals to fixed sizes and
|
| 70 |
+
assembles the :class:`AICMECompartmentsDataBatch` expected by the models.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, cfg: EmpiricalBatchConfig) -> None:
|
| 74 |
+
self.cfg = cfg
|
| 75 |
+
|
| 76 |
+
def _ctx_cap(self) -> int:
|
| 77 |
+
return (
|
| 78 |
+
self.cfg.max_context_individuals
|
| 79 |
+
if self.cfg.max_context_individuals is not None
|
| 80 |
+
else self.cfg.max_individuals
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def _tgt_cap(self) -> int:
|
| 84 |
+
return (
|
| 85 |
+
self.cfg.max_target_individuals
|
| 86 |
+
if self.cfg.max_target_individuals is not None
|
| 87 |
+
else self.cfg.max_individuals
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def _ctx_obs_cap(self) -> int:
|
| 91 |
+
return (
|
| 92 |
+
self.cfg.max_context_observations
|
| 93 |
+
if self.cfg.max_context_observations is not None
|
| 94 |
+
else self.cfg.max_observations
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def _tgt_obs_cap(self) -> int:
|
| 98 |
+
return (
|
| 99 |
+
self.cfg.max_target_observations
|
| 100 |
+
if self.cfg.max_target_observations is not None
|
| 101 |
+
else self.cfg.max_observations
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def _ctx_rem_cap(self) -> int:
|
| 105 |
+
return (
|
| 106 |
+
self.cfg.max_context_remaining
|
| 107 |
+
if self.cfg.max_context_remaining is not None
|
| 108 |
+
else self.cfg.max_remaining
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def _tgt_rem_cap(self) -> int:
|
| 112 |
+
return (
|
| 113 |
+
self.cfg.max_target_remaining
|
| 114 |
+
if self.cfg.max_target_remaining is not None
|
| 115 |
+
else self.cfg.max_remaining
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _block_from_inds(
|
| 119 |
+
self,
|
| 120 |
+
inds: List[IndividualJSON],
|
| 121 |
+
*,
|
| 122 |
+
max_individuals: int,
|
| 123 |
+
obs_cap: int,
|
| 124 |
+
rem_cap: int,
|
| 125 |
+
) -> Dict[str, TT]:
|
| 126 |
+
"""Assemble tensors for a list of individuals.
|
| 127 |
+
|
| 128 |
+
Padding is applied so that each block has the same number of
|
| 129 |
+
individuals (``max_individuals``) and time steps
|
| 130 |
+
(``max_observations``/``max_remaining``).
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
I_max = max(0, max_individuals)
|
| 134 |
+
ET = max(0, obs_cap)
|
| 135 |
+
R = max(0, rem_cap)
|
| 136 |
+
|
| 137 |
+
obs_tensor = torch.full((I_max, ET), self.cfg.pad_value_obs) # [I, ET]
|
| 138 |
+
time_tensor = torch.full((I_max, ET), self.cfg.pad_value_time) # [I, ET]
|
| 139 |
+
mask_tensor = torch.zeros((I_max, ET), dtype=torch.bool) # [I, ET]
|
| 140 |
+
|
| 141 |
+
rem_tensor = (
|
| 142 |
+
torch.full((I_max, R), self.cfg.pad_value_obs) if R else torch.zeros(I_max, 0)
|
| 143 |
+
) # [I, R]
|
| 144 |
+
rem_time_tensor = (
|
| 145 |
+
torch.full((I_max, R), self.cfg.pad_value_time) if R else torch.zeros(I_max, 0)
|
| 146 |
+
) # [I, R]
|
| 147 |
+
rem_mask_tensor = (
|
| 148 |
+
torch.zeros((I_max, R), dtype=torch.bool)
|
| 149 |
+
if R
|
| 150 |
+
else torch.zeros(I_max, 0, dtype=torch.bool)
|
| 151 |
+
) # [I, R]
|
| 152 |
+
|
| 153 |
+
for i, ind in enumerate(inds[:I_max]):
|
| 154 |
+
obs = torch.tensor(ind.get("observations", []), dtype=torch.float32) # [ET?]
|
| 155 |
+
time = torch.tensor(ind.get("observation_times", []), dtype=torch.float32) # [ET?]
|
| 156 |
+
L = min(obs.shape[0], ET)
|
| 157 |
+
obs_tensor[i, :L] = obs[:L]
|
| 158 |
+
time_tensor[i, :L] = time[:L]
|
| 159 |
+
mask_tensor[i, :L] = True
|
| 160 |
+
|
| 161 |
+
rem = torch.tensor(ind.get("remaining", []), dtype=torch.float32) # [R?]
|
| 162 |
+
rem_t = torch.tensor(ind.get("remaining_times", []), dtype=torch.float32) # [R?]
|
| 163 |
+
Lr = min(rem.shape[0], R)
|
| 164 |
+
if R:
|
| 165 |
+
rem_tensor[i, :Lr] = rem[:Lr]
|
| 166 |
+
rem_time_tensor[i, :Lr] = rem_t[:Lr]
|
| 167 |
+
rem_mask_tensor[i, :Lr] = True
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"obs": obs_tensor,
|
| 171 |
+
"time": time_tensor,
|
| 172 |
+
"mask": mask_tensor,
|
| 173 |
+
"rem": rem_tensor,
|
| 174 |
+
"rem_time": rem_time_tensor,
|
| 175 |
+
"rem_mask": rem_mask_tensor,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
def build_study_batch(
|
| 179 |
+
self, study: StudyJSON, meta_dosing: MetaDosingConfig
|
| 180 |
+
) -> AICMECompartmentsDataBatch:
|
| 181 |
+
"""Build a batch for a single study.
|
| 182 |
+
DOES NOT USES OBSERVATIONS STRATEGIESM,
|
| 183 |
+
takes the observation structure as given by the JSON data
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
study:
|
| 188 |
+
Canonicalised representation of one study.
|
| 189 |
+
meta_dosing:
|
| 190 |
+
Global dosing configuration.
|
| 191 |
+
|
| 192 |
+
Returns
|
| 193 |
+
-------
|
| 194 |
+
AICMECompartmentsDataBatch
|
| 195 |
+
Batch with ``B=1``.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
study = canonicalize_study(study)
|
| 199 |
+
ctx_cap = self._ctx_cap()
|
| 200 |
+
tgt_cap = self._tgt_cap()
|
| 201 |
+
|
| 202 |
+
ctx_block = self._block_from_inds(
|
| 203 |
+
study["context"],
|
| 204 |
+
max_individuals=ctx_cap,
|
| 205 |
+
obs_cap=self._ctx_obs_cap(),
|
| 206 |
+
rem_cap=self._ctx_rem_cap(),
|
| 207 |
+
)
|
| 208 |
+
tgt_block = self._block_from_inds(
|
| 209 |
+
study["target"],
|
| 210 |
+
max_individuals=tgt_cap,
|
| 211 |
+
obs_cap=self._tgt_obs_cap(),
|
| 212 |
+
rem_cap=self._tgt_rem_cap(),
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
route_vocab = {r: i for i, r in enumerate(meta_dosing.route_options)}
|
| 216 |
+
|
| 217 |
+
def _dose_route(inds: List[IndividualJSON], I_max: int):
|
| 218 |
+
amounts = torch.zeros(1, I_max, dtype=torch.float32) # [1, I]
|
| 219 |
+
routes = torch.zeros(1, I_max, dtype=torch.long) # [1, I]
|
| 220 |
+
for i, ind in enumerate(inds[:I_max]):
|
| 221 |
+
if ind.get("dosing"):
|
| 222 |
+
amounts[0, i] = ind["dosing"][0]
|
| 223 |
+
routes[0, i] = route_vocab.get(ind["dosing_type"][0], 0)
|
| 224 |
+
return amounts, routes
|
| 225 |
+
|
| 226 |
+
c_dose, c_route = _dose_route(study["context"], ctx_cap)
|
| 227 |
+
t_dose, t_route = _dose_route(study["target"], tgt_cap)
|
| 228 |
+
|
| 229 |
+
def _unsqueeze(block):
|
| 230 |
+
obs = block["obs"].unsqueeze(0).unsqueeze(-1) # [1, I, ET, 1]
|
| 231 |
+
time = block["time"].unsqueeze(0).unsqueeze(-1) # [1, I, ET, 1]
|
| 232 |
+
mask = block["mask"].unsqueeze(0) # [1, I, ET]
|
| 233 |
+
rem = block["rem"].unsqueeze(0).unsqueeze(-1) # [1, I, R, 1]
|
| 234 |
+
rem_time = block["rem_time"].unsqueeze(0).unsqueeze(-1) # [1, I, R, 1]
|
| 235 |
+
rem_mask = block["rem_mask"].unsqueeze(0) # [1, I, R]
|
| 236 |
+
return obs, time, mask, rem, rem_time, rem_mask
|
| 237 |
+
|
| 238 |
+
t_obs, t_time, t_mask, t_rem, t_rem_time, t_rem_mask = _unsqueeze(tgt_block)
|
| 239 |
+
c_obs, c_time, c_mask, c_rem, c_rem_time, c_rem_mask = _unsqueeze(ctx_block)
|
| 240 |
+
|
| 241 |
+
mask_ctx_inds = torch.zeros(1, ctx_cap, dtype=torch.bool) # [1, I]
|
| 242 |
+
mask_ctx_inds[0, : min(len(study["context"]), ctx_cap)] = True
|
| 243 |
+
mask_tgt_inds = torch.zeros(1, tgt_cap, dtype=torch.bool) # [1, I]
|
| 244 |
+
mask_tgt_inds[0, : min(len(study["target"]), tgt_cap)] = True
|
| 245 |
+
|
| 246 |
+
study_name = [study["meta_data"]["study_name"]]
|
| 247 |
+
substance_name = [study["meta_data"].get("substance_name", "")]
|
| 248 |
+
|
| 249 |
+
context_subject_name = [
|
| 250 |
+
[
|
| 251 |
+
study["context"][i].get("name_id", "") if i < len(study["context"]) else ""
|
| 252 |
+
for i in range(ctx_cap)
|
| 253 |
+
]
|
| 254 |
+
]
|
| 255 |
+
target_subject_name = [
|
| 256 |
+
[
|
| 257 |
+
study["target"][i].get("name_id", "") if i < len(study["target"]) else ""
|
| 258 |
+
for i in range(tgt_cap)
|
| 259 |
+
]
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
batch = AICMECompartmentsDataBatch(
|
| 263 |
+
target_obs=t_obs,
|
| 264 |
+
target_obs_time=t_time,
|
| 265 |
+
target_obs_mask=t_mask,
|
| 266 |
+
target_rem_sim=t_rem,
|
| 267 |
+
target_rem_sim_time=t_rem_time,
|
| 268 |
+
target_rem_sim_mask=t_rem_mask,
|
| 269 |
+
context_obs=c_obs,
|
| 270 |
+
context_obs_time=c_time,
|
| 271 |
+
context_obs_mask=c_mask,
|
| 272 |
+
context_rem_sim=c_rem,
|
| 273 |
+
context_rem_sim_time=c_rem_time,
|
| 274 |
+
context_rem_sim_mask=c_rem_mask,
|
| 275 |
+
target_dosing_amounts=t_dose,
|
| 276 |
+
target_dosing_route_types=t_route,
|
| 277 |
+
context_dosing_amounts=c_dose,
|
| 278 |
+
context_dosing_route_types=c_route,
|
| 279 |
+
mask_context_individuals=mask_ctx_inds,
|
| 280 |
+
mask_target_individuals=mask_tgt_inds,
|
| 281 |
+
study_name=study_name,
|
| 282 |
+
context_subject_name=context_subject_name,
|
| 283 |
+
target_subject_name=target_subject_name,
|
| 284 |
+
substance_name=substance_name,
|
| 285 |
+
time_scales=torch.tensor([[0.0, 0.0]]),
|
| 286 |
+
is_empirical=True,
|
| 287 |
+
)
|
| 288 |
+
return batch
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def _stack_B(
|
| 292 |
+
batches: List[AICMECompartmentsDataBatch],
|
| 293 |
+
) -> AICMECompartmentsDataBatch:
|
| 294 |
+
"""Concatenate ``batches`` along the batch dimension ``B``.
|
| 295 |
+
|
| 296 |
+
Each input batch must have ``B=1``; the returned batch will have
|
| 297 |
+
``B=len(batches)`` with index order preserved.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
if not batches:
|
| 301 |
+
raise ValueError("batches must not be empty")
|
| 302 |
+
|
| 303 |
+
stacked_fields = []
|
| 304 |
+
for values in zip(*batches):
|
| 305 |
+
first = values[0]
|
| 306 |
+
if isinstance(first, torch.Tensor):
|
| 307 |
+
stacked_fields.append(torch.cat(values, dim=0)) # [B, ...]
|
| 308 |
+
elif isinstance(first, list):
|
| 309 |
+
if first and isinstance(first[0], list):
|
| 310 |
+
merged_nested: List[List[str]] = []
|
| 311 |
+
for v in values:
|
| 312 |
+
merged_nested.extend(v)
|
| 313 |
+
stacked_fields.append(merged_nested)
|
| 314 |
+
else:
|
| 315 |
+
merged: List[str] = []
|
| 316 |
+
for v in values:
|
| 317 |
+
merged.extend(v)
|
| 318 |
+
stacked_fields.append(merged)
|
| 319 |
+
else:
|
| 320 |
+
stacked_fields.append(first)
|
| 321 |
+
return AICMECompartmentsDataBatch(*stacked_fields)
|
| 322 |
+
|
| 323 |
+
def build_one_aicmebatch(
|
| 324 |
+
self, studies: List[StudyJSON], meta_dosing: MetaDosingConfig
|
| 325 |
+
) -> AICMECompartmentsDataBatch:
|
| 326 |
+
"""Build a single batch from multiple studies.
|
| 327 |
+
|
| 328 |
+
Parameters
|
| 329 |
+
----------
|
| 330 |
+
studies:
|
| 331 |
+
List of studies to combine. The resulting batch will have
|
| 332 |
+
``B=len(studies)``.
|
| 333 |
+
meta_dosing:
|
| 334 |
+
Global dosing configuration shared across studies.
|
| 335 |
+
|
| 336 |
+
Returns
|
| 337 |
+
-------
|
| 338 |
+
AICMECompartmentsDataBatch
|
| 339 |
+
Combined batch with batch dimension indexing the supplied
|
| 340 |
+
studies in order.
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
per_study = [self.build_study_batch(s, meta_dosing) for s in studies]
|
| 344 |
+
return self._stack_B(per_study)
|
| 345 |
+
|
| 346 |
+
def build_one_aicmebatch_as_dataset(
|
| 347 |
+
self,
|
| 348 |
+
studies: List[StudyJSON],
|
| 349 |
+
context_strategy,
|
| 350 |
+
target_strategy,
|
| 351 |
+
meta_dosing: MetaDosingConfig,
|
| 352 |
+
*,
|
| 353 |
+
return_studies: bool = False, # ← debugging flag (default = True)
|
| 354 |
+
) -> List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]:
|
| 355 |
+
"""Create batches mirroring ``AICMECompartmentsDataset`` processing.
|
| 356 |
+
|
| 357 |
+
For each study we generate leave-one-out permutations using
|
| 358 |
+
:func:`held_out_ind_json`. The provided ``context_strategy`` and
|
| 359 |
+
``target_strategy`` are then used to apply the same empirical splitting
|
| 360 |
+
between observed and remaining measurements as performed in
|
| 361 |
+
:class:`AICMECompartmentsDataset`. Each permutation across all studies is
|
| 362 |
+
stacked along the batch dimension ``B``.
|
| 363 |
+
|
| 364 |
+
Parameters
|
| 365 |
+
----------
|
| 366 |
+
studies:
|
| 367 |
+
List of empirical studies. Each study is expected to contain only a
|
| 368 |
+
context block; target individuals are produced via leave-one-out
|
| 369 |
+
permutations.
|
| 370 |
+
context_strategy / target_strategy:
|
| 371 |
+
Observation strategies matching those used by
|
| 372 |
+
:class:`AICMECompartmentsDataset` for shaping context and target
|
| 373 |
+
data respectively.
|
| 374 |
+
meta_dosing:
|
| 375 |
+
Global dosing configuration.
|
| 376 |
+
return_studies:
|
| 377 |
+
If True (default), return the intermediate permuted study dicts
|
| 378 |
+
instead of building full ``AICMECompartmentsDataBatch`` objects.
|
| 379 |
+
Useful for debugging.
|
| 380 |
+
|
| 381 |
+
Returns
|
| 382 |
+
-------
|
| 383 |
+
List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]
|
| 384 |
+
If `return_studies` is True → list of permuted study dicts.
|
| 385 |
+
If `return_studies` is False → list of ``AICMECompartmentsDataBatch``.
|
| 386 |
+
"""
|
| 387 |
+
canon_studies = [canonicalize_study(s, drop_tgt_too_few=False) for s in studies]
|
| 388 |
+
max_perm = max(len(s["context"]) for s in canon_studies)
|
| 389 |
+
per_study_perms = [held_out_ind_json(s, max_perm) for s in canon_studies]
|
| 390 |
+
|
| 391 |
+
batches = []
|
| 392 |
+
for perm_idx in range(max_perm):
|
| 393 |
+
permuted_studies = [
|
| 394 |
+
self._process_one_study_perm(
|
| 395 |
+
study_perms[perm_idx], context_strategy, target_strategy
|
| 396 |
+
)
|
| 397 |
+
for study_perms in per_study_perms
|
| 398 |
+
]
|
| 399 |
+
if return_studies:
|
| 400 |
+
batches.append(permuted_studies) # debugging: raw dicts
|
| 401 |
+
else:
|
| 402 |
+
batches.append(self.build_one_aicmebatch(permuted_studies, meta_dosing))
|
| 403 |
+
return batches
|
| 404 |
+
|
| 405 |
+
def build_one_aicmebatch_as_dataset_no_heldout(
|
| 406 |
+
self,
|
| 407 |
+
studies: List[StudyJSON],
|
| 408 |
+
context_strategy,
|
| 409 |
+
target_strategy,
|
| 410 |
+
meta_dosing: MetaDosingConfig,
|
| 411 |
+
*,
|
| 412 |
+
return_studies: bool = False,
|
| 413 |
+
) -> List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]:
|
| 414 |
+
"""Create a single empirical batch without leave-one-out targets.
|
| 415 |
+
|
| 416 |
+
This method mirrors :meth:`build_one_aicmebatch_as_dataset` preprocessing
|
| 417 |
+
but does not move any individual from context to target. All individuals
|
| 418 |
+
remain in context and the returned list has a single element.
|
| 419 |
+
|
| 420 |
+
Parameters
|
| 421 |
+
----------
|
| 422 |
+
studies:
|
| 423 |
+
List of empirical studies.
|
| 424 |
+
context_strategy / target_strategy:
|
| 425 |
+
Observation strategies matching those used by
|
| 426 |
+
:class:`AICMECompartmentsDataset`.
|
| 427 |
+
meta_dosing:
|
| 428 |
+
Global dosing configuration.
|
| 429 |
+
return_studies:
|
| 430 |
+
If ``True``, return the processed ``StudyJSON`` records instead of a
|
| 431 |
+
fully built :class:`AICMECompartmentsDataBatch`.
|
| 432 |
+
|
| 433 |
+
Returns
|
| 434 |
+
-------
|
| 435 |
+
List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]
|
| 436 |
+
A list with length one containing either processed studies or one
|
| 437 |
+
``AICMECompartmentsDataBatch``.
|
| 438 |
+
"""
|
| 439 |
+
canon_studies = [canonicalize_study(s, drop_tgt_too_few=False) for s in studies]
|
| 440 |
+
context_only_studies: List[StudyJSON] = []
|
| 441 |
+
for study in canon_studies:
|
| 442 |
+
all_inds = list(study.get("context", [])) + list(study.get("target", []))
|
| 443 |
+
context_only_studies.append(
|
| 444 |
+
{
|
| 445 |
+
"context": all_inds,
|
| 446 |
+
"target": [],
|
| 447 |
+
"meta_data": dict(study.get("meta_data", {})),
|
| 448 |
+
}
|
| 449 |
+
)
|
| 450 |
+
processed_studies = [
|
| 451 |
+
self._process_one_study_perm(study, context_strategy, target_strategy)
|
| 452 |
+
for study in context_only_studies
|
| 453 |
+
]
|
| 454 |
+
|
| 455 |
+
if return_studies:
|
| 456 |
+
return [processed_studies]
|
| 457 |
+
return [self.build_one_aicmebatch(processed_studies, meta_dosing)]
|
| 458 |
+
|
| 459 |
+
def _process_one_study_perm(
|
| 460 |
+
self,
|
| 461 |
+
study: StudyJSON,
|
| 462 |
+
context_strategy,
|
| 463 |
+
target_strategy,
|
| 464 |
+
) -> StudyJSON:
|
| 465 |
+
"""Turn one permuted study into tensors and apply strategies."""
|
| 466 |
+
processed = {"context": [], "target": [], "meta_data": study["meta_data"]}
|
| 467 |
+
|
| 468 |
+
for block, inds, strat in (
|
| 469 |
+
("context", study["context"], context_strategy),
|
| 470 |
+
("target", study["target"], target_strategy),
|
| 471 |
+
):
|
| 472 |
+
processed[block] = self._process_block(inds, strat)
|
| 473 |
+
|
| 474 |
+
return processed
|
| 475 |
+
|
| 476 |
+
def _process_block(
|
| 477 |
+
self,
|
| 478 |
+
inds: List[IndividualJSON],
|
| 479 |
+
strat,
|
| 480 |
+
) -> List[IndividualJSON]:
|
| 481 |
+
"""Convert a list of individuals into padded tensors, then apply strategy."""
|
| 482 |
+
if not inds:
|
| 483 |
+
return []
|
| 484 |
+
|
| 485 |
+
obs, times, mask = self._pack_individuals(inds)
|
| 486 |
+
obs_o, time_o, mask_o, rem_o, rem_t, rem_m = strat.generate_empirical(obs, times, mask)
|
| 487 |
+
|
| 488 |
+
return self._rebuild_individuals(inds, obs_o, time_o, mask_o, rem_o, rem_t, rem_m)
|
| 489 |
+
|
| 490 |
+
def _pack_individuals(
|
| 491 |
+
self,
|
| 492 |
+
inds: List[IndividualJSON],
|
| 493 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 494 |
+
"""Pad individuals into (obs, times, mask)."""
|
| 495 |
+
I = len(inds)
|
| 496 |
+
ET = max(len(ind["observations"]) for ind in inds)
|
| 497 |
+
obs = torch.full((I, ET), self.cfg.pad_value_obs)
|
| 498 |
+
times = torch.full((I, ET), self.cfg.pad_value_time)
|
| 499 |
+
mask = torch.zeros((I, ET), dtype=torch.bool)
|
| 500 |
+
|
| 501 |
+
for i, ind in enumerate(inds):
|
| 502 |
+
o = torch.tensor(ind["observations"], dtype=torch.float32)
|
| 503 |
+
t = torch.tensor(ind["observation_times"], dtype=torch.float32)
|
| 504 |
+
L = o.shape[0]
|
| 505 |
+
obs[i, :L], times[i, :L], mask[i, :L] = o, t, True
|
| 506 |
+
return obs, times, mask
|
| 507 |
+
|
| 508 |
+
def _rebuild_individuals(
|
| 509 |
+
self,
|
| 510 |
+
inds: List[IndividualJSON],
|
| 511 |
+
obs_o: torch.Tensor,
|
| 512 |
+
time_o: torch.Tensor,
|
| 513 |
+
mask_o: torch.Tensor,
|
| 514 |
+
rem_o: Optional[torch.Tensor],
|
| 515 |
+
rem_t: Optional[torch.Tensor],
|
| 516 |
+
rem_m: Optional[torch.Tensor],
|
| 517 |
+
) -> List[IndividualJSON]:
|
| 518 |
+
"""Convert tensors back to JSON-like dicts for each individual."""
|
| 519 |
+
block_inds = []
|
| 520 |
+
for i in range(obs_o.shape[0]):
|
| 521 |
+
ind_dict: IndividualJSON = {
|
| 522 |
+
"observations": obs_o[i][mask_o[i]].tolist(),
|
| 523 |
+
"observation_times": time_o[i][mask_o[i]].tolist(),
|
| 524 |
+
}
|
| 525 |
+
name_id = inds[i].get("name_id") if i < len(inds) else None
|
| 526 |
+
if name_id:
|
| 527 |
+
ind_dict["name_id"] = name_id
|
| 528 |
+
if rem_o is not None and rem_m is not None:
|
| 529 |
+
ind_dict["remaining"] = rem_o[i][rem_m[i]].tolist()
|
| 530 |
+
ind_dict["remaining_times"] = rem_t[i][rem_m[i]].tolist()
|
| 531 |
+
block_inds.append(ind_dict)
|
| 532 |
+
return block_inds
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def databatch_to_study_jsons(
|
| 536 |
+
batch: AICMECompartmentsDataBatch,
|
| 537 |
+
meta_dosing: MetaDosingConfig,
|
| 538 |
+
) -> list[StudyJSON]:
|
| 539 |
+
"""Convert an ``AICMECompartmentsDataBatch`` back to ``StudyJSON`` records.
|
| 540 |
+
|
| 541 |
+
Parameters
|
| 542 |
+
----------
|
| 543 |
+
batch:
|
| 544 |
+
Batch carrying tensors with a leading study dimension ``B``.
|
| 545 |
+
meta_dosing:
|
| 546 |
+
Dosing configuration used to decode route type indices.
|
| 547 |
+
|
| 548 |
+
Returns
|
| 549 |
+
-------
|
| 550 |
+
List[StudyJSON]
|
| 551 |
+
One study per element along the batch dimension ``B``. Missing
|
| 552 |
+
``study_name`` or ``substance_name`` entries are replaced by
|
| 553 |
+
fallback placeholders ``study_{b}`` and ``substance_{b}``.
|
| 554 |
+
"""
|
| 555 |
+
route_options = meta_dosing.route_options
|
| 556 |
+
studies: list[StudyJSON] = []
|
| 557 |
+
B = batch.context_obs.shape[0]
|
| 558 |
+
|
| 559 |
+
def _block(
|
| 560 |
+
obs: TT["B", "I", "T", 1],
|
| 561 |
+
time: TT["B", "I", "T", 1],
|
| 562 |
+
mask: TT["B", "I", "T"],
|
| 563 |
+
rem: TT["B", "I", "R", 1],
|
| 564 |
+
rem_time: TT["B", "I", "R", 1],
|
| 565 |
+
rem_mask: TT["B", "I", "R"],
|
| 566 |
+
doses: TT["B", "I"],
|
| 567 |
+
routes: TT["B", "I"],
|
| 568 |
+
ind_mask: TT["B", "I"],
|
| 569 |
+
names: list[list[str]],
|
| 570 |
+
) -> list[IndividualJSON]:
|
| 571 |
+
inds: list[IndividualJSON] = []
|
| 572 |
+
for i in range(obs.shape[1]):
|
| 573 |
+
if not ind_mask[b, i]:
|
| 574 |
+
continue
|
| 575 |
+
name_list = names[b] if b < len(names) else []
|
| 576 |
+
ind: IndividualJSON = {}
|
| 577 |
+
if i < len(name_list) and name_list[i]:
|
| 578 |
+
ind["name_id"] = name_list[i]
|
| 579 |
+
obs_i = obs[b, i, :, 0] # [T]
|
| 580 |
+
time_i = time[b, i, :, 0] # [T]
|
| 581 |
+
mask_i = mask[b, i] # [T]
|
| 582 |
+
ind["observations"] = obs_i[mask_i].tolist()
|
| 583 |
+
ind["observation_times"] = time_i[mask_i].tolist()
|
| 584 |
+
rem_i = rem[b, i, :, 0] # [R]
|
| 585 |
+
rem_time_i = rem_time[b, i, :, 0] # [R]
|
| 586 |
+
rem_mask_i = rem_mask[b, i] # [R]
|
| 587 |
+
rem_vals = rem_i[rem_mask_i].tolist()
|
| 588 |
+
rem_times = rem_time_i[rem_mask_i].tolist()
|
| 589 |
+
if rem_vals:
|
| 590 |
+
ind["remaining"] = rem_vals
|
| 591 |
+
ind["remaining_times"] = rem_times
|
| 592 |
+
dose = float(doses[b, i].item())
|
| 593 |
+
route_idx = int(routes[b, i].item())
|
| 594 |
+
if dose or route_idx:
|
| 595 |
+
route = (
|
| 596 |
+
route_options[route_idx] if route_idx < len(route_options) else str(route_idx)
|
| 597 |
+
)
|
| 598 |
+
ind["dosing"] = [dose]
|
| 599 |
+
ind["dosing_type"] = [route]
|
| 600 |
+
ind["dosing_times"] = [meta_dosing.time]
|
| 601 |
+
ind["dosing_name"] = [route]
|
| 602 |
+
inds.append(ind)
|
| 603 |
+
return inds
|
| 604 |
+
|
| 605 |
+
for b in range(B):
|
| 606 |
+
study_name = (
|
| 607 |
+
batch.study_name[b]
|
| 608 |
+
if b < len(batch.study_name) and batch.study_name[b]
|
| 609 |
+
else f"study_{b}"
|
| 610 |
+
)
|
| 611 |
+
substance_name = (
|
| 612 |
+
batch.substance_name[b]
|
| 613 |
+
if b < len(batch.substance_name) and batch.substance_name[b]
|
| 614 |
+
else f"substance_{b}"
|
| 615 |
+
)
|
| 616 |
+
meta = {"study_name": study_name, "substance_name": substance_name}
|
| 617 |
+
ctx = _block(
|
| 618 |
+
batch.context_obs,
|
| 619 |
+
batch.context_obs_time,
|
| 620 |
+
batch.context_obs_mask,
|
| 621 |
+
batch.context_rem_sim,
|
| 622 |
+
batch.context_rem_sim_time,
|
| 623 |
+
batch.context_rem_sim_mask,
|
| 624 |
+
batch.context_dosing_amounts,
|
| 625 |
+
batch.context_dosing_route_types,
|
| 626 |
+
batch.mask_context_individuals,
|
| 627 |
+
batch.context_subject_name,
|
| 628 |
+
)
|
| 629 |
+
tgt = _block(
|
| 630 |
+
batch.target_obs,
|
| 631 |
+
batch.target_obs_time,
|
| 632 |
+
batch.target_obs_mask,
|
| 633 |
+
batch.target_rem_sim,
|
| 634 |
+
batch.target_rem_sim_time,
|
| 635 |
+
batch.target_rem_sim_mask,
|
| 636 |
+
batch.target_dosing_amounts,
|
| 637 |
+
batch.target_dosing_route_types,
|
| 638 |
+
batch.mask_target_individuals,
|
| 639 |
+
batch.target_subject_name,
|
| 640 |
+
)
|
| 641 |
+
studies.append({"context": ctx, "target": tgt, "meta_data": meta})
|
| 642 |
+
return studies
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def prediction_to_study_jsons(
|
| 646 |
+
prediction_sample: TT["S", "B", "It", "Tr", 1],
|
| 647 |
+
prediction_time: TT["S", "B", "It", "Tr", 1],
|
| 648 |
+
batch: AICMECompartmentsDataBatch,
|
| 649 |
+
meta_dosing: MetaDosingConfig,
|
| 650 |
+
) -> list[StudyJSON]:
|
| 651 |
+
"""Attach prediction samples to study records.
|
| 652 |
+
|
| 653 |
+
Parameters
|
| 654 |
+
----------
|
| 655 |
+
prediction_sample:
|
| 656 |
+
Predicted trajectories with a leading sample dimension ``S``.
|
| 657 |
+
prediction_time:
|
| 658 |
+
Time points corresponding to ``prediction_sample``.
|
| 659 |
+
batch:
|
| 660 |
+
Original :class:`AICMECompartmentsDataBatch` used to generate the
|
| 661 |
+
predictions.
|
| 662 |
+
meta_dosing:
|
| 663 |
+
Dosing configuration for route decoding.
|
| 664 |
+
|
| 665 |
+
Returns
|
| 666 |
+
-------
|
| 667 |
+
list[StudyJSON]
|
| 668 |
+
Studies with ``prediction_samples`` and ``prediction_times`` fields in
|
| 669 |
+
each predicted target individual.
|
| 670 |
+
|
| 671 |
+
Notes
|
| 672 |
+
-----
|
| 673 |
+
Some predictive samplers (for example FlowPK individual prediction) may
|
| 674 |
+
return predictions for only a subset of target individuals compared with
|
| 675 |
+
the original batch. In that case this function keeps only the first ``It``
|
| 676 |
+
target entries (where ``It`` is inferred from ``prediction_sample``) so
|
| 677 |
+
JSON plots and exported records stay aligned with the predicted tensors.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
studies = databatch_to_study_jsons(batch, meta_dosing)
|
| 681 |
+
_, B, It, _, _ = prediction_sample.shape # [S, B, It, Tr, 1]
|
| 682 |
+
for b in range(B):
|
| 683 |
+
# Keep studies aligned with the number of predicted target individuals.
|
| 684 |
+
studies[b]["target"] = studies[b]["target"][:It]
|
| 685 |
+
for i in range(min(It, len(studies[b]["target"]))):
|
| 686 |
+
samples = prediction_sample[:, b, i, :, 0] # [S, Tr]
|
| 687 |
+
times = prediction_time[0, b, i, :, 0] # [Tr]
|
| 688 |
+
studies[b]["target"][i]["prediction_samples"] = samples.tolist()
|
| 689 |
+
studies[b]["target"][i]["prediction_times"] = times.tolist()
|
| 690 |
+
return studies
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def simulation_obs_to_study_json(
|
| 694 |
+
obs_out: torch.Tensor,
|
| 695 |
+
time_out: torch.Tensor,
|
| 696 |
+
mask_out: torch.Tensor,
|
| 697 |
+
rem_sim: Optional[torch.Tensor],
|
| 698 |
+
rem_time: Optional[torch.Tensor],
|
| 699 |
+
rem_mask: Optional[torch.Tensor],
|
| 700 |
+
dosing_config_array: list,
|
| 701 |
+
dosing_amounts: torch.Tensor,
|
| 702 |
+
study_config,
|
| 703 |
+
idx: int,
|
| 704 |
+
) -> StudyJSON:
|
| 705 |
+
"""Convert processed simulation tensors into a :class:`StudyJSON` entry.
|
| 706 |
+
|
| 707 |
+
Parameters
|
| 708 |
+
----------
|
| 709 |
+
obs_out, time_out, mask_out:
|
| 710 |
+
Tensors describing the observed concentrations and time points for the
|
| 711 |
+
simulated individuals. ``mask_out`` identifies valid entries in the
|
| 712 |
+
padded tensors.
|
| 713 |
+
rem_sim, rem_time, rem_mask:
|
| 714 |
+
Optional tensors describing the remaining (unobserved) simulation
|
| 715 |
+
trajectory. When provided, the tensors must have the same leading
|
| 716 |
+
dimensions as ``obs_out`` and ``time_out`` with ``rem_mask`` marking
|
| 717 |
+
valid entries.
|
| 718 |
+
dosing_config_array:
|
| 719 |
+
Sequence with dosing configuration objects for each individual.
|
| 720 |
+
dosing_amounts:
|
| 721 |
+
Tensor containing the dosing amount per individual.
|
| 722 |
+
study_config:
|
| 723 |
+
Configuration object describing the simulated study. Only the
|
| 724 |
+
``drug_id`` attribute is accessed, if present.
|
| 725 |
+
idx:
|
| 726 |
+
Index used to label the generated study name.
|
| 727 |
+
|
| 728 |
+
Returns
|
| 729 |
+
-------
|
| 730 |
+
StudyJSON
|
| 731 |
+
JSON-compatible dictionary describing the context block of the
|
| 732 |
+
simulation.
|
| 733 |
+
"""
|
| 734 |
+
|
| 735 |
+
context: list[IndividualJSON] = []
|
| 736 |
+
num_individuals = obs_out.shape[0]
|
| 737 |
+
|
| 738 |
+
for ind_idx in range(num_individuals):
|
| 739 |
+
mask = mask_out[ind_idx].to(torch.bool)
|
| 740 |
+
observations = obs_out[ind_idx][mask].tolist()
|
| 741 |
+
observation_times = time_out[ind_idx][mask].tolist()
|
| 742 |
+
|
| 743 |
+
individual: IndividualJSON = {
|
| 744 |
+
"name_id": f"context_{ind_idx}",
|
| 745 |
+
"observations": observations,
|
| 746 |
+
"observation_times": observation_times,
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
if rem_sim is not None and rem_time is not None and rem_mask is not None:
|
| 750 |
+
rem_mask_row = rem_mask[ind_idx].to(torch.bool)
|
| 751 |
+
if rem_mask_row.any():
|
| 752 |
+
individual["remaining"] = rem_sim[ind_idx][rem_mask_row].tolist()
|
| 753 |
+
individual["remaining_times"] = rem_time[ind_idx][rem_mask_row].tolist()
|
| 754 |
+
|
| 755 |
+
dosing_cfg = dosing_config_array[ind_idx]
|
| 756 |
+
dose = float(dosing_amounts[ind_idx].item())
|
| 757 |
+
route = getattr(dosing_cfg, "route", "")
|
| 758 |
+
dosing_time = float(getattr(dosing_cfg, "time", 0.0))
|
| 759 |
+
|
| 760 |
+
if dose or route:
|
| 761 |
+
individual["dosing"] = [dose]
|
| 762 |
+
individual["dosing_type"] = [route]
|
| 763 |
+
individual["dosing_times"] = [dosing_time]
|
| 764 |
+
individual["dosing_name"] = [route]
|
| 765 |
+
|
| 766 |
+
context.append(individual)
|
| 767 |
+
|
| 768 |
+
study_json: StudyJSON = {
|
| 769 |
+
"context": context,
|
| 770 |
+
"target": [],
|
| 771 |
+
"meta_data": {
|
| 772 |
+
"study_name": f"simulated_study_{idx}",
|
| 773 |
+
"substance_name": getattr(study_config, "drug_id", "simulated_substance"),
|
| 774 |
+
},
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
return study_json
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def held_out_ind_json(study: StudyJSON, max_held_out_individuals: int) -> List[StudyJSON]:
|
| 781 |
+
"""Create study permutations with one individual moved to target.
|
| 782 |
+
|
| 783 |
+
Parameters
|
| 784 |
+
----------
|
| 785 |
+
study:
|
| 786 |
+
Study JSON containing only context individuals (``target`` must be empty).
|
| 787 |
+
max_held_out_individuals:
|
| 788 |
+
Maximum number of permutations to generate.
|
| 789 |
+
|
| 790 |
+
Returns
|
| 791 |
+
-------
|
| 792 |
+
List[StudyJSON]
|
| 793 |
+
List with ``max_held_out_individuals`` studies where each of the first
|
| 794 |
+
``len(context)`` entries corresponds to one context individual being
|
| 795 |
+
moved to the target block. Remaining entries repeat the original study
|
| 796 |
+
with an empty target.
|
| 797 |
+
"""
|
| 798 |
+
context = list(study.get("context", []))
|
| 799 |
+
meta = dict(study.get("meta_data", {}))
|
| 800 |
+
out: List[StudyJSON] = []
|
| 801 |
+
n_ctx = len(context)
|
| 802 |
+
limit = min(max_held_out_individuals, n_ctx)
|
| 803 |
+
for idx in range(limit):
|
| 804 |
+
target = [context[idx]]
|
| 805 |
+
ctx = context[:idx] + context[idx + 1 :]
|
| 806 |
+
out.append({"context": ctx, "target": target, "meta_data": meta})
|
| 807 |
+
base = {"context": context, "target": [], "meta_data": meta}
|
| 808 |
+
while len(out) < max_held_out_individuals:
|
| 809 |
+
out.append(base)
|
| 810 |
+
return out
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
def held_out_list_json(
|
| 814 |
+
builder: JSON2AICMEBuilder,
|
| 815 |
+
studies: List[StudyJSON],
|
| 816 |
+
meta_dosing: MetaDosingConfig,
|
| 817 |
+
max_held_out_individuals: int,
|
| 818 |
+
) -> List[AICMECompartmentsDataBatch]:
|
| 819 |
+
"""
|
| 820 |
+
Generate batches for leave-one-out permutations across studies.
|
| 821 |
+
|
| 822 |
+
Parameters
|
| 823 |
+
----------
|
| 824 |
+
builder:
|
| 825 |
+
Instance used to convert studies to :class:`AICMECompartmentsDataBatch`.
|
| 826 |
+
studies:
|
| 827 |
+
Studies where only the context block is populated.
|
| 828 |
+
meta_dosing:
|
| 829 |
+
Global dosing configuration.
|
| 830 |
+
max_held_out_individuals:
|
| 831 |
+
Maximum number of held-out permutations per study.
|
| 832 |
+
|
| 833 |
+
Returns
|
| 834 |
+
-------
|
| 835 |
+
List[AICMECompartmentsDataBatch]
|
| 836 |
+
``max_held_out_individuals`` batches. The ``i``-th batch contains the
|
| 837 |
+
``i``-th permutation from each study stacked along the batch
|
| 838 |
+
dimension.
|
| 839 |
+
"""
|
| 840 |
+
per_study = [held_out_ind_json(s, max_held_out_individuals) for s in studies]
|
| 841 |
+
batches: List[AICMECompartmentsDataBatch] = []
|
| 842 |
+
for i in range(max_held_out_individuals):
|
| 843 |
+
perm = [per_study[j][i] for j in range(len(studies))]
|
| 844 |
+
batches.append(builder.build_one_aicmebatch(perm, meta_dosing))
|
| 845 |
+
return batches
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def load_empirical_json_batches(
|
| 849 |
+
json_path: Path,
|
| 850 |
+
meta_dosing: Optional[MetaDosingConfig] = None,
|
| 851 |
+
stats: Optional[EmpiricalJSONStats] = None,
|
| 852 |
+
datamodule: Optional[AICMECompartmentsDataModule] = None,
|
| 853 |
+
) -> List[AICMECompartmentsDataBatch]:
|
| 854 |
+
"""
|
| 855 |
+
Load an empirical study JSON file and build leave-one-out batches.
|
| 856 |
+
|
| 857 |
+
We place all the individuals in the context
|
| 858 |
+
|
| 859 |
+
Parameters
|
| 860 |
+
----------
|
| 861 |
+
json_path:
|
| 862 |
+
Path to a JSON file containing a list of :class:`StudyJSON` records.
|
| 863 |
+
meta_dosing:
|
| 864 |
+
Global dosing configuration. If ``None`` a default
|
| 865 |
+
:class:`MetaDosingConfig` is used.
|
| 866 |
+
stats:
|
| 867 |
+
Pre-computed statistics describing the dataset. When ``None`` the
|
| 868 |
+
statistics are calculated from ``json_path`` via
|
| 869 |
+
:func:`compute_json_stats`.
|
| 870 |
+
datamodule:
|
| 871 |
+
Optional synthetic data module providing shape information via
|
| 872 |
+
:meth:`AICMECompartmentsDataModule.obtain_shapes`. When given, these
|
| 873 |
+
shapes override those inferred from ``stats``.
|
| 874 |
+
|
| 875 |
+
Returns
|
| 876 |
+
-------
|
| 877 |
+
List[AICMECompartmentsDataBatch]
|
| 878 |
+
Leave-one-out batches constructed from the studies in ``json_path``.
|
| 879 |
+
|
| 880 |
+
Notes
|
| 881 |
+
-----
|
| 882 |
+
The function canonicalises all studies and either uses the provided
|
| 883 |
+
``stats`` or computes them from the JSON file to determine the number of
|
| 884 |
+
leave-one-out permutations. When ``datamodule`` is supplied the padding
|
| 885 |
+
shapes ``(max_individuals, max_observations, max_remaining)`` are taken
|
| 886 |
+
from :meth:`AICMECompartmentsDataModule.obtain_shapes`.
|
| 887 |
+
"""
|
| 888 |
+
|
| 889 |
+
# read file SHOULD BE A LIST OF STUDY JSON
|
| 890 |
+
with json_path.open() as f:
|
| 891 |
+
raw_studies = json.load(f)
|
| 892 |
+
|
| 893 |
+
if not isinstance(raw_studies, list):
|
| 894 |
+
raise ValueError("Expected JSON file to contain a list of StudyJSON records")
|
| 895 |
+
|
| 896 |
+
# ensure data quality
|
| 897 |
+
canon_studies: List[StudyJSON] = [
|
| 898 |
+
canonicalize_study(s, drop_tgt_too_few=False) for s in raw_studies
|
| 899 |
+
]
|
| 900 |
+
|
| 901 |
+
# we set all the individuals as context
|
| 902 |
+
studies: List[StudyJSON] = []
|
| 903 |
+
for study in canon_studies:
|
| 904 |
+
all_individuals = list(study.get("context", [])) + list(study.get("target", []))
|
| 905 |
+
studies.append(
|
| 906 |
+
{"context": all_individuals, "target": [], "meta_data": study.get("meta_data", {})}
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# define shapes
|
| 910 |
+
if not studies:
|
| 911 |
+
raise ValueError("No studies found in JSON file")
|
| 912 |
+
if datamodule is not None:
|
| 913 |
+
max_inds, max_obs, max_rem = datamodule.obtain_shapes() # (I, T, R)
|
| 914 |
+
ctx_cap = getattr(datamodule.train_dataset, "max_context_individuals", max_inds)
|
| 915 |
+
tgt_cap = getattr(datamodule.train_dataset, "n_of_target_individuals", max_inds)
|
| 916 |
+
else:
|
| 917 |
+
# compute statitics of the whole dataset
|
| 918 |
+
stats = compute_json_stats(canon_studies)
|
| 919 |
+
max_inds, max_obs, max_rem = (
|
| 920 |
+
stats.max_total_individuals,
|
| 921 |
+
stats.max_observations,
|
| 922 |
+
stats.max_remaining,
|
| 923 |
+
)
|
| 924 |
+
ctx_cap = max_inds
|
| 925 |
+
tgt_cap = max_inds
|
| 926 |
+
|
| 927 |
+
# the maximum batch is so that we have all the empirical at once
|
| 928 |
+
cfg = EmpiricalBatchConfig(
|
| 929 |
+
max_databatch_size=len(studies),
|
| 930 |
+
max_individuals=max_inds,
|
| 931 |
+
max_observations=max_obs,
|
| 932 |
+
max_remaining=max_rem,
|
| 933 |
+
max_context_individuals=ctx_cap,
|
| 934 |
+
max_target_individuals=tgt_cap,
|
| 935 |
+
)
|
| 936 |
+
builder = JSON2AICMEBuilder(cfg)
|
| 937 |
+
meta = meta_dosing or MetaDosingConfig()
|
| 938 |
+
|
| 939 |
+
return held_out_list_json(
|
| 940 |
+
builder, studies, meta, max_held_out_individuals=stats.max_total_individuals
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def load_empirical_json_batches_as_dm(
|
| 945 |
+
json_path: Optional[Path] = None,
|
| 946 |
+
meta_dosing: Optional[MetaDosingConfig] = None,
|
| 947 |
+
stats: Optional[EmpiricalJSONStats] = None,
|
| 948 |
+
datamodule: Optional[AICMECompartmentsDataModule] = None,
|
| 949 |
+
raw_studies: Optional[List[StudyJSON]] = None,
|
| 950 |
+
*,
|
| 951 |
+
held_out: bool = True,
|
| 952 |
+
) -> List[AICMECompartmentsDataBatch]:
|
| 953 |
+
"""Load an empirical study JSON file and build leave-one-out batches.
|
| 954 |
+
|
| 955 |
+
This variant mirrors the empirical preprocessing performed by
|
| 956 |
+
:class:`AICMECompartmentsDataset` by relying on the observation strategies
|
| 957 |
+
of a provided :class:`AICMECompartmentsDataModule` and using
|
| 958 |
+
:meth:`JSON2AICMEBuilder.build_one_aicmebatch_as_dataset`.
|
| 959 |
+
|
| 960 |
+
Parameters
|
| 961 |
+
----------
|
| 962 |
+
json_path:
|
| 963 |
+
Path to a JSON file containing a list of :class:`StudyJSON` records.
|
| 964 |
+
meta_dosing:
|
| 965 |
+
Global dosing configuration. If ``None`` a default
|
| 966 |
+
:class:`MetaDosingConfig` is used.
|
| 967 |
+
stats:
|
| 968 |
+
Pre-computed statistics describing the dataset. When ``None`` the
|
| 969 |
+
statistics are calculated from ``json_path`` via
|
| 970 |
+
:func:`compute_json_stats`.
|
| 971 |
+
datamodule:
|
| 972 |
+
Synthetic data module providing observation strategies and shape
|
| 973 |
+
information via :meth:`AICMECompartmentsDataModule.obtain_shapes`.
|
| 974 |
+
The module must be provided; its shapes override those inferred from
|
| 975 |
+
``stats``.
|
| 976 |
+
held_out:
|
| 977 |
+
If ``True`` (default), build leave-one-out permutations (one empirical
|
| 978 |
+
individual in target). If ``False``, keep all empirical individuals in
|
| 979 |
+
context and return a single batch.
|
| 980 |
+
|
| 981 |
+
Returns
|
| 982 |
+
-------
|
| 983 |
+
List[AICMECompartmentsDataBatch]
|
| 984 |
+
Leave-one-out batches constructed from the studies in ``json_path``
|
| 985 |
+
using the datamodule's strategies.
|
| 986 |
+
"""
|
| 987 |
+
|
| 988 |
+
if datamodule is None:
|
| 989 |
+
raise ValueError("datamodule must be provided to supply observation strategies")
|
| 990 |
+
|
| 991 |
+
if raw_studies is None:
|
| 992 |
+
with json_path.open() as f:
|
| 993 |
+
raw_studies = json.load(f)
|
| 994 |
+
|
| 995 |
+
if not isinstance(raw_studies, list):
|
| 996 |
+
raise ValueError("Expected JSON file to contain a list of StudyJSON records")
|
| 997 |
+
|
| 998 |
+
canon_studies: List[StudyJSON] = [
|
| 999 |
+
canonicalize_study(s, drop_tgt_too_few=False) for s in raw_studies
|
| 1000 |
+
]
|
| 1001 |
+
|
| 1002 |
+
if stats is None:
|
| 1003 |
+
stats = compute_json_stats(canon_studies)
|
| 1004 |
+
|
| 1005 |
+
studies: List[StudyJSON] = []
|
| 1006 |
+
for study in canon_studies:
|
| 1007 |
+
all_inds = list(study.get("context", [])) + list(study.get("target", []))
|
| 1008 |
+
studies.append({"context": all_inds, "target": [], "meta_data": study.get("meta_data", {})})
|
| 1009 |
+
|
| 1010 |
+
if not studies:
|
| 1011 |
+
raise ValueError("No studies found in JSON file")
|
| 1012 |
+
|
| 1013 |
+
max_inds, max_obs, max_rem = datamodule.obtain_shapes() # (I, T, R)
|
| 1014 |
+
ctx_cap = getattr(datamodule.train_dataset, "max_context_individuals", max_inds)
|
| 1015 |
+
tgt_cap = getattr(datamodule.train_dataset, "n_of_target_individuals", max_inds)
|
| 1016 |
+
context_strategy = getattr(datamodule, "context_strategy", None)
|
| 1017 |
+
# For empirical targets we prefer the dedicated datamodule override
|
| 1018 |
+
# (legacy PK behavior + fixed capacities), falling back to target_strategy.
|
| 1019 |
+
target_strategy = getattr(datamodule, "empirical_target_strategy", None)
|
| 1020 |
+
if target_strategy is None:
|
| 1021 |
+
target_strategy = getattr(datamodule, "target_strategy", None)
|
| 1022 |
+
if context_strategy is None or target_strategy is None:
|
| 1023 |
+
raise ValueError("datamodule is missing context or target strategies")
|
| 1024 |
+
|
| 1025 |
+
ctx_obs_cap, ctx_rem_cap = context_strategy.get_shapes()
|
| 1026 |
+
tgt_obs_cap, tgt_rem_cap = target_strategy.get_shapes()
|
| 1027 |
+
|
| 1028 |
+
cfg = EmpiricalBatchConfig(
|
| 1029 |
+
max_databatch_size=len(studies),
|
| 1030 |
+
max_individuals=max_inds,
|
| 1031 |
+
max_observations=max_obs,
|
| 1032 |
+
max_remaining=max_rem,
|
| 1033 |
+
max_context_individuals=ctx_cap,
|
| 1034 |
+
max_target_individuals=tgt_cap,
|
| 1035 |
+
max_context_observations=ctx_obs_cap,
|
| 1036 |
+
max_target_observations=tgt_obs_cap,
|
| 1037 |
+
max_context_remaining=ctx_rem_cap,
|
| 1038 |
+
max_target_remaining=tgt_rem_cap,
|
| 1039 |
+
)
|
| 1040 |
+
builder = JSON2AICMEBuilder(cfg)
|
| 1041 |
+
meta = meta_dosing or MetaDosingConfig()
|
| 1042 |
+
|
| 1043 |
+
if held_out:
|
| 1044 |
+
return builder.build_one_aicmebatch_as_dataset(
|
| 1045 |
+
studies, context_strategy, target_strategy, meta
|
| 1046 |
+
)
|
| 1047 |
+
return builder.build_one_aicmebatch_as_dataset_no_heldout(
|
| 1048 |
+
studies, context_strategy, target_strategy, meta
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def load_empirical_hf_batches_as_dm(
|
| 1053 |
+
repo_id: str,
|
| 1054 |
+
split: str = "train",
|
| 1055 |
+
meta_dosing: Optional[MetaDosingConfig] = None,
|
| 1056 |
+
stats: Optional[EmpiricalJSONStats] = None,
|
| 1057 |
+
datamodule: Optional[AICMECompartmentsDataModule] = None,
|
| 1058 |
+
*,
|
| 1059 |
+
held_out: bool = True,
|
| 1060 |
+
) -> List[AICMECompartmentsDataBatch]:
|
| 1061 |
+
"""Load a StudyJSON dataset from Hugging Face Hub.
|
| 1062 |
+
|
| 1063 |
+
Parameters
|
| 1064 |
+
----------
|
| 1065 |
+
repo_id:
|
| 1066 |
+
Hugging Face dataset id.
|
| 1067 |
+
split:
|
| 1068 |
+
Dataset split to load.
|
| 1069 |
+
meta_dosing:
|
| 1070 |
+
Dosing configuration.
|
| 1071 |
+
stats:
|
| 1072 |
+
Optional precomputed dataset statistics.
|
| 1073 |
+
datamodule:
|
| 1074 |
+
Datamodule providing empirical shape and strategy information.
|
| 1075 |
+
held_out:
|
| 1076 |
+
If ``True`` (default), build leave-one-out permutations. If ``False``,
|
| 1077 |
+
keep all empirical individuals in context and return a single batch.
|
| 1078 |
+
"""
|
| 1079 |
+
|
| 1080 |
+
if datamodule is None:
|
| 1081 |
+
raise ValueError("datamodule must be provided to supply observation strategies")
|
| 1082 |
+
|
| 1083 |
+
# Load from HF Hub
|
| 1084 |
+
ds = load_dataset(repo_id, split=split)
|
| 1085 |
+
raw_studies = [dict(study) for study in ds] # Hugging Face rows are dict-like
|
| 1086 |
+
|
| 1087 |
+
# reuse your old code
|
| 1088 |
+
canon_studies: List[StudyJSON] = [
|
| 1089 |
+
canonicalize_study(s, drop_tgt_too_few=False) for s in raw_studies
|
| 1090 |
+
]
|
| 1091 |
+
|
| 1092 |
+
if stats is None:
|
| 1093 |
+
stats = compute_json_stats(canon_studies)
|
| 1094 |
+
|
| 1095 |
+
studies: List[StudyJSON] = []
|
| 1096 |
+
for study in canon_studies:
|
| 1097 |
+
all_inds = list(study.get("context", [])) + list(study.get("target", []))
|
| 1098 |
+
studies.append({"context": all_inds, "target": [], "meta_data": study.get("meta_data", {})})
|
| 1099 |
+
|
| 1100 |
+
if not studies:
|
| 1101 |
+
raise ValueError("No studies found in HF dataset")
|
| 1102 |
+
|
| 1103 |
+
max_inds, max_obs, max_rem = datamodule.obtain_shapes()
|
| 1104 |
+
ctx_cap = getattr(datamodule.train_dataset, "max_context_individuals", max_inds)
|
| 1105 |
+
tgt_cap = getattr(datamodule.train_dataset, "n_of_target_individuals", max_inds)
|
| 1106 |
+
context_strategy = getattr(datamodule, "context_strategy", None)
|
| 1107 |
+
# For empirical targets we prefer the dedicated datamodule override
|
| 1108 |
+
# (legacy PK behavior + fixed capacities), falling back to target_strategy.
|
| 1109 |
+
target_strategy = getattr(datamodule, "empirical_target_strategy", None)
|
| 1110 |
+
if target_strategy is None:
|
| 1111 |
+
target_strategy = getattr(datamodule, "target_strategy", None)
|
| 1112 |
+
if context_strategy is None or target_strategy is None:
|
| 1113 |
+
raise ValueError("datamodule is missing context or target strategies")
|
| 1114 |
+
|
| 1115 |
+
ctx_obs_cap, ctx_rem_cap = context_strategy.get_shapes()
|
| 1116 |
+
tgt_obs_cap, tgt_rem_cap = target_strategy.get_shapes()
|
| 1117 |
+
|
| 1118 |
+
cfg = EmpiricalBatchConfig(
|
| 1119 |
+
max_databatch_size=len(studies),
|
| 1120 |
+
max_individuals=max_inds,
|
| 1121 |
+
max_observations=max_obs,
|
| 1122 |
+
max_remaining=max_rem,
|
| 1123 |
+
max_context_individuals=ctx_cap,
|
| 1124 |
+
max_target_individuals=tgt_cap,
|
| 1125 |
+
max_context_observations=ctx_obs_cap,
|
| 1126 |
+
max_target_observations=tgt_obs_cap,
|
| 1127 |
+
max_context_remaining=ctx_rem_cap,
|
| 1128 |
+
max_target_remaining=tgt_rem_cap,
|
| 1129 |
+
)
|
| 1130 |
+
builder = JSON2AICMEBuilder(cfg)
|
| 1131 |
+
meta = meta_dosing or MetaDosingConfig()
|
| 1132 |
+
|
| 1133 |
+
if held_out:
|
| 1134 |
+
return builder.build_one_aicmebatch_as_dataset(
|
| 1135 |
+
studies, context_strategy, target_strategy, meta
|
| 1136 |
+
)
|
| 1137 |
+
return builder.build_one_aicmebatch_as_dataset_no_heldout(
|
| 1138 |
+
studies, context_strategy, target_strategy, meta
|
| 1139 |
+
)
|
sim_priors_pk/data/data_empirical/json_schema.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TypedDict schemas for empirical pharmacokinetic JSON inputs."""
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypedDict
|
| 4 |
+
|
| 5 |
+
try: # pragma: no cover - optional torch dependency
|
| 6 |
+
import torch
|
| 7 |
+
from torchtyping import TensorType as TT
|
| 8 |
+
except ModuleNotFoundError: # pragma: no cover - allow missing torch
|
| 9 |
+
torch = None # type: ignore
|
| 10 |
+
TT = object # type: ignore
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING: # pragma: no cover - typing only
|
| 13 |
+
from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class IndividualJSON(TypedDict, total=False):
|
| 17 |
+
"""Schema for a single individual's PK data.
|
| 18 |
+
|
| 19 |
+
Optional ``prediction_samples`` and ``prediction_times`` fields allow
|
| 20 |
+
storing model forecasts for the individual's future trajectory.
|
| 21 |
+
Each element in ``prediction_samples`` corresponds to a full simulated
|
| 22 |
+
trajectory for the times listed in ``prediction_times``.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
name_id: str
|
| 26 |
+
observations: List[float]
|
| 27 |
+
observation_times: List[float]
|
| 28 |
+
remaining: List[float]
|
| 29 |
+
remaining_times: List[float]
|
| 30 |
+
dosing: List[float]
|
| 31 |
+
dosing_type: List[str]
|
| 32 |
+
dosing_times: List[float]
|
| 33 |
+
dosing_name: List[str]
|
| 34 |
+
prediction_samples: List[List[float]]
|
| 35 |
+
prediction_times: List[float]
|
| 36 |
+
covariates: Dict[str, object]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class StudyJSON(TypedDict):
|
| 40 |
+
"""Schema for a full study consisting of context and target individuals."""
|
| 41 |
+
|
| 42 |
+
context: List[IndividualJSON]
|
| 43 |
+
target: List[IndividualJSON]
|
| 44 |
+
meta_data: Dict[str, str]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
MIN_OBS_DEFAULT = 0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ValidationError(Exception):
|
| 51 |
+
"""Raised when data do not conform to :class:`StudyJSON`."""
|
| 52 |
+
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def canonicalize_individual(
|
| 57 |
+
ind: IndividualJSON,
|
| 58 |
+
*,
|
| 59 |
+
min_obs: int = MIN_OBS_DEFAULT,
|
| 60 |
+
drop_if_too_few: bool = True,
|
| 61 |
+
) -> Optional[IndividualJSON]:
|
| 62 |
+
"""Return a canonical version of ``ind``.
|
| 63 |
+
|
| 64 |
+
Parameters
|
| 65 |
+
----------
|
| 66 |
+
ind:
|
| 67 |
+
Individual JSON record to canonicalize. The input dictionary is **not**
|
| 68 |
+
mutated.
|
| 69 |
+
min_obs:
|
| 70 |
+
Minimum required number of observations. Defaults to
|
| 71 |
+
:data:`MIN_OBS_DEFAULT`.
|
| 72 |
+
drop_if_too_few:
|
| 73 |
+
If ``True`` and the individual has fewer than ``min_obs`` observations
|
| 74 |
+
after sorting/de-duplication, ``None`` is returned.
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
Optional[IndividualJSON]
|
| 79 |
+
Canonicalized record or ``None`` when dropped.
|
| 80 |
+
|
| 81 |
+
Notes
|
| 82 |
+
-----
|
| 83 |
+
The function performs the following steps:
|
| 84 |
+
|
| 85 |
+
- Validate the presence and equal length of ``observations`` and
|
| 86 |
+
``observation_times``.
|
| 87 |
+
- Sort observations by ascending time and remove duplicate time entries
|
| 88 |
+
keeping the first occurrence.
|
| 89 |
+
- Optionally drop the individual if the number of observations is below
|
| 90 |
+
``min_obs``.
|
| 91 |
+
- Ensure ``remaining``/``remaining_times`` are disjoint from
|
| 92 |
+
``observation_times`` and of equal length.
|
| 93 |
+
- If any dosing related fields are provided, require that all dosing fields
|
| 94 |
+
are present and have equal lengths.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
# --- observations & times ---
|
| 98 |
+
if "observations" not in ind or "observation_times" not in ind:
|
| 99 |
+
raise ValidationError("observations and observation_times are required")
|
| 100 |
+
|
| 101 |
+
obs = list(ind["observations"])
|
| 102 |
+
times = list(ind["observation_times"])
|
| 103 |
+
if len(obs) != len(times):
|
| 104 |
+
raise ValidationError("observations and observation_times must match in length")
|
| 105 |
+
|
| 106 |
+
# sort and de-duplicate by time (stable sort keeps first occurrence)
|
| 107 |
+
pairs = sorted(zip(times, obs), key=lambda x: x[0])
|
| 108 |
+
seen = set()
|
| 109 |
+
obs_sorted: List[float] = []
|
| 110 |
+
times_sorted: List[float] = []
|
| 111 |
+
for t, o in pairs:
|
| 112 |
+
if t in seen:
|
| 113 |
+
continue
|
| 114 |
+
seen.add(t)
|
| 115 |
+
times_sorted.append(t)
|
| 116 |
+
obs_sorted.append(o)
|
| 117 |
+
|
| 118 |
+
if len(obs_sorted) < min_obs and drop_if_too_few:
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
new_ind: IndividualJSON = {}
|
| 122 |
+
if "name_id" in ind:
|
| 123 |
+
new_ind["name_id"] = ind["name_id"]
|
| 124 |
+
new_ind["observations"] = obs_sorted
|
| 125 |
+
new_ind["observation_times"] = times_sorted
|
| 126 |
+
|
| 127 |
+
# --- remaining ---
|
| 128 |
+
has_rem = "remaining" in ind or "remaining_times" in ind
|
| 129 |
+
if has_rem:
|
| 130 |
+
if "remaining" not in ind or "remaining_times" not in ind:
|
| 131 |
+
raise ValidationError(
|
| 132 |
+
"both remaining and remaining_times required when one is provided"
|
| 133 |
+
)
|
| 134 |
+
rem = list(ind["remaining"])
|
| 135 |
+
rem_t = list(ind["remaining_times"])
|
| 136 |
+
if len(rem) != len(rem_t):
|
| 137 |
+
raise ValidationError("remaining and remaining_times must match in length")
|
| 138 |
+
obs_time_set = set(times_sorted)
|
| 139 |
+
rem_filtered: List[float] = []
|
| 140 |
+
rem_t_filtered: List[float] = []
|
| 141 |
+
for t, r in zip(rem_t, rem):
|
| 142 |
+
if t in obs_time_set:
|
| 143 |
+
continue
|
| 144 |
+
rem_t_filtered.append(t)
|
| 145 |
+
rem_filtered.append(r)
|
| 146 |
+
new_ind["remaining"] = rem_filtered
|
| 147 |
+
new_ind["remaining_times"] = rem_t_filtered
|
| 148 |
+
|
| 149 |
+
# --- dosing ---
|
| 150 |
+
dosing_keys = ["dosing", "dosing_type", "dosing_times", "dosing_name"]
|
| 151 |
+
present_dosing = [k for k in dosing_keys if k in ind]
|
| 152 |
+
if present_dosing:
|
| 153 |
+
if len(present_dosing) != len(dosing_keys):
|
| 154 |
+
raise ValidationError("all dosing fields must be present when dosing is provided")
|
| 155 |
+
lengths = [len(ind[k]) for k in dosing_keys] # type: ignore[index]
|
| 156 |
+
if len(set(lengths)) != 1:
|
| 157 |
+
raise ValidationError("dosing fields must have equal lengths")
|
| 158 |
+
for k in dosing_keys:
|
| 159 |
+
new_ind[k] = list(ind[k]) # type: ignore[index]
|
| 160 |
+
|
| 161 |
+
# --- covariates ---
|
| 162 |
+
if "covariates" in ind:
|
| 163 |
+
new_ind["covariates"] = dict(ind["covariates"])
|
| 164 |
+
|
| 165 |
+
# --- prediction samples ---
|
| 166 |
+
if "prediction_samples" in ind:
|
| 167 |
+
new_ind["prediction_samples"] = [list(s) for s in ind["prediction_samples"]]
|
| 168 |
+
if "prediction_times" in ind:
|
| 169 |
+
new_ind["prediction_times"] = list(ind["prediction_times"])
|
| 170 |
+
if "prediction_mean" in ind:
|
| 171 |
+
new_ind["prediction_mean"] = list(ind["prediction_mean"])
|
| 172 |
+
if "prediction_std" in ind:
|
| 173 |
+
new_ind["prediction_std"] = list(ind["prediction_std"])
|
| 174 |
+
|
| 175 |
+
return new_ind
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def canonicalize_study(
|
| 179 |
+
study: StudyJSON,
|
| 180 |
+
*,
|
| 181 |
+
min_obs_ctx: int = MIN_OBS_DEFAULT,
|
| 182 |
+
min_obs_tgt: int = MIN_OBS_DEFAULT,
|
| 183 |
+
drop_tgt_too_few: bool = True,
|
| 184 |
+
) -> StudyJSON:
|
| 185 |
+
"""Canonicalize all individuals in ``study`` and validate meta data."""
|
| 186 |
+
|
| 187 |
+
meta = study.get("meta_data", {})
|
| 188 |
+
if not meta.get("study_name") or not meta.get("substance_name"):
|
| 189 |
+
raise ValidationError("meta_data must include non-empty study_name and substance_name")
|
| 190 |
+
|
| 191 |
+
context_canon: List[IndividualJSON] = []
|
| 192 |
+
for ind in study.get("context", []):
|
| 193 |
+
canon = canonicalize_individual(ind, min_obs=min_obs_ctx, drop_if_too_few=False)
|
| 194 |
+
if canon is not None:
|
| 195 |
+
context_canon.append(canon)
|
| 196 |
+
|
| 197 |
+
target_canon: List[IndividualJSON] = []
|
| 198 |
+
for ind in study.get("target", []):
|
| 199 |
+
canon = canonicalize_individual(ind, min_obs=min_obs_tgt, drop_if_too_few=drop_tgt_too_few)
|
| 200 |
+
if canon is not None:
|
| 201 |
+
target_canon.append(canon)
|
| 202 |
+
|
| 203 |
+
new_study: StudyJSON = {
|
| 204 |
+
"context": context_canon,
|
| 205 |
+
"target": target_canon,
|
| 206 |
+
"meta_data": dict(meta),
|
| 207 |
+
}
|
| 208 |
+
return new_study
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def studies_from_sampled_targets(
|
| 212 |
+
*,
|
| 213 |
+
db: "AICMECompartmentsDataBatch",
|
| 214 |
+
samples: "TT['S', 'B', 'T', 1]",
|
| 215 |
+
times: "TT['B', 'T', 1]",
|
| 216 |
+
mask: "TT['B', 'T']",
|
| 217 |
+
route_options: Sequence[str],
|
| 218 |
+
dosing_time: float,
|
| 219 |
+
name_prefix: str = "new_individual",
|
| 220 |
+
) -> List[StudyJSON]:
|
| 221 |
+
"""Convert sampled trajectories into :class:`StudyJSON` records.
|
| 222 |
+
|
| 223 |
+
Parameters
|
| 224 |
+
----------
|
| 225 |
+
db:
|
| 226 |
+
Batch containing the conditioning study information. Only the fields
|
| 227 |
+
accessed in this function are required, allowing reuse with compatible
|
| 228 |
+
NamedTuple implementations used throughout the project.
|
| 229 |
+
samples, times, mask:
|
| 230 |
+
Output tensors from ``sample_new_individual`` where ``samples`` carries
|
| 231 |
+
the simulated trajectories, ``times`` their corresponding decode times
|
| 232 |
+
and ``mask`` selects valid entries along the temporal dimension.
|
| 233 |
+
route_options:
|
| 234 |
+
Lookup table translating dosing route indices into human readable
|
| 235 |
+
labels. Indices outside the provided range are returned as their string
|
| 236 |
+
representation.
|
| 237 |
+
dosing_time:
|
| 238 |
+
Absolute time at which the dosing event occurred. Used for both context
|
| 239 |
+
and newly sampled target individuals when dosing information is
|
| 240 |
+
present.
|
| 241 |
+
name_prefix:
|
| 242 |
+
Prefix for generated target individual identifiers. Defaults to
|
| 243 |
+
``"new_individual"``.
|
| 244 |
+
|
| 245 |
+
Returns
|
| 246 |
+
-------
|
| 247 |
+
list[StudyJSON]
|
| 248 |
+
One ``StudyJSON`` per batch element in ``db``.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
if torch is None:
|
| 252 |
+
raise ValidationError("torch is required to build StudyJSON records from tensors")
|
| 253 |
+
|
| 254 |
+
S, B, _, _ = samples.shape
|
| 255 |
+
studies: List[StudyJSON] = []
|
| 256 |
+
|
| 257 |
+
for b in range(B):
|
| 258 |
+
study_name = (
|
| 259 |
+
db.study_name[b] if b < len(db.study_name) and db.study_name[b] else f"study_{b}"
|
| 260 |
+
)
|
| 261 |
+
substance_name = (
|
| 262 |
+
db.substance_name[b]
|
| 263 |
+
if b < len(db.substance_name) and db.substance_name[b]
|
| 264 |
+
else f"substance_{b}"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
context_list: List[IndividualJSON] = []
|
| 268 |
+
I = db.context_obs.shape[1]
|
| 269 |
+
for i in range(I):
|
| 270 |
+
if not db.mask_context_individuals[b, i]:
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
ind: IndividualJSON = {}
|
| 274 |
+
if b < len(db.context_subject_name) and i < len(db.context_subject_name[b]):
|
| 275 |
+
name = db.context_subject_name[b][i]
|
| 276 |
+
if name:
|
| 277 |
+
ind["name_id"] = name
|
| 278 |
+
|
| 279 |
+
obs_i = db.context_obs[b, i, :, 0]
|
| 280 |
+
time_i = db.context_obs_time[b, i, :, 0]
|
| 281 |
+
mask_i = db.context_obs_mask[b, i]
|
| 282 |
+
ind["observations"] = obs_i[mask_i].tolist()
|
| 283 |
+
ind["observation_times"] = time_i[mask_i].tolist()
|
| 284 |
+
|
| 285 |
+
if db.context_rem_sim.shape[2] > 0:
|
| 286 |
+
rem_i = db.context_rem_sim[b, i, :, 0]
|
| 287 |
+
rem_t = db.context_rem_sim_time[b, i, :, 0]
|
| 288 |
+
rem_m = db.context_rem_sim_mask[b, i]
|
| 289 |
+
rem_vals = rem_i[rem_m].tolist()
|
| 290 |
+
rem_times = rem_t[rem_m].tolist()
|
| 291 |
+
if rem_vals:
|
| 292 |
+
ind["remaining"] = rem_vals
|
| 293 |
+
ind["remaining_times"] = rem_times
|
| 294 |
+
|
| 295 |
+
dose = float(db.context_dosing_amounts[b, i].item())
|
| 296 |
+
route_idx = int(db.context_dosing_route_types[b, i].item())
|
| 297 |
+
if dose or route_idx:
|
| 298 |
+
route = (
|
| 299 |
+
route_options[route_idx] if route_idx < len(route_options) else str(route_idx)
|
| 300 |
+
)
|
| 301 |
+
ind["dosing"] = [dose]
|
| 302 |
+
ind["dosing_type"] = [route]
|
| 303 |
+
ind["dosing_times"] = [dosing_time]
|
| 304 |
+
ind["dosing_name"] = [route]
|
| 305 |
+
|
| 306 |
+
context_list.append(ind)
|
| 307 |
+
|
| 308 |
+
target_list: List[IndividualJSON] = []
|
| 309 |
+
valid_mask = mask[b]
|
| 310 |
+
valid_times = times[b, valid_mask, 0].tolist()
|
| 311 |
+
for s in range(S):
|
| 312 |
+
traj = samples[s, b, valid_mask, 0].tolist()
|
| 313 |
+
ind: IndividualJSON = {
|
| 314 |
+
"name_id": f"{name_prefix}_{s}",
|
| 315 |
+
"observations": traj,
|
| 316 |
+
"observation_times": valid_times,
|
| 317 |
+
}
|
| 318 |
+
dose = float(db.target_dosing_amounts[b, 0].item())
|
| 319 |
+
route_idx = int(db.target_dosing_route_types[b, 0].item())
|
| 320 |
+
if dose or route_idx:
|
| 321 |
+
route = (
|
| 322 |
+
route_options[route_idx] if route_idx < len(route_options) else str(route_idx)
|
| 323 |
+
)
|
| 324 |
+
ind["dosing"] = [dose]
|
| 325 |
+
ind["dosing_type"] = [route]
|
| 326 |
+
ind["dosing_times"] = [dosing_time]
|
| 327 |
+
ind["dosing_name"] = [route]
|
| 328 |
+
target_list.append(ind)
|
| 329 |
+
|
| 330 |
+
studies.append(
|
| 331 |
+
{
|
| 332 |
+
"context": context_list,
|
| 333 |
+
"target": target_list,
|
| 334 |
+
"meta_data": {
|
| 335 |
+
"study_name": study_name,
|
| 336 |
+
"substance_name": substance_name,
|
| 337 |
+
},
|
| 338 |
+
}
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
return studies
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def prediction_stats(study: StudyJSON) -> StudyJSON:
|
| 345 |
+
"""Compute prediction mean and std for target individuals.
|
| 346 |
+
|
| 347 |
+
For each target individual with ``prediction_samples`` the function
|
| 348 |
+
calculates the mean and standard deviation across the sample dimension and
|
| 349 |
+
stores the results in ``prediction_mean`` and ``prediction_std`` fields.
|
| 350 |
+
|
| 351 |
+
Parameters
|
| 352 |
+
----------
|
| 353 |
+
study:
|
| 354 |
+
``StudyJSON`` record containing prediction samples.
|
| 355 |
+
|
| 356 |
+
Returns
|
| 357 |
+
-------
|
| 358 |
+
StudyJSON
|
| 359 |
+
The input study where target individuals now also carry ``prediction_mean``
|
| 360 |
+
and ``prediction_std`` fields. The input mapping is mutated for
|
| 361 |
+
convenience.
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
for ind in study.get("target", []):
|
| 365 |
+
samples = ind.get("prediction_samples")
|
| 366 |
+
if samples:
|
| 367 |
+
if torch is None:
|
| 368 |
+
raise ValidationError("torch is required to compute prediction summaries")
|
| 369 |
+
samples_t: TT["S", "Tr"] = torch.tensor(samples)
|
| 370 |
+
ind["prediction_mean"] = samples_t.mean(dim=0).tolist()
|
| 371 |
+
ind["prediction_std"] = samples_t.std(dim=0, unbiased=False).tolist()
|
| 372 |
+
return study
|
sim_priors_pk/data/data_empirical/json_stats.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This is only used for checking the shapes of the empirical data that are passed to the Dataloader"""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Dict, List, Sequence, Set
|
| 8 |
+
|
| 9 |
+
from .json_schema import StudyJSON
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class EmpiricalJSONStats:
|
| 14 |
+
"""Basic statistics collected from empirical study JSON files.
|
| 15 |
+
|
| 16 |
+
Attributes
|
| 17 |
+
----------
|
| 18 |
+
min_context_individuals, max_context_individuals:
|
| 19 |
+
Range of context individuals across studies.
|
| 20 |
+
min_target_individuals, max_target_individuals:
|
| 21 |
+
Range of target individuals across studies.
|
| 22 |
+
min_observation, max_observation:
|
| 23 |
+
Extremal observed values across all individuals.
|
| 24 |
+
substances:
|
| 25 |
+
Sorted list of distinct substance names.
|
| 26 |
+
max_total_individuals:
|
| 27 |
+
Maximum combined number of context and target individuals in a study.
|
| 28 |
+
max_observations:
|
| 29 |
+
Maximum number of observation time points for any individual.
|
| 30 |
+
max_remaining:
|
| 31 |
+
Maximum number of remaining time points for any individual.
|
| 32 |
+
substance_summaries:
|
| 33 |
+
Nested mapping keyed by substance name containing per-substance
|
| 34 |
+
statistics. Each inner dictionary exposes the total number of
|
| 35 |
+
individuals, the minimum and maximum number of observation time points
|
| 36 |
+
per individual, and the sorted list of unique time steps observed
|
| 37 |
+
across all individuals (including observation and remaining times) for
|
| 38 |
+
the substance.
|
| 39 |
+
studies_by_substance:
|
| 40 |
+
Mapping from substance name to the list of studies associated with it.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
min_context_individuals: int
|
| 44 |
+
max_context_individuals: int
|
| 45 |
+
min_target_individuals: int
|
| 46 |
+
max_target_individuals: int
|
| 47 |
+
min_observation: float
|
| 48 |
+
max_observation: float
|
| 49 |
+
substances: List[str]
|
| 50 |
+
max_total_individuals: int
|
| 51 |
+
max_observations: int
|
| 52 |
+
max_remaining: int
|
| 53 |
+
substance_summaries: Dict[str, Dict[str, object]]
|
| 54 |
+
studies_by_substance: Dict[str, List[StudyJSON]]
|
| 55 |
+
|
| 56 |
+
def studies_for_substance(self, substance: str) -> List[StudyJSON]:
|
| 57 |
+
"""Return all studies that reference ``substance``.
|
| 58 |
+
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
substance:
|
| 62 |
+
Name of the substance whose studies should be returned.
|
| 63 |
+
|
| 64 |
+
Returns
|
| 65 |
+
-------
|
| 66 |
+
List[StudyJSON]
|
| 67 |
+
Study dictionaries associated with ``substance``. An empty list is
|
| 68 |
+
returned when the substance was not observed.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
return list(self.studies_by_substance.get(substance, []))
|
| 72 |
+
|
| 73 |
+
def get_substance_summary(self, substance: str) -> Dict[str, object]:
|
| 74 |
+
"""Return the per-substance statistics for ``substance``.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
substance:
|
| 79 |
+
Name of the substance whose statistics should be retrieved.
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
Dict[str, object]
|
| 84 |
+
Dictionary containing the ``individual_count``,
|
| 85 |
+
``min_observations``, ``max_observations`` and
|
| 86 |
+
``observation_time_steps`` entries. An empty dictionary is returned
|
| 87 |
+
if the substance is unknown.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
summary = self.substance_summaries.get(substance)
|
| 91 |
+
if summary is None:
|
| 92 |
+
return {}
|
| 93 |
+
return dict(summary)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_json_stats(studies: Sequence[StudyJSON]) -> EmpiricalJSONStats:
|
| 97 |
+
"""Compute statistics across empirical pharmacokinetic studies.
|
| 98 |
+
|
| 99 |
+
Parameters
|
| 100 |
+
----------
|
| 101 |
+
studies:
|
| 102 |
+
Sequence of :class:`StudyJSON` objects to aggregate.
|
| 103 |
+
|
| 104 |
+
Returns
|
| 105 |
+
-------
|
| 106 |
+
EmpiricalJSONStats
|
| 107 |
+
Aggregated statistics across all provided studies.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
min_ctx = float("inf")
|
| 111 |
+
max_ctx = 0
|
| 112 |
+
min_tgt = float("inf")
|
| 113 |
+
max_tgt = 0
|
| 114 |
+
min_obs = float("inf")
|
| 115 |
+
max_obs = float("-inf")
|
| 116 |
+
substances = set()
|
| 117 |
+
max_total_inds = 0
|
| 118 |
+
max_obs_len = 0
|
| 119 |
+
max_rem_len = 0
|
| 120 |
+
|
| 121 |
+
substance_counts: Dict[str, int] = defaultdict(int)
|
| 122 |
+
substance_min_obs: Dict[str, int] = {}
|
| 123 |
+
substance_max_obs: Dict[str, int] = {}
|
| 124 |
+
substance_times: Dict[str, Set[float]] = defaultdict(set)
|
| 125 |
+
studies_by_substance: Dict[str, List[StudyJSON]] = defaultdict(list)
|
| 126 |
+
|
| 127 |
+
for study in studies:
|
| 128 |
+
c_len = len(study.get("context", []))
|
| 129 |
+
t_len = len(study.get("target", []))
|
| 130 |
+
total_len = c_len + t_len
|
| 131 |
+
min_ctx = min(min_ctx, c_len)
|
| 132 |
+
max_ctx = max(max_ctx, c_len)
|
| 133 |
+
min_tgt = min(min_tgt, t_len)
|
| 134 |
+
max_tgt = max(max_tgt, t_len)
|
| 135 |
+
max_total_inds = max(max_total_inds, total_len)
|
| 136 |
+
meta = study.get("meta_data", {})
|
| 137 |
+
substance = meta.get("substance_name")
|
| 138 |
+
if substance:
|
| 139 |
+
substances.add(substance)
|
| 140 |
+
studies_by_substance[substance].append(study)
|
| 141 |
+
for ind in study.get("context", []) + study.get("target", []):
|
| 142 |
+
obs = ind.get("observations", [])
|
| 143 |
+
obs_len = len(obs)
|
| 144 |
+
rem = ind.get("remaining", [])
|
| 145 |
+
times = ind.get("observation_times", [])
|
| 146 |
+
rem_times = ind.get("remaining_times", [])
|
| 147 |
+
|
| 148 |
+
max_obs_len = max(max_obs_len, len(obs))
|
| 149 |
+
max_rem_len = max(max_rem_len, len(rem))
|
| 150 |
+
if obs:
|
| 151 |
+
min_obs = min(min_obs, min(obs))
|
| 152 |
+
max_obs = max(max_obs, max(obs))
|
| 153 |
+
|
| 154 |
+
if substance:
|
| 155 |
+
substance_counts[substance] += 1
|
| 156 |
+
current_min = substance_min_obs.get(substance)
|
| 157 |
+
if current_min is None:
|
| 158 |
+
substance_min_obs[substance] = obs_len
|
| 159 |
+
else:
|
| 160 |
+
substance_min_obs[substance] = min(current_min, obs_len)
|
| 161 |
+
current_max = substance_max_obs.get(substance)
|
| 162 |
+
if current_max is None:
|
| 163 |
+
substance_max_obs[substance] = obs_len
|
| 164 |
+
else:
|
| 165 |
+
substance_max_obs[substance] = max(current_max, obs_len)
|
| 166 |
+
substance_times[substance].update(times)
|
| 167 |
+
substance_times[substance].update(rem_times)
|
| 168 |
+
|
| 169 |
+
if min_ctx == float("inf"):
|
| 170 |
+
min_ctx = 0
|
| 171 |
+
if min_tgt == float("inf"):
|
| 172 |
+
min_tgt = 0
|
| 173 |
+
if min_obs == float("inf"):
|
| 174 |
+
min_obs = float("nan")
|
| 175 |
+
if max_obs == float("-inf"):
|
| 176 |
+
max_obs = float("nan")
|
| 177 |
+
|
| 178 |
+
substance_summaries = {
|
| 179 |
+
substance: {
|
| 180 |
+
"individual_count": substance_counts.get(substance, 0),
|
| 181 |
+
"min_observations": substance_min_obs.get(substance, 0),
|
| 182 |
+
"max_observations": substance_max_obs.get(substance, 0),
|
| 183 |
+
"observation_time_steps": sorted(substance_times.get(substance, set())),
|
| 184 |
+
}
|
| 185 |
+
for substance in sorted(substances)
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
return EmpiricalJSONStats(
|
| 189 |
+
min_context_individuals=int(min_ctx),
|
| 190 |
+
max_context_individuals=int(max_ctx),
|
| 191 |
+
min_target_individuals=int(min_tgt),
|
| 192 |
+
max_target_individuals=int(max_tgt),
|
| 193 |
+
min_observation=float(min_obs),
|
| 194 |
+
max_observation=float(max_obs),
|
| 195 |
+
substances=sorted(substances),
|
| 196 |
+
max_total_individuals=int(max_total_inds),
|
| 197 |
+
max_observations=int(max_obs_len),
|
| 198 |
+
max_remaining=int(max_rem_len),
|
| 199 |
+
substance_summaries=substance_summaries,
|
| 200 |
+
studies_by_substance={k: list(v) for k, v in studies_by_substance.items()},
|
| 201 |
+
)
|
sim_priors_pk/data/data_empirical/simulx_to_json.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tools for converting simulx output .csv files (simulation from an NLME model) to study JSON format
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import csv
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from typing import Sequence
|
| 8 |
+
|
| 9 |
+
from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def simulx_to_json(
|
| 13 |
+
csv_path,
|
| 14 |
+
study_name="simulated_study",
|
| 15 |
+
substance_name="Drug_A",
|
| 16 |
+
dosing_type="oral"
|
| 17 |
+
) -> Sequence[StudyJSON]:
|
| 18 |
+
# rep -> ID -> data
|
| 19 |
+
reps = defaultdict(lambda: defaultdict(lambda: {
|
| 20 |
+
"observations": [],
|
| 21 |
+
"observation_times": [],
|
| 22 |
+
"dosing": [],
|
| 23 |
+
"dosing_type": [],
|
| 24 |
+
"dosing_times": [],
|
| 25 |
+
"dosing_name": []
|
| 26 |
+
}))
|
| 27 |
+
|
| 28 |
+
with open(csv_path, newline="") as f:
|
| 29 |
+
reader = csv.DictReader(f)
|
| 30 |
+
for row in reader:
|
| 31 |
+
rep = int(row["rep"])
|
| 32 |
+
id_ = row["ID"]
|
| 33 |
+
time = float(row["TIME"])
|
| 34 |
+
|
| 35 |
+
# Observations
|
| 36 |
+
if row["value"] != ".":
|
| 37 |
+
reps[rep][id_]["observations"].append(float(row["value"]))
|
| 38 |
+
reps[rep][id_]["observation_times"].append(time)
|
| 39 |
+
|
| 40 |
+
# Dosing (assumed at TIME == 0)
|
| 41 |
+
if time == 0 and row["AMOUNT"] != ".":
|
| 42 |
+
reps[rep][id_]["dosing"].append(float(row["AMOUNT"]))
|
| 43 |
+
reps[rep][id_]["dosing_times"].append(0.0)
|
| 44 |
+
reps[rep][id_]["dosing_type"].append(dosing_type)
|
| 45 |
+
reps[rep][id_]["dosing_name"].append(dosing_type)
|
| 46 |
+
|
| 47 |
+
# Build final output: one JSON object per rep
|
| 48 |
+
output = []
|
| 49 |
+
|
| 50 |
+
for rep, ids in sorted(reps.items()):
|
| 51 |
+
contexts = []
|
| 52 |
+
for i, (id_, data) in enumerate(ids.items()):
|
| 53 |
+
contexts.append({
|
| 54 |
+
"name_id": f"context_{id_}",
|
| 55 |
+
**data
|
| 56 |
+
})
|
| 57 |
+
|
| 58 |
+
study = StudyJSON({
|
| 59 |
+
"context": contexts,
|
| 60 |
+
"meta_data": {
|
| 61 |
+
"study_name": f"{study_name}_rep{rep}",
|
| 62 |
+
"substance_name": substance_name
|
| 63 |
+
}
|
| 64 |
+
})
|
| 65 |
+
output.append(study)
|
| 66 |
+
|
| 67 |
+
return output
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
output = simulx_to_json(csv_path="data/raw_nlme_simulx/indometacin-test-data.csv")
|
| 71 |
+
|
sim_priors_pk/data/data_generation/__init__.py
ADDED
|
File without changes
|
sim_priors_pk/data/data_generation/compartment_models.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Callable, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torchdiffeq import odeint
|
| 8 |
+
from torchtyping import TensorType
|
| 9 |
+
|
| 10 |
+
from sim_priors_pk.config_classes.data_config import (
|
| 11 |
+
DosingConfig,
|
| 12 |
+
DosingWithDurationConfig,
|
| 13 |
+
MetaDosingConfig,
|
| 14 |
+
MetaDosingWithDurationConfig,
|
| 15 |
+
MetaStudyConfig,
|
| 16 |
+
)
|
| 17 |
+
from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class StudyConfig:
|
| 22 |
+
"""
|
| 23 |
+
This corresponds to the configuration of one study
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
drug_id: str # Identifier for the drug
|
| 27 |
+
num_individuals: int # Number of individuals in the population
|
| 28 |
+
num_peripherals: int # Number of peripheral compartments
|
| 29 |
+
log_k_a_mean: float # Mean absorption rate constant
|
| 30 |
+
log_k_a_std: float # Standard deviation for absorption rate constant
|
| 31 |
+
k_a_tmag: float # Magnitude of time-dependent variation of absorption rate constant
|
| 32 |
+
k_a_tscl: float # Scale of time-dependent variation of absorption rate constant
|
| 33 |
+
log_k_e_mean: float # Mean elimination rate constant
|
| 34 |
+
log_k_e_std: float # Standard deviation for elimination rate constant
|
| 35 |
+
k_e_tmag: float # Magnitude of time-dependent variation of elimination rate constant
|
| 36 |
+
k_e_tscl: float # Scale of time-dependent variation of elimination rate constant
|
| 37 |
+
log_V_mean: float # Mean volume of central compartment
|
| 38 |
+
log_V_std: float # Standard deviation for volume of central compartment
|
| 39 |
+
V_tmag: float # Magnitude of time-dependent variation of volume of central compartment
|
| 40 |
+
V_tscl: float # Scale of time-dependent variation of volume of central compartment
|
| 41 |
+
log_k_1p_mean: List[float] # Mean rate constants (central to other peripherals)
|
| 42 |
+
log_k_1p_std: List[float] # Standard deviations for k_1p
|
| 43 |
+
k_1p_tmag: List[float] # Magnitude of time-dependent variation of k_1p
|
| 44 |
+
k_1p_tscl: List[float] # Scale of time-dependent variation of k_1p
|
| 45 |
+
log_k_p1_mean: List[float] # Mean rate constants (other peripherals to central)
|
| 46 |
+
log_k_p1_std: List[float] # Standard deviations for k_p1
|
| 47 |
+
k_p1_tmag: List[float] # Magnitude of time-dependent variation of k_p1
|
| 48 |
+
k_p1_tscl: List[float] # Scale of time-dependent variation of k_p1
|
| 49 |
+
time_start: float # Start time for the study
|
| 50 |
+
time_stop: float # End time for the study
|
| 51 |
+
rel_ruv: float # Relative residual unexplained variability for the study
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class IndividualConfig:
|
| 56 |
+
"""
|
| 57 |
+
This corresponds to the configuration of one individual.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
num_peripherals: int = 2 # Number of peripheral compartments
|
| 61 |
+
k_a: Callable[[float], float] = lambda t: 0.1 # Absorption rate constant (gut to central)
|
| 62 |
+
k_e: Callable[[float], float] = lambda t: 0.05 # Elimination rate constant (central)
|
| 63 |
+
V: Callable[[float], float] = lambda t: 0.05 # Volume of central compartment
|
| 64 |
+
k_1p: List[Callable[[float], float]] = field(
|
| 65 |
+
default_factory=lambda: [lambda t: 0.01, lambda t: 0.01]
|
| 66 |
+
) # Rate constants from central to other peripherals
|
| 67 |
+
k_p1: List[Callable[[float], float]] = field(
|
| 68 |
+
default_factory=lambda: [lambda t: 0.01, lambda t: 0.01]
|
| 69 |
+
) # Rate constants from other peripherals to central
|
| 70 |
+
rel_ruv: float = 0.1 # Relative residual unexplained variability per individual
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def sample_study_config(config: MetaStudyConfig):
|
| 74 |
+
"""
|
| 75 |
+
Samples a StudyConfig object based on the MetaStudyConfig.
|
| 76 |
+
"""
|
| 77 |
+
# Generate random values for each parameter
|
| 78 |
+
drug_id = random.choice(config.drug_id_options)
|
| 79 |
+
num_individuals = random.randint(*config.num_individuals_range)
|
| 80 |
+
num_peripherals = random.randint(*config.num_peripherals_range)
|
| 81 |
+
|
| 82 |
+
# Sample mean, std, and tmag for each rate constant
|
| 83 |
+
log_k_a_mean = random.uniform(*config.log_k_a_mean_range)
|
| 84 |
+
log_k_a_std = random.uniform(*config.log_k_a_std_range)
|
| 85 |
+
k_a_tmag = random.uniform(*config.k_a_tmag_range)
|
| 86 |
+
k_a_tscl = random.uniform(*config.k_a_tscl_range)
|
| 87 |
+
|
| 88 |
+
log_k_e_mean = random.uniform(*config.log_k_e_mean_range)
|
| 89 |
+
log_k_e_std = random.uniform(*config.log_k_e_std_range)
|
| 90 |
+
k_e_tmag = random.uniform(*config.k_e_tmag_range)
|
| 91 |
+
k_e_tscl = random.uniform(*config.k_e_tscl_range)
|
| 92 |
+
|
| 93 |
+
log_V_mean = random.uniform(*config.log_V_mean_range)
|
| 94 |
+
log_V_std = random.uniform(*config.log_V_std_range)
|
| 95 |
+
V_tmag = random.uniform(*config.V_tmag_range)
|
| 96 |
+
V_tscl = random.uniform(*config.V_tscl_range)
|
| 97 |
+
|
| 98 |
+
log_k_1p_mean = [random.uniform(*config.log_k_1p_mean_range) for _ in range(num_peripherals)]
|
| 99 |
+
log_k_1p_std = [random.uniform(*config.log_k_1p_std_range) for _ in range(num_peripherals)]
|
| 100 |
+
k_1p_tmag = [random.uniform(*config.k_1p_tmag_range) for _ in range(num_peripherals)]
|
| 101 |
+
k_1p_tscl = [random.uniform(*config.k_1p_tscl_range) for _ in range(num_peripherals)]
|
| 102 |
+
|
| 103 |
+
log_k_p1_mean = [random.uniform(*config.log_k_p1_mean_range) for _ in range(num_peripherals)]
|
| 104 |
+
log_k_p1_std = [random.uniform(*config.log_k_p1_std_range) for _ in range(num_peripherals)]
|
| 105 |
+
k_p1_tmag = [random.uniform(*config.k_p1_tmag_range) for _ in range(num_peripherals)]
|
| 106 |
+
k_p1_tscl = [random.uniform(*config.k_p1_tscl_range) for _ in range(num_peripherals)]
|
| 107 |
+
|
| 108 |
+
rel_ruv = random.uniform(*config.rel_ruv_range)
|
| 109 |
+
|
| 110 |
+
return StudyConfig(
|
| 111 |
+
drug_id=drug_id,
|
| 112 |
+
num_individuals=num_individuals,
|
| 113 |
+
num_peripherals=num_peripherals,
|
| 114 |
+
log_k_a_mean=log_k_a_mean,
|
| 115 |
+
log_k_a_std=log_k_a_std,
|
| 116 |
+
k_a_tmag=k_a_tmag,
|
| 117 |
+
k_a_tscl=k_a_tscl,
|
| 118 |
+
log_k_e_mean=log_k_e_mean,
|
| 119 |
+
log_k_e_std=log_k_e_std,
|
| 120 |
+
k_e_tmag=k_e_tmag,
|
| 121 |
+
k_e_tscl=k_e_tscl,
|
| 122 |
+
log_V_mean=log_V_mean,
|
| 123 |
+
log_V_std=log_V_std,
|
| 124 |
+
V_tmag=V_tmag,
|
| 125 |
+
V_tscl=V_tscl,
|
| 126 |
+
log_k_1p_mean=log_k_1p_mean,
|
| 127 |
+
log_k_1p_std=log_k_1p_std,
|
| 128 |
+
k_1p_tmag=k_1p_tmag,
|
| 129 |
+
k_1p_tscl=k_1p_tscl,
|
| 130 |
+
log_k_p1_mean=log_k_p1_mean,
|
| 131 |
+
log_k_p1_std=log_k_p1_std,
|
| 132 |
+
k_p1_tmag=k_p1_tmag,
|
| 133 |
+
k_p1_tscl=k_p1_tscl,
|
| 134 |
+
time_start=config.time_start,
|
| 135 |
+
time_stop=config.time_stop,
|
| 136 |
+
rel_ruv=rel_ruv,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def sample_rate_function(mean_rate, variability, variability_type="sinusoidal"):
|
| 141 |
+
"""
|
| 142 |
+
Samples a time-dependent rate function.
|
| 143 |
+
:param mean_rate: Mean rate constant
|
| 144 |
+
:param variability: Variability in the rate constant
|
| 145 |
+
:param variability_type: Type of variability ("sinusoidal" or "decaying")
|
| 146 |
+
:return: A time-dependent rate function
|
| 147 |
+
"""
|
| 148 |
+
if variability_type == "sinusoidal":
|
| 149 |
+
|
| 150 |
+
def rate_function(t):
|
| 151 |
+
return mean_rate + variability * torch.sin(t) # Sinusoidal variability
|
| 152 |
+
elif variability_type == "decaying":
|
| 153 |
+
|
| 154 |
+
def rate_function(t):
|
| 155 |
+
return mean_rate * torch.exp(-variability * t) # Decaying variability
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError(f"Unknown variability_type: {variability_type}")
|
| 158 |
+
return rate_function
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def simulate_ou_process(
|
| 162 |
+
mu: float, sigma: float, theta: float, dt: float, T: float, seed: Optional[int] = None
|
| 163 |
+
) -> np.ndarray:
|
| 164 |
+
"""Simulate a mean-reverting Ornstein-Uhlenbeck process."""
|
| 165 |
+
if seed is not None:
|
| 166 |
+
np.random.seed(seed)
|
| 167 |
+
|
| 168 |
+
N = int(T / dt)
|
| 169 |
+
X = np.zeros(N)
|
| 170 |
+
|
| 171 |
+
# Start from the stationary distribution
|
| 172 |
+
X[0] = np.random.normal(mu, np.sqrt(sigma**2 / (2 * theta)))
|
| 173 |
+
|
| 174 |
+
for t in range(1, N):
|
| 175 |
+
dW = np.random.normal(0, np.sqrt(dt))
|
| 176 |
+
X[t] = X[t - 1] + theta * (mu - X[t - 1]) * dt + sigma * dW
|
| 177 |
+
|
| 178 |
+
return X
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def sample_individual_configs(study_config: StudyConfig, n: Optional[int] = None):
|
| 182 |
+
"""
|
| 183 |
+
Samples parameters for a population of individuals.
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
study_config : StudyConfig
|
| 188 |
+
Configuration object with parameter distributions.
|
| 189 |
+
n : int, optional
|
| 190 |
+
Number of individuals to sample. If None, defaults to
|
| 191 |
+
study_config.num_individuals.
|
| 192 |
+
|
| 193 |
+
Returns
|
| 194 |
+
-------
|
| 195 |
+
List[IndividualConfig]
|
| 196 |
+
A list of sampled individual configurations.
|
| 197 |
+
"""
|
| 198 |
+
num_individuals = n if n is not None else study_config.num_individuals
|
| 199 |
+
individual_configs = []
|
| 200 |
+
|
| 201 |
+
for _ in range(num_individuals):
|
| 202 |
+
# Sample parameters from lognormal distributions
|
| 203 |
+
k_a = np.random.lognormal(study_config.log_k_a_mean, study_config.log_k_a_std)
|
| 204 |
+
k_e = np.random.lognormal(study_config.log_k_e_mean, study_config.log_k_e_std)
|
| 205 |
+
V = np.random.lognormal(study_config.log_V_mean, study_config.log_V_std)
|
| 206 |
+
k_1p = [
|
| 207 |
+
np.random.lognormal(mean, std)
|
| 208 |
+
for mean, std in zip(study_config.log_k_1p_mean, study_config.log_k_1p_std)
|
| 209 |
+
]
|
| 210 |
+
k_p1 = [
|
| 211 |
+
np.random.lognormal(mean, std)
|
| 212 |
+
for mean, std in zip(study_config.log_k_p1_mean, study_config.log_k_p1_std)
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
# Ornstein–Uhlenbeck processes for time-dependent variability
|
| 216 |
+
dt = 0.1
|
| 217 |
+
ou_times = np.arange(study_config.time_start, study_config.time_stop, dt)
|
| 218 |
+
ou_k_a = k_a * np.exp(
|
| 219 |
+
simulate_ou_process(
|
| 220 |
+
0,
|
| 221 |
+
study_config.k_a_tmag * np.sqrt(2 * study_config.k_a_tscl),
|
| 222 |
+
study_config.k_a_tmag,
|
| 223 |
+
dt,
|
| 224 |
+
study_config.time_stop - study_config.time_start,
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
ou_k_e = k_e * np.exp(
|
| 228 |
+
simulate_ou_process(
|
| 229 |
+
0,
|
| 230 |
+
study_config.k_e_tmag * np.sqrt(2 * study_config.k_e_tscl),
|
| 231 |
+
study_config.k_e_tmag,
|
| 232 |
+
dt,
|
| 233 |
+
study_config.time_stop - study_config.time_start,
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
ou_V = V * np.exp(
|
| 237 |
+
simulate_ou_process(
|
| 238 |
+
0,
|
| 239 |
+
study_config.V_tmag * np.sqrt(2 * study_config.V_tscl),
|
| 240 |
+
study_config.V_tmag,
|
| 241 |
+
dt,
|
| 242 |
+
study_config.time_stop - study_config.time_start,
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Time-dependent rate functions
|
| 247 |
+
def k_a_fn(t, ou_k_a=ou_k_a):
|
| 248 |
+
return np.interp(t, ou_times, ou_k_a)
|
| 249 |
+
|
| 250 |
+
def k_e_fn(t, ou_k_e=ou_k_e):
|
| 251 |
+
return np.interp(t, ou_times, ou_k_e)
|
| 252 |
+
|
| 253 |
+
def V_fn(t, ou_V=ou_V):
|
| 254 |
+
return np.interp(t, ou_times, ou_V)
|
| 255 |
+
|
| 256 |
+
# Peripheral exchange rates (sinusoidal modulation as placeholder)
|
| 257 |
+
k_1p_fn = [
|
| 258 |
+
lambda t,
|
| 259 |
+
k_1p_i=k_1p[i],
|
| 260 |
+
tmag_i=study_config.k_1p_tmag[i],
|
| 261 |
+
tscl_i=study_config.k_1p_tscl[i]: k_1p_i * (1 + tmag_i * np.sin(t / tscl_i))
|
| 262 |
+
for i in range(len(k_1p))
|
| 263 |
+
]
|
| 264 |
+
k_p1_fn = [
|
| 265 |
+
lambda t,
|
| 266 |
+
k_p1_i=k_p1[i],
|
| 267 |
+
tmag_i=study_config.k_p1_tmag[i],
|
| 268 |
+
tscl_i=study_config.k_p1_tscl[i]: k_p1_i * (1 + tmag_i * np.sin(t / tscl_i))
|
| 269 |
+
for i in range(len(k_p1))
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
# Create config for this individual
|
| 273 |
+
config = IndividualConfig(
|
| 274 |
+
num_peripherals=study_config.num_peripherals,
|
| 275 |
+
k_a=k_a_fn,
|
| 276 |
+
k_e=k_e_fn,
|
| 277 |
+
V=V_fn,
|
| 278 |
+
k_1p=k_1p_fn,
|
| 279 |
+
k_p1=k_p1_fn,
|
| 280 |
+
rel_ruv=study_config.rel_ruv,
|
| 281 |
+
)
|
| 282 |
+
individual_configs.append(config)
|
| 283 |
+
|
| 284 |
+
return individual_configs
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def create_dynamic_ode_matrix(config: IndividualConfig, t: float):
|
| 288 |
+
"""
|
| 289 |
+
Creates the ODE matrix for the compartment model at time t.
|
| 290 |
+
:param config: IndividualConfig object
|
| 291 |
+
:param t: Current time
|
| 292 |
+
:return: ODE matrix as a torch tensor
|
| 293 |
+
"""
|
| 294 |
+
num_compartments = 2 + config.num_peripherals # gut, central, and peripherals
|
| 295 |
+
ode_matrix = torch.zeros((num_compartments, num_compartments))
|
| 296 |
+
|
| 297 |
+
# Gut compartment
|
| 298 |
+
ode_matrix[0, 0] = -config.k_a(t) # d_gut/dt = -k_a(t) * gut
|
| 299 |
+
|
| 300 |
+
# Central compartment
|
| 301 |
+
ode_matrix[1, 0] = config.k_a(t) # d_central/dt += k_a(t) * gut
|
| 302 |
+
ode_matrix[1, 1] = -config.k_e(t) # d_central/dt += -k_e(t) * central
|
| 303 |
+
|
| 304 |
+
# Peripheral compartments
|
| 305 |
+
for i in range(config.num_peripherals):
|
| 306 |
+
ode_matrix[1, 1] -= config.k_1p[i](t) # d_central/dt += - sum_p(k_1p(t)) * central
|
| 307 |
+
ode_matrix[1, 2 + i] = config.k_p1[i](t) # d_central/dt += k_p1[i](t) * peripheral(i)
|
| 308 |
+
ode_matrix[2 + i, 1] = config.k_1p[i](t) # d_peripheral(i)/dt += k_1p[i](t) * central
|
| 309 |
+
ode_matrix[2 + i, 2 + i] = -config.k_p1[i](
|
| 310 |
+
t
|
| 311 |
+
) # d_peripheral(i)/dt += -k_p1[i](t) * peripheral(i)
|
| 312 |
+
|
| 313 |
+
return ode_matrix
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def create_dynamic_ode_matrix_batched(configs, t, num_peripherals):
|
| 317 |
+
"""
|
| 318 |
+
Creates batched ODE matrices for multiple individuals.
|
| 319 |
+
|
| 320 |
+
Parameters:
|
| 321 |
+
----------
|
| 322 |
+
configs : list
|
| 323 |
+
List of IndividualConfig objects.
|
| 324 |
+
t : float
|
| 325 |
+
Current time point.
|
| 326 |
+
num_peripherals : int
|
| 327 |
+
Number of peripheral compartments (same for all individuals).
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
-------
|
| 331 |
+
A_all : torch.Tensor
|
| 332 |
+
Tensor of shape (N, M, M) containing ODE matrices for all individuals.
|
| 333 |
+
"""
|
| 334 |
+
import torch
|
| 335 |
+
|
| 336 |
+
N = len(configs)
|
| 337 |
+
M = 2 + num_peripherals
|
| 338 |
+
A_all = torch.zeros((N, M, M), dtype=torch.float32)
|
| 339 |
+
|
| 340 |
+
# Compute batched rate parameters
|
| 341 |
+
k_a_all = torch.tensor([config.k_a(t) for config in configs], dtype=torch.float32)
|
| 342 |
+
k_e_all = torch.tensor([config.k_e(t) for config in configs], dtype=torch.float32)
|
| 343 |
+
k_1p_all = torch.tensor(
|
| 344 |
+
[[config.k_1p[i](t) for i in range(num_peripherals)] for config in configs],
|
| 345 |
+
dtype=torch.float32,
|
| 346 |
+
)
|
| 347 |
+
k_p1_all = torch.tensor(
|
| 348 |
+
[[config.k_p1[i](t) for i in range(num_peripherals)] for config in configs],
|
| 349 |
+
dtype=torch.float32,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Populate the batched ODE matrices
|
| 353 |
+
A_all[:, 0, 0] = -k_a_all # Gut compartment
|
| 354 |
+
A_all[:, 1, 0] = k_a_all # Absorption into central
|
| 355 |
+
A_all[:, 1, 1] = -k_e_all - k_1p_all.sum(dim=1) # Central compartment
|
| 356 |
+
A_all[:, 1, 2 : 2 + num_peripherals] = k_p1_all # Central to peripheral
|
| 357 |
+
A_all[:, 2 : 2 + num_peripherals, 1] = k_1p_all # Peripheral to central
|
| 358 |
+
for i in range(num_peripherals):
|
| 359 |
+
A_all[:, 2 + i, 2 + i] = -k_p1_all[:, i] # Peripheral compartments
|
| 360 |
+
|
| 361 |
+
return A_all
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def sample_study(
|
| 365 |
+
individual_config_array, dosing_config_array, t: torch.Tensor, solver_method: str = "rk4"
|
| 366 |
+
) -> Tuple[
|
| 367 |
+
torch.Tensor, # [N, T] concentration profiles
|
| 368 |
+
torch.Tensor, # [N, T] time points
|
| 369 |
+
torch.Tensor, # [N] dosing amounts
|
| 370 |
+
torch.Tensor, # [N] dosing route types (0 = oral, 1 = iv)
|
| 371 |
+
]:
|
| 372 |
+
"""
|
| 373 |
+
Simulates the pharmacokinetic study for a group of individuals and returns
|
| 374 |
+
concentration profiles, time points, and dosing metadata.
|
| 375 |
+
|
| 376 |
+
Parameters:
|
| 377 |
+
----------
|
| 378 |
+
individual_config_array : list
|
| 379 |
+
List of IndividualConfig objects for each individual.
|
| 380 |
+
dosing_config_array : list
|
| 381 |
+
List of DosingConfig objects for each individual.
|
| 382 |
+
t : torch.Tensor
|
| 383 |
+
A 1D tensor of time points [T].
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
-------
|
| 387 |
+
full_simulation : torch.Tensor
|
| 388 |
+
Concentration profiles [N, T].
|
| 389 |
+
full_simulation_times : torch.Tensor
|
| 390 |
+
Time points [N, T].
|
| 391 |
+
dosing_amounts : torch.Tensor
|
| 392 |
+
Dosing amounts [N].
|
| 393 |
+
dosing_route_types : torch.Tensor
|
| 394 |
+
Route types [N], 0 = oral, 1 = iv.
|
| 395 |
+
"""
|
| 396 |
+
# Sanity check
|
| 397 |
+
if len(individual_config_array) != len(dosing_config_array):
|
| 398 |
+
raise ValueError("Number of individuals and dosing configurations must match.")
|
| 399 |
+
|
| 400 |
+
N = len(individual_config_array)
|
| 401 |
+
num_peripherals_list = [cfg.num_peripherals for cfg in individual_config_array]
|
| 402 |
+
all_same_peripherals = all(n == num_peripherals_list[0] for n in num_peripherals_list)
|
| 403 |
+
|
| 404 |
+
# Extract dosing info
|
| 405 |
+
dosing_amounts = torch.tensor(
|
| 406 |
+
[cfg.dose for cfg in dosing_config_array], dtype=torch.float32
|
| 407 |
+
) # [N]
|
| 408 |
+
routes_str = [cfg.route for cfg in dosing_config_array]
|
| 409 |
+
route_map = {"oral": 0, "iv": 1}
|
| 410 |
+
dosing_route_types = torch.tensor([route_map[r] for r in routes_str], dtype=torch.int64) # [N]
|
| 411 |
+
|
| 412 |
+
if all_same_peripherals:
|
| 413 |
+
P = num_peripherals_list[0]
|
| 414 |
+
M = 2 + P
|
| 415 |
+
y0 = torch.zeros((N, M), dtype=torch.float32)
|
| 416 |
+
is_oral = dosing_route_types == 0
|
| 417 |
+
is_iv = dosing_route_types == 1
|
| 418 |
+
y0[is_oral, 0] = dosing_amounts[is_oral]
|
| 419 |
+
y0[is_iv, 1] = dosing_amounts[is_iv]
|
| 420 |
+
|
| 421 |
+
def ode_func(t, y):
|
| 422 |
+
A_all = create_dynamic_ode_matrix_batched(individual_config_array, t.item(), P)
|
| 423 |
+
return torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
|
| 424 |
+
|
| 425 |
+
y = odeint(ode_func, y0, t, method=solver_method) # [T, N, M]
|
| 426 |
+
V_all = torch.tensor(
|
| 427 |
+
[[cfg.V(ti.item()) for ti in t] for cfg in individual_config_array], dtype=torch.float32
|
| 428 |
+
) # [N, T]
|
| 429 |
+
full_simulation = y[:, :, 1].T / V_all # [N, T]
|
| 430 |
+
full_simulation *= (
|
| 431 |
+
1 + torch.randn_like(full_simulation) * individual_config_array[0].rel_ruv
|
| 432 |
+
)
|
| 433 |
+
else:
|
| 434 |
+
full_simulation = []
|
| 435 |
+
for config, dosing_config in zip(individual_config_array, dosing_config_array):
|
| 436 |
+
P = config.num_peripherals
|
| 437 |
+
M = 2 + P
|
| 438 |
+
if dosing_config.route == "oral":
|
| 439 |
+
y0 = torch.tensor([dosing_config.dose] + [0.0] * (M - 1), dtype=torch.float32)
|
| 440 |
+
elif dosing_config.route == "iv":
|
| 441 |
+
y0 = torch.tensor([0.0, dosing_config.dose] + [0.0] * (M - 2), dtype=torch.float32)
|
| 442 |
+
else:
|
| 443 |
+
raise ValueError(f"Unsupported route: {dosing_config.route}")
|
| 444 |
+
|
| 445 |
+
def ode_func(t, y):
|
| 446 |
+
A = create_dynamic_ode_matrix(config, t.item())
|
| 447 |
+
return torch.matmul(A, y)
|
| 448 |
+
|
| 449 |
+
y = odeint(ode_func, y0, t, method=solver_method) # [T, M]
|
| 450 |
+
V = torch.tensor([config.V(ti.item()) for ti in t], dtype=torch.float32) # [T]
|
| 451 |
+
concentration = y[:, 1] / V
|
| 452 |
+
concentration *= 1 + torch.randn_like(concentration) * config.rel_ruv
|
| 453 |
+
full_simulation.append(concentration)
|
| 454 |
+
|
| 455 |
+
full_simulation = torch.stack(full_simulation) # [N, T]
|
| 456 |
+
|
| 457 |
+
full_times = t.unsqueeze(0).repeat(N, 1) # [N, T]
|
| 458 |
+
|
| 459 |
+
return full_simulation, full_times, dosing_amounts, dosing_route_types
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def sample_study_with_duration(
|
| 463 |
+
individual_config_array,
|
| 464 |
+
dosing_config_array: List[DosingWithDurationConfig],
|
| 465 |
+
t: torch.Tensor,
|
| 466 |
+
solver_method: str = "rk4",
|
| 467 |
+
) -> Tuple[
|
| 468 |
+
torch.Tensor, # [N, T] concentration profiles
|
| 469 |
+
torch.Tensor, # [N, T] time points
|
| 470 |
+
torch.Tensor, # [N] dosing amounts
|
| 471 |
+
torch.Tensor, # [N] dosing route types (0 = oral, 1 = iv)
|
| 472 |
+
]:
|
| 473 |
+
"""
|
| 474 |
+
Simulates the pharmacokinetic study for a group of individuals and returns
|
| 475 |
+
concentration profiles, time points, and dosing metadata.
|
| 476 |
+
|
| 477 |
+
This is a parallel implementation to sample_study that supports infusion dosing with duration.
|
| 478 |
+
Once validated, the two can be merged.
|
| 479 |
+
|
| 480 |
+
Parameters:
|
| 481 |
+
----------
|
| 482 |
+
individual_config_array : list
|
| 483 |
+
List of IndividualConfig objects for each individual.
|
| 484 |
+
dosing_config_array : list
|
| 485 |
+
List of DosingWithDurationConfig objects for each individual.
|
| 486 |
+
t : torch.Tensor
|
| 487 |
+
A 1D tensor of time points [T].
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
-------
|
| 491 |
+
full_simulation : torch.Tensor
|
| 492 |
+
Concentration profiles [N, T].
|
| 493 |
+
full_simulation_times : torch.Tensor
|
| 494 |
+
Time points [N, T].
|
| 495 |
+
dosing_amounts : torch.Tensor
|
| 496 |
+
Dosing amounts [N].
|
| 497 |
+
dosing_route_types : torch.Tensor
|
| 498 |
+
Route types [N], 0 = oral, 1 = iv.
|
| 499 |
+
"""
|
| 500 |
+
# Sanity check
|
| 501 |
+
if len(individual_config_array) != len(dosing_config_array):
|
| 502 |
+
raise ValueError("Number of individuals and dosing configurations must match.")
|
| 503 |
+
|
| 504 |
+
N = len(individual_config_array)
|
| 505 |
+
num_peripherals_list = [cfg.num_peripherals for cfg in individual_config_array]
|
| 506 |
+
all_same_peripherals = all(n == num_peripherals_list[0] for n in num_peripherals_list)
|
| 507 |
+
|
| 508 |
+
# Extract dosing info
|
| 509 |
+
dosing_amounts = torch.tensor(
|
| 510 |
+
[cfg.dose for cfg in dosing_config_array], dtype=torch.float32
|
| 511 |
+
) # [N]
|
| 512 |
+
routes_str = [cfg.route for cfg in dosing_config_array]
|
| 513 |
+
route_map = {"oral": 0, "iv": 1}
|
| 514 |
+
dosing_route_types = torch.tensor([route_map[r] for r in routes_str], dtype=torch.int64) # [N]
|
| 515 |
+
dosing_durations = torch.tensor(
|
| 516 |
+
[cfg.duration for cfg in dosing_config_array], dtype=torch.float32
|
| 517 |
+
) # [N]
|
| 518 |
+
|
| 519 |
+
if all_same_peripherals and all(dosing_durations == 0):
|
| 520 |
+
P = num_peripherals_list[0]
|
| 521 |
+
M = 2 + P # gut, central, peripherals
|
| 522 |
+
y0 = torch.zeros((N, M), dtype=torch.float32)
|
| 523 |
+
|
| 524 |
+
is_oral = dosing_route_types == 0
|
| 525 |
+
is_iv_bolus = dosing_route_types == 1
|
| 526 |
+
|
| 527 |
+
y0[is_oral, 0] = dosing_amounts[is_oral]
|
| 528 |
+
y0[is_iv_bolus, 1] = dosing_amounts[is_iv_bolus]
|
| 529 |
+
|
| 530 |
+
def ode_func(t, y):
|
| 531 |
+
A_all = create_dynamic_ode_matrix_batched(individual_config_array, t.item(), P)
|
| 532 |
+
return torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
|
| 533 |
+
|
| 534 |
+
# ODE solving during infusion
|
| 535 |
+
y = odeint(ode_func, y0, t, method=solver_method) # [T, N, M]
|
| 536 |
+
V_all = torch.tensor(
|
| 537 |
+
[[cfg.V(ti.item()) for ti in t] for cfg in individual_config_array], dtype=torch.float32
|
| 538 |
+
) # [N, T]
|
| 539 |
+
full_simulation = y[:, :, 1].T / V_all # [N, T]
|
| 540 |
+
full_simulation *= (
|
| 541 |
+
1 + torch.randn_like(full_simulation) * individual_config_array[0].rel_ruv
|
| 542 |
+
)
|
| 543 |
+
else:
|
| 544 |
+
full_simulation = []
|
| 545 |
+
for config, dosing_config in zip(individual_config_array, dosing_config_array):
|
| 546 |
+
P = config.num_peripherals
|
| 547 |
+
M = 2 + P # gut, central, peripherals
|
| 548 |
+
if dosing_config.route == "oral":
|
| 549 |
+
assert dosing_config.duration == 0, "Oral dosing cannot have a duration."
|
| 550 |
+
y0 = torch.tensor([dosing_config.dose] + [0.0] * (M - 1), dtype=torch.float32)
|
| 551 |
+
elif dosing_config.route == "iv":
|
| 552 |
+
if dosing_config.duration > 0:
|
| 553 |
+
# Infusion dosing
|
| 554 |
+
y0 = torch.tensor(
|
| 555 |
+
[0.0, 0.0] + [0.0] * (M - 2),
|
| 556 |
+
dtype=torch.float32,
|
| 557 |
+
)
|
| 558 |
+
else: # Bolus dosing
|
| 559 |
+
y0 = torch.tensor(
|
| 560 |
+
[0.0, dosing_config.dose] + [0.0] * (M - 2), dtype=torch.float32
|
| 561 |
+
)
|
| 562 |
+
else:
|
| 563 |
+
raise ValueError(f"Unsupported route: {dosing_config.route}")
|
| 564 |
+
|
| 565 |
+
def ode_func(t, y):
|
| 566 |
+
A = create_dynamic_ode_matrix(config, t.item())
|
| 567 |
+
b = torch.zeros_like(y)
|
| 568 |
+
if (
|
| 569 |
+
dosing_config.route == "iv"
|
| 570 |
+
and dosing_config.duration > 0
|
| 571 |
+
and t.item() < dosing_config.duration
|
| 572 |
+
):
|
| 573 |
+
# During infusion, add rate to central compartment
|
| 574 |
+
b[1] = dosing_config.dose / dosing_config.duration
|
| 575 |
+
return torch.matmul(A, y) + b
|
| 576 |
+
|
| 577 |
+
y = odeint(ode_func, y0, t, method=solver_method) # [T, M]
|
| 578 |
+
V = torch.tensor([config.V(ti.item()) for ti in t], dtype=torch.float32) # [T]
|
| 579 |
+
concentration = y[:, 1] / V
|
| 580 |
+
concentration *= 1 + torch.randn_like(concentration) * config.rel_ruv
|
| 581 |
+
full_simulation.append(concentration)
|
| 582 |
+
|
| 583 |
+
full_simulation = torch.stack(full_simulation) # [N, T]
|
| 584 |
+
|
| 585 |
+
full_times = t.unsqueeze(0).repeat(N, 1) # [N, T]
|
| 586 |
+
|
| 587 |
+
return full_simulation, full_times, dosing_amounts, dosing_route_types
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def derive_timescale_parameters(config: StudyConfig, meta_config: MetaStudyConfig):
|
| 591 |
+
"""
|
| 592 |
+
Derive peak time and terminal half life for typical parameters,
|
| 593 |
+
which can then be used to inform a study-specific sampling schedule.
|
| 594 |
+
"""
|
| 595 |
+
k_a = np.exp(config.log_k_a_mean)
|
| 596 |
+
k_e = np.exp(config.log_k_e_mean)
|
| 597 |
+
tmax = (np.log(k_e) - np.log(k_a)) / (k_e - k_a)
|
| 598 |
+
|
| 599 |
+
# mean residence time approximation for terminal half-life
|
| 600 |
+
MRT = 1 / k_e
|
| 601 |
+
# for i in range(config.num_peripherals):
|
| 602 |
+
# k_1i = np.exp(config.log_k_p1_mean[i])
|
| 603 |
+
# MRT += 1/k_1i
|
| 604 |
+
t12 = np.log(2) * MRT
|
| 605 |
+
if t12 > meta_config.time_stop:
|
| 606 |
+
t12 = float(meta_config.time_stop / 2.0)
|
| 607 |
+
if tmax > t12:
|
| 608 |
+
tmax = float(t12 * 0.5)
|
| 609 |
+
return torch.Tensor([tmax, t12])
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def sample_dosing_configs(config: MetaDosingConfig):
|
| 613 |
+
"""
|
| 614 |
+
Sample a dosing configuration based on the meta dosing configuration.
|
| 615 |
+
Route may be the same for all individuals in the study or not.
|
| 616 |
+
Doses are lognormally distributed with log-mean and log-std sample uniformly from the specified range.
|
| 617 |
+
In the special case of logdose_std_range = (0, 0), the dose is identical for all individuals.
|
| 618 |
+
"""
|
| 619 |
+
dosing_configs = []
|
| 620 |
+
|
| 621 |
+
if config.same_route:
|
| 622 |
+
route = np.random.choice(config.route_options, p=config.route_weights)
|
| 623 |
+
route = np.repeat(route, config.num_individuals)
|
| 624 |
+
else:
|
| 625 |
+
route = np.random.choice(
|
| 626 |
+
config.route_options, p=config.route_weights, size=config.num_individuals
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# Draw lognormal distribution parameters for dose
|
| 630 |
+
logdose_mean = np.random.uniform(*config.logdose_mean_range)
|
| 631 |
+
logdose_std = np.random.uniform(*config.logdose_std_range)
|
| 632 |
+
dose = np.random.lognormal(logdose_mean, logdose_std, size=config.num_individuals)
|
| 633 |
+
|
| 634 |
+
for i in range(config.num_individuals):
|
| 635 |
+
time = config.time
|
| 636 |
+
|
| 637 |
+
this_config = DosingConfig(dose=dose[i], route=route[i], time=time)
|
| 638 |
+
|
| 639 |
+
dosing_configs.append(this_config)
|
| 640 |
+
|
| 641 |
+
return dosing_configs
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def sample_dosing_with_duration_configs(config: MetaDosingWithDurationConfig):
|
| 645 |
+
"""
|
| 646 |
+
Sample a dosing configuration based on the meta dosing configuration.
|
| 647 |
+
Route may be the same for all individuals in the study or not.
|
| 648 |
+
Doses are lognormally distributed with log-mean and log-std sample uniformly from the specified range.
|
| 649 |
+
In the special case of logdose_std_range = (0, 0), the dose is identical for all individuals.
|
| 650 |
+
"""
|
| 651 |
+
dosing_configs = []
|
| 652 |
+
|
| 653 |
+
if config.same_route:
|
| 654 |
+
route = np.random.choice(config.route_options, p=config.route_weights)
|
| 655 |
+
route = np.repeat(route, config.num_individuals)
|
| 656 |
+
else:
|
| 657 |
+
route = np.random.choice(
|
| 658 |
+
config.route_options, p=config.route_weights, size=config.num_individuals
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Draw durations for infusion dosing depending on route
|
| 662 |
+
duration_raw = np.random.uniform(
|
| 663 |
+
config.duration_range[0], config.duration_range[1], size=config.num_individuals
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# Draw lognormal distribution parameters for dose
|
| 667 |
+
logdose_mean = np.random.uniform(*config.logdose_mean_range)
|
| 668 |
+
logdose_std = np.random.uniform(*config.logdose_std_range)
|
| 669 |
+
dose = np.random.lognormal(logdose_mean, logdose_std, size=config.num_individuals)
|
| 670 |
+
|
| 671 |
+
for i in range(config.num_individuals):
|
| 672 |
+
time = config.time
|
| 673 |
+
|
| 674 |
+
# Add duration flag based on route duration weights
|
| 675 |
+
duration_flag = np.random.binomial(1, config.route_duration_weights[route[i]], size=1)[0]
|
| 676 |
+
|
| 677 |
+
# Define a dosing config with a (default) duration of 0; can be modified once MetaDosingConfig supports it
|
| 678 |
+
this_config = DosingWithDurationConfig(
|
| 679 |
+
dose=dose[i],
|
| 680 |
+
route=route[i],
|
| 681 |
+
time=time,
|
| 682 |
+
duration=duration_raw[i] * duration_flag,
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
dosing_configs.append(this_config)
|
| 686 |
+
|
| 687 |
+
return dosing_configs
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def get_random_simulation(
|
| 691 |
+
model_config: NodePKExperimentConfig,
|
| 692 |
+
) -> Tuple[TensorType["I", "T"], TensorType["I", "T"]]:
|
| 693 |
+
"""
|
| 694 |
+
Generates random simulation data based on the model configuration.
|
| 695 |
+
|
| 696 |
+
Args:
|
| 697 |
+
model_config (NodePKConfig): Configuration for the simulation.
|
| 698 |
+
|
| 699 |
+
Returns:
|
| 700 |
+
Tuple[TensorType["I", "T"], TensorType["I", "T"]]: Time steps and simulation points.
|
| 701 |
+
"""
|
| 702 |
+
I = model_config.meta_study.num_individuals_range[0]
|
| 703 |
+
T = model_config.meta_study.time_num_steps
|
| 704 |
+
|
| 705 |
+
# Generate time steps using linspace
|
| 706 |
+
time_steps = (
|
| 707 |
+
torch.linspace(
|
| 708 |
+
model_config.meta_study.time_start,
|
| 709 |
+
model_config.meta_study.time_stop,
|
| 710 |
+
T,
|
| 711 |
+
dtype=torch.float32,
|
| 712 |
+
)
|
| 713 |
+
.unsqueeze(0)
|
| 714 |
+
.repeat(I, 1)
|
| 715 |
+
) # Shape: [I, T]
|
| 716 |
+
|
| 717 |
+
# Generate random simulation points with the same shape
|
| 718 |
+
simulation_points = torch.rand(I, T) # Shape: [I, T]
|
| 719 |
+
simulation_points = simulation_points / model_config.meta_study.time_stop
|
| 720 |
+
|
| 721 |
+
return simulation_points, time_steps
|
sim_priors_pk/data/data_generation/compartment_models_management.py
ADDED
|
@@ -0,0 +1,1338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pyright: reportAssignmentType=false
|
| 2 |
+
# compartment_models_management.py
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from dataclasses import replace
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torchtyping import TensorType
|
| 12 |
+
|
| 13 |
+
from sim_priors_pk.config_classes.data_config import (
|
| 14 |
+
DosingConfig,
|
| 15 |
+
DosingWithDurationConfig,
|
| 16 |
+
MetaDosingConfig,
|
| 17 |
+
MetaDosingWithDurationConfig,
|
| 18 |
+
MetaStudyConfig,
|
| 19 |
+
ObservationsConfig,
|
| 20 |
+
)
|
| 21 |
+
from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
|
| 22 |
+
from sim_priors_pk.data.data_generation.compartment_models import (
|
| 23 |
+
StudyConfig,
|
| 24 |
+
derive_timescale_parameters,
|
| 25 |
+
sample_dosing_configs,
|
| 26 |
+
sample_dosing_with_duration_configs,
|
| 27 |
+
sample_individual_configs,
|
| 28 |
+
sample_study,
|
| 29 |
+
sample_study_config,
|
| 30 |
+
sample_study_with_duration,
|
| 31 |
+
)
|
| 32 |
+
from sim_priors_pk.data.data_generation.observations_classes import ObservationStrategyFactory
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
if TYPE_CHECKING: # pragma: no cover - typing only
|
| 37 |
+
from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON
|
| 38 |
+
else: # pragma: no cover - runtime fallback avoids heavy import cycle
|
| 39 |
+
IndividualJSON = Dict[str, object]
|
| 40 |
+
StudyJSON = Dict[str, object]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def is_valid_simulation(sim: torch.Tensor) -> bool:
|
| 44 |
+
"""Returns True if the simulation is numerically valid and all values are < 10."""
|
| 45 |
+
return torch.isfinite(sim).all() and (sim >= 0).all() and (sim < 10).all()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def sample_dosing_configs_repeated_target(config: MetaDosingConfig, n_targets: int):
|
| 49 |
+
"""
|
| 50 |
+
Generate dosing configs where all target individuals share the same
|
| 51 |
+
dose and route.
|
| 52 |
+
|
| 53 |
+
Parameters
|
| 54 |
+
----------
|
| 55 |
+
config : MetaDosingConfig
|
| 56 |
+
Meta dosing configuration (num_individuals field may be ignored).
|
| 57 |
+
n_targets : int
|
| 58 |
+
Number of target individuals to generate.
|
| 59 |
+
|
| 60 |
+
Returns
|
| 61 |
+
-------
|
| 62 |
+
List[DosingConfig]
|
| 63 |
+
Identical dosing configs repeated `n_targets` times.
|
| 64 |
+
"""
|
| 65 |
+
# Choose one route for all targets
|
| 66 |
+
route = np.random.choice(config.route_options, p=config.route_weights)
|
| 67 |
+
|
| 68 |
+
# Sample one dose (lognormal)
|
| 69 |
+
logdose_mean = np.random.uniform(*config.logdose_mean_range)
|
| 70 |
+
logdose_std = np.random.uniform(*config.logdose_std_range)
|
| 71 |
+
dose_value = float(np.random.lognormal(logdose_mean, logdose_std))
|
| 72 |
+
|
| 73 |
+
# Build identical configs
|
| 74 |
+
dosing_configs = [
|
| 75 |
+
DosingConfig(dose=dose_value, route=route, time=config.time) for _ in range(n_targets)
|
| 76 |
+
]
|
| 77 |
+
return dosing_configs
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def sample_dosing_with_duration_configs_repeated_target(
|
| 81 |
+
config: MetaDosingWithDurationConfig, n_targets: int
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Generate dosing configs where all target individuals share the same
|
| 85 |
+
dose and route.
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
config : MetaDosingWithDurationConfig
|
| 90 |
+
Meta dosing configuration with duration (num_individuals field may be ignored).
|
| 91 |
+
n_targets : int
|
| 92 |
+
Number of target individuals to generate.
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
List[DosingConfig]
|
| 97 |
+
Identical dosing configs repeated `n_targets` times.
|
| 98 |
+
"""
|
| 99 |
+
# Choose one route for all targets
|
| 100 |
+
route = np.random.choice(config.route_options, p=config.route_weights)
|
| 101 |
+
|
| 102 |
+
# Handling the duration logic
|
| 103 |
+
duration_weight = config.route_duration_weights[route]
|
| 104 |
+
duration_range = np.random.uniform(*config.duration_range)
|
| 105 |
+
duration = duration_weight * duration_range
|
| 106 |
+
|
| 107 |
+
# Sample one dose (lognormal)
|
| 108 |
+
logdose_mean = np.random.uniform(*config.logdose_mean_range)
|
| 109 |
+
logdose_std = np.random.uniform(*config.logdose_std_range)
|
| 110 |
+
dose_value = float(np.random.lognormal(logdose_mean, logdose_std))
|
| 111 |
+
|
| 112 |
+
# Build identical configs
|
| 113 |
+
dosing_configs = [
|
| 114 |
+
DosingWithDurationConfig(dose=dose_value, route=route, time=config.time, duration=duration)
|
| 115 |
+
for _ in range(n_targets)
|
| 116 |
+
]
|
| 117 |
+
return dosing_configs
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ──────────────────────────────────────────────────────────────
|
| 121 |
+
# NEW: split where *all* individuals are target
|
| 122 |
+
# ──────────────────────────────────────────────────────────────
|
| 123 |
+
def split_context_only(
|
| 124 |
+
full_simulation: torch.Tensor,
|
| 125 |
+
full_simulation_times: torch.Tensor,
|
| 126 |
+
) -> Tuple[torch.Tensor, torch.Tensor, list[int]]:
|
| 127 |
+
"""Return all individuals as context, no targets."""
|
| 128 |
+
num_individuals = full_simulation.shape[0]
|
| 129 |
+
context_indices = list(range(num_individuals))
|
| 130 |
+
return full_simulation, full_simulation_times, context_indices
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def split_simulations_repeated_target(
|
| 134 |
+
full_simulation: torch.Tensor,
|
| 135 |
+
full_simulation_times: torch.Tensor,
|
| 136 |
+
) -> Tuple[
|
| 137 |
+
Optional[torch.Tensor],
|
| 138 |
+
Optional[torch.Tensor],
|
| 139 |
+
torch.Tensor,
|
| 140 |
+
torch.Tensor,
|
| 141 |
+
list[int],
|
| 142 |
+
list[int],
|
| 143 |
+
]:
|
| 144 |
+
"""
|
| 145 |
+
Variant of split_simulations where **all individuals are in the target set**
|
| 146 |
+
and no context individuals are returned.
|
| 147 |
+
|
| 148 |
+
Parameters
|
| 149 |
+
----------
|
| 150 |
+
full_simulation : torch.Tensor [N, T]
|
| 151 |
+
full_simulation_times : torch.Tensor [N, T]
|
| 152 |
+
|
| 153 |
+
Returns
|
| 154 |
+
-------
|
| 155 |
+
context_simulation : None
|
| 156 |
+
context_simulation_times : None
|
| 157 |
+
target_simulation : torch.Tensor [N, T]
|
| 158 |
+
target_simulation_times : torch.Tensor [N, T]
|
| 159 |
+
context_indices : []
|
| 160 |
+
target_indices : list[int] = [0,...,N-1]
|
| 161 |
+
"""
|
| 162 |
+
num_individuals = full_simulation.shape[0]
|
| 163 |
+
target_indices = list(range(num_individuals))
|
| 164 |
+
|
| 165 |
+
return (
|
| 166 |
+
None,
|
| 167 |
+
None,
|
| 168 |
+
full_simulation,
|
| 169 |
+
full_simulation_times,
|
| 170 |
+
[],
|
| 171 |
+
target_indices,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _generate_full_simulation(
|
| 176 |
+
meta_study_config: MetaStudyConfig,
|
| 177 |
+
meta_dosing_config: MetaDosingConfig,
|
| 178 |
+
*,
|
| 179 |
+
retry_on_invalid: bool = True,
|
| 180 |
+
idx: int = 0,
|
| 181 |
+
) -> Tuple[
|
| 182 |
+
torch.Tensor,
|
| 183 |
+
torch.Tensor,
|
| 184 |
+
torch.Tensor,
|
| 185 |
+
torch.Tensor,
|
| 186 |
+
torch.Tensor,
|
| 187 |
+
torch.Tensor,
|
| 188 |
+
StudyConfig,
|
| 189 |
+
list[DosingConfig],
|
| 190 |
+
int,
|
| 191 |
+
]:
|
| 192 |
+
"""Internal helper returning the raw tensors alongside sampling metadata."""
|
| 193 |
+
study_config = sample_study_config(meta_study_config)
|
| 194 |
+
indiv_config_array = sample_individual_configs(study_config)
|
| 195 |
+
time_scales = derive_timescale_parameters(study_config, meta_study_config)
|
| 196 |
+
|
| 197 |
+
time_points = torch.linspace(
|
| 198 |
+
meta_study_config.time_start,
|
| 199 |
+
meta_study_config.time_stop,
|
| 200 |
+
meta_study_config.time_num_steps,
|
| 201 |
+
dtype=torch.float32,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
local_meta_dosing = replace(meta_dosing_config, num_individuals=study_config.num_individuals)
|
| 205 |
+
dosing_config_array = sample_dosing_configs(local_meta_dosing)
|
| 206 |
+
|
| 207 |
+
full_sim, full_times, dosing_amounts, dosing_routes = sample_study(
|
| 208 |
+
indiv_config_array,
|
| 209 |
+
dosing_config_array,
|
| 210 |
+
time_points,
|
| 211 |
+
meta_study_config.solver_method,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if not is_valid_simulation(full_sim):
|
| 215 |
+
attempt_number = idx + 1
|
| 216 |
+
if attempt_number > 5:
|
| 217 |
+
logger.warning(
|
| 218 |
+
"Invalid simulation encountered during attempt %d (recursion depth %d); retry_on_invalid=%s.",
|
| 219 |
+
attempt_number,
|
| 220 |
+
idx,
|
| 221 |
+
retry_on_invalid,
|
| 222 |
+
)
|
| 223 |
+
if retry_on_invalid:
|
| 224 |
+
(
|
| 225 |
+
full_sim,
|
| 226 |
+
full_times,
|
| 227 |
+
dosing_amounts,
|
| 228 |
+
dosing_routes,
|
| 229 |
+
time_points,
|
| 230 |
+
time_scales,
|
| 231 |
+
study_config,
|
| 232 |
+
dosing_config_array,
|
| 233 |
+
downstream_failures,
|
| 234 |
+
) = _generate_full_simulation(
|
| 235 |
+
meta_study_config,
|
| 236 |
+
meta_dosing_config,
|
| 237 |
+
retry_on_invalid=retry_on_invalid,
|
| 238 |
+
idx=idx + 1,
|
| 239 |
+
)
|
| 240 |
+
return (
|
| 241 |
+
full_sim,
|
| 242 |
+
full_times,
|
| 243 |
+
dosing_amounts,
|
| 244 |
+
dosing_routes,
|
| 245 |
+
time_points,
|
| 246 |
+
time_scales,
|
| 247 |
+
study_config,
|
| 248 |
+
dosing_config_array,
|
| 249 |
+
downstream_failures + 1,
|
| 250 |
+
)
|
| 251 |
+
raise RuntimeError("Invalid simulation")
|
| 252 |
+
|
| 253 |
+
return (
|
| 254 |
+
full_sim,
|
| 255 |
+
full_times,
|
| 256 |
+
dosing_amounts,
|
| 257 |
+
dosing_routes,
|
| 258 |
+
time_points,
|
| 259 |
+
time_scales,
|
| 260 |
+
study_config,
|
| 261 |
+
dosing_config_array,
|
| 262 |
+
0,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _generate_full_simulation_with_duration(
|
| 267 |
+
meta_study_config: MetaStudyConfig,
|
| 268 |
+
meta_dosing_config: MetaDosingWithDurationConfig,
|
| 269 |
+
*,
|
| 270 |
+
retry_on_invalid: bool = True,
|
| 271 |
+
idx: int = 0,
|
| 272 |
+
) -> Tuple[
|
| 273 |
+
torch.Tensor,
|
| 274 |
+
torch.Tensor,
|
| 275 |
+
torch.Tensor,
|
| 276 |
+
torch.Tensor,
|
| 277 |
+
torch.Tensor,
|
| 278 |
+
torch.Tensor,
|
| 279 |
+
StudyConfig,
|
| 280 |
+
list[DosingConfig],
|
| 281 |
+
int,
|
| 282 |
+
]:
|
| 283 |
+
"""
|
| 284 |
+
Internal helper returning the raw tensors alongside sampling metadata.
|
| 285 |
+
This is a parallel implementation to `_generate_full_simulation` that supports
|
| 286 |
+
dosing with duration. Once validated, the two can be merged.
|
| 287 |
+
"""
|
| 288 |
+
study_config = sample_study_config(meta_study_config)
|
| 289 |
+
indiv_config_array = sample_individual_configs(study_config)
|
| 290 |
+
time_scales = derive_timescale_parameters(study_config, meta_study_config)
|
| 291 |
+
|
| 292 |
+
time_points = torch.linspace(
|
| 293 |
+
meta_study_config.time_start,
|
| 294 |
+
meta_study_config.time_stop,
|
| 295 |
+
meta_study_config.time_num_steps,
|
| 296 |
+
dtype=torch.float32,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
local_meta_dosing = replace(meta_dosing_config, num_individuals=study_config.num_individuals)
|
| 300 |
+
dosing_config_array = sample_dosing_with_duration_configs(local_meta_dosing)
|
| 301 |
+
|
| 302 |
+
full_sim, full_times, dosing_amounts, dosing_routes = sample_study_with_duration(
|
| 303 |
+
indiv_config_array,
|
| 304 |
+
dosing_config_array,
|
| 305 |
+
time_points,
|
| 306 |
+
meta_study_config.solver_method,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if not is_valid_simulation(full_sim):
|
| 310 |
+
attempt_number = idx + 1
|
| 311 |
+
if attempt_number > 5:
|
| 312 |
+
logger.warning(
|
| 313 |
+
"Invalid simulation encountered during attempt %d (recursion depth %d); retry_on_invalid=%s.",
|
| 314 |
+
attempt_number,
|
| 315 |
+
idx,
|
| 316 |
+
retry_on_invalid,
|
| 317 |
+
)
|
| 318 |
+
if retry_on_invalid:
|
| 319 |
+
(
|
| 320 |
+
full_sim,
|
| 321 |
+
full_times,
|
| 322 |
+
dosing_amounts,
|
| 323 |
+
dosing_routes,
|
| 324 |
+
time_points,
|
| 325 |
+
time_scales,
|
| 326 |
+
study_config,
|
| 327 |
+
dosing_config_array,
|
| 328 |
+
downstream_failures,
|
| 329 |
+
) = _generate_full_simulation_with_duration(
|
| 330 |
+
meta_study_config,
|
| 331 |
+
meta_dosing_config,
|
| 332 |
+
retry_on_invalid=retry_on_invalid,
|
| 333 |
+
idx=idx + 1,
|
| 334 |
+
)
|
| 335 |
+
return (
|
| 336 |
+
full_sim,
|
| 337 |
+
full_times,
|
| 338 |
+
dosing_amounts,
|
| 339 |
+
dosing_routes,
|
| 340 |
+
time_points,
|
| 341 |
+
time_scales,
|
| 342 |
+
study_config,
|
| 343 |
+
dosing_config_array,
|
| 344 |
+
downstream_failures + 1,
|
| 345 |
+
)
|
| 346 |
+
raise RuntimeError("Invalid simulation")
|
| 347 |
+
|
| 348 |
+
return (
|
| 349 |
+
full_sim,
|
| 350 |
+
full_times,
|
| 351 |
+
dosing_amounts,
|
| 352 |
+
dosing_routes,
|
| 353 |
+
time_points,
|
| 354 |
+
time_scales,
|
| 355 |
+
study_config,
|
| 356 |
+
dosing_config_array,
|
| 357 |
+
0,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _generate_simple_exp_simulation(
|
| 362 |
+
meta_study_config,
|
| 363 |
+
) -> Tuple[
|
| 364 |
+
TensorType["N", "T"], # full_simulation
|
| 365 |
+
TensorType["N", "T"], # full_simulation_times
|
| 366 |
+
TensorType["N"], # dosing_amounts
|
| 367 |
+
TensorType["N"], # dosing_route_types
|
| 368 |
+
TensorType["T"], # time_points
|
| 369 |
+
TensorType[2], # time_scales [tmax, t12]
|
| 370 |
+
]:
|
| 371 |
+
"""
|
| 372 |
+
Minimal synthetic PK-like simulator.
|
| 373 |
+
|
| 374 |
+
Changes:
|
| 375 |
+
- Samples a single per-RUN decay rate k ~ U(decay_rate_range) and uses it for all individuals.
|
| 376 |
+
- Uses only band_scale_range, baseline_range, and (new) decay_rate_range.
|
| 377 |
+
|
| 378 |
+
Derivations per RUN:
|
| 379 |
+
baseline_run ~ U(baseline_range)
|
| 380 |
+
band_scale_run ~ U(band_scale_range)
|
| 381 |
+
decay_rate ~ U(decay_rate_range)
|
| 382 |
+
intercept_mean = 1.0 + baseline_run
|
| 383 |
+
intercept_std = 0.5 * band_scale_run
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
# ---------------------------
|
| 387 |
+
# Basic hyperparameters
|
| 388 |
+
# ---------------------------
|
| 389 |
+
N: int = getattr(meta_study_config, "num_individuals", 16)
|
| 390 |
+
Tn: int = getattr(meta_study_config, "time_num_steps", 40)
|
| 391 |
+
t_min: float = getattr(meta_study_config, "time_start", 0.0)
|
| 392 |
+
t_max: float = getattr(meta_study_config, "time_stop", 24.0)
|
| 393 |
+
|
| 394 |
+
band_scale_range = getattr(meta_study_config, "band_scale_range", (0.1, 0.3))
|
| 395 |
+
baseline_range = getattr(meta_study_config, "baseline_range", (0.0, 0.1))
|
| 396 |
+
decay_rate_range = getattr(meta_study_config, "decay_rate_range", (0.3, 0.6)) # NEW
|
| 397 |
+
|
| 398 |
+
# ---------------------------
|
| 399 |
+
# Per-RUN draws (no seeds)
|
| 400 |
+
# ---------------------------
|
| 401 |
+
def _urun(lo, hi): # uniform helper
|
| 402 |
+
return (torch.rand(1) * (hi - lo) + lo).item()
|
| 403 |
+
|
| 404 |
+
band_scale_run = _urun(*band_scale_range)
|
| 405 |
+
baseline_run = _urun(*baseline_range)
|
| 406 |
+
decay_rate_k = _urun(*decay_rate_range) # shared across all individuals this run
|
| 407 |
+
|
| 408 |
+
intercept_mean = 1.0 + baseline_run
|
| 409 |
+
intercept_std = 0.5 * band_scale_run
|
| 410 |
+
|
| 411 |
+
# ---------------------------
|
| 412 |
+
# Time grid & single-exp shape
|
| 413 |
+
# ---------------------------
|
| 414 |
+
t: TensorType["T", 1] = torch.linspace(t_min, t_max, Tn).unsqueeze(-1) # [T,1]
|
| 415 |
+
f_t: TensorType["T", 1] = torch.exp(-decay_rate_k * t) # f_t(0)=1, shared shape
|
| 416 |
+
|
| 417 |
+
# ---------------------------
|
| 418 |
+
# Per-individual intercepts
|
| 419 |
+
# ---------------------------
|
| 420 |
+
intercepts: TensorType["N", 1, 1] = torch.normal(
|
| 421 |
+
mean=float(intercept_mean),
|
| 422 |
+
std=float(intercept_std),
|
| 423 |
+
size=(N, 1, 1),
|
| 424 |
+
).clamp_min(0.0)
|
| 425 |
+
|
| 426 |
+
# Build samples: scaled shape + shared run baseline.
|
| 427 |
+
samples: TensorType["N", "T", 1] = intercepts * f_t.unsqueeze(0) + baseline_run
|
| 428 |
+
samples = samples.clamp_min(0.0) # numerical safety
|
| 429 |
+
|
| 430 |
+
# ---------------------------
|
| 431 |
+
# Dummy dosing / time scales
|
| 432 |
+
# ---------------------------
|
| 433 |
+
dosing_amounts: TensorType["N"] = torch.zeros(N)
|
| 434 |
+
dosing_routes: TensorType["N"] = torch.zeros(N)
|
| 435 |
+
duration = float(t_max - t_min)
|
| 436 |
+
tmax = 0.3 * duration
|
| 437 |
+
t12 = 0.75 * duration
|
| 438 |
+
time_scales: TensorType[2] = torch.tensor([tmax, t12], dtype=torch.float32)
|
| 439 |
+
|
| 440 |
+
# ---------------------------
|
| 441 |
+
# Construct outputs
|
| 442 |
+
# ---------------------------
|
| 443 |
+
full_sim = samples.squeeze(-1) # [N, T]
|
| 444 |
+
full_sim_times = t.expand(N, -1, -1).squeeze(-1) # [N, T]
|
| 445 |
+
time_points = t.squeeze(-1) # [T]
|
| 446 |
+
|
| 447 |
+
return (
|
| 448 |
+
full_sim,
|
| 449 |
+
full_sim_times,
|
| 450 |
+
dosing_amounts,
|
| 451 |
+
dosing_routes,
|
| 452 |
+
time_points,
|
| 453 |
+
time_scales,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _generate_pulse_simulation(
|
| 458 |
+
meta_study_config,
|
| 459 |
+
) -> Tuple[
|
| 460 |
+
TensorType["N", "T"], # full_simulation
|
| 461 |
+
TensorType["N", "T"], # full_simulation_times
|
| 462 |
+
TensorType["N"], # dosing_amounts
|
| 463 |
+
TensorType["N"], # dosing_route_types
|
| 464 |
+
TensorType["T"], # time_points
|
| 465 |
+
TensorType[2], # time_scales [t_peak, t_half_tail]
|
| 466 |
+
]:
|
| 467 |
+
"""
|
| 468 |
+
Pulse-like PK-style simulator (rise -> peak -> decay).
|
| 469 |
+
|
| 470 |
+
Config used (all optional with safe defaults):
|
| 471 |
+
- num_individuals (int)
|
| 472 |
+
- time_start, time_stop, time_num_steps
|
| 473 |
+
- band_scale_range: (lo, hi) # controls intercept std via 0.5 * band_scale_run
|
| 474 |
+
- baseline_range: (lo, hi) # per-run vertical offset added to all traces
|
| 475 |
+
- decay_rate_range: (lo, hi) # per-run tail rate r; larger r => faster decay
|
| 476 |
+
|
| 477 |
+
Construction (per RUN):
|
| 478 |
+
duration = time_stop - time_start
|
| 479 |
+
t_peak = 0.30 * duration
|
| 480 |
+
r ~ U(decay_rate_range)
|
| 481 |
+
beta = 1 / r
|
| 482 |
+
alpha = 1 + r * t_peak # => peak near t_peak for Gamma(alpha, beta)
|
| 483 |
+
f(t) = t^(alpha-1) * exp(-t/beta) # normalized so max(f)=1
|
| 484 |
+
|
| 485 |
+
baseline_run ~ U(baseline_range)
|
| 486 |
+
band_scale_run ~ U(band_scale_range)
|
| 487 |
+
intercept_mean = 1.0 + baseline_run
|
| 488 |
+
intercept_std = 0.5 * band_scale_run
|
| 489 |
+
|
| 490 |
+
Per INDIVIDUAL:
|
| 491 |
+
intercept_i ~ Normal(intercept_mean, intercept_std), clamped to >= 0
|
| 492 |
+
|
| 493 |
+
Output:
|
| 494 |
+
samples_i(t) = intercept_i * f_norm(t) + baseline_run, clamped to >= 0
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
# ---------------------------
|
| 498 |
+
# Basics
|
| 499 |
+
# ---------------------------
|
| 500 |
+
N: int = getattr(meta_study_config, "num_individuals", 16)
|
| 501 |
+
Tn: int = getattr(meta_study_config, "time_num_steps", 40)
|
| 502 |
+
t_min: float = getattr(meta_study_config, "time_start", 0.0)
|
| 503 |
+
t_max: float = getattr(meta_study_config, "time_stop", 24.0)
|
| 504 |
+
|
| 505 |
+
band_scale_range = getattr(meta_study_config, "band_scale_range", (0.1, 0.3))
|
| 506 |
+
baseline_range = getattr(meta_study_config, "baseline_range", (0.0, 0.1))
|
| 507 |
+
decay_rate_range = getattr(meta_study_config, "decay_rate_range", (0.3, 0.6))
|
| 508 |
+
|
| 509 |
+
# ---------------------------
|
| 510 |
+
# Per-RUN draws (no seeds)
|
| 511 |
+
# ---------------------------
|
| 512 |
+
def _urun(lo, hi):
|
| 513 |
+
return (torch.rand(1) * (hi - lo) + lo).item()
|
| 514 |
+
|
| 515 |
+
band_scale_run = _urun(*band_scale_range)
|
| 516 |
+
baseline_run = _urun(*baseline_range)
|
| 517 |
+
r_tail = _urun(*decay_rate_range) # shared by all individuals this run
|
| 518 |
+
|
| 519 |
+
duration = float(t_max - t_min)
|
| 520 |
+
t_peak = 0.30 * duration # desired peak position
|
| 521 |
+
beta = 1.0 / max(r_tail, 1e-6) # tail scale
|
| 522 |
+
alpha = 1.0 + r_tail * t_peak # ensures peak near t_peak (alpha>1)
|
| 523 |
+
|
| 524 |
+
# Guardrails: make sure alpha > 1 for a proper rise-then-decay
|
| 525 |
+
if alpha <= 1.05:
|
| 526 |
+
alpha = 1.05
|
| 527 |
+
|
| 528 |
+
# ---------------------------
|
| 529 |
+
# Time grid & Gamma-shaped pulse
|
| 530 |
+
# ---------------------------
|
| 531 |
+
t: TensorType["T"] = torch.linspace(t_min, t_max, Tn) # [T]
|
| 532 |
+
t_shift = t - t_min # start at 0
|
| 533 |
+
# Gamma shape (unnormalized). For t=0, t^(alpha-1) is 0 if alpha>1.
|
| 534 |
+
f_t = (t_shift.clamp_min(0.0) ** (alpha - 1.0)) * torch.exp(-t_shift / beta)
|
| 535 |
+
|
| 536 |
+
# Normalize to max=1 so intercept controls amplitude
|
| 537 |
+
f_max = torch.amax(f_t).clamp_min(1e-12)
|
| 538 |
+
f_t = f_t / f_max # [T]
|
| 539 |
+
|
| 540 |
+
# ---------------------------
|
| 541 |
+
# Per-individual intercepts
|
| 542 |
+
# ---------------------------
|
| 543 |
+
intercept_mean = 1.0 + baseline_run
|
| 544 |
+
intercept_std = 0.5 * band_scale_run
|
| 545 |
+
|
| 546 |
+
intercepts: TensorType["N", 1] = torch.normal(
|
| 547 |
+
mean=float(intercept_mean),
|
| 548 |
+
std=float(intercept_std),
|
| 549 |
+
size=(N, 1),
|
| 550 |
+
).clamp_min(0.0)
|
| 551 |
+
|
| 552 |
+
# Samples: scale by intercept, add per-run baseline
|
| 553 |
+
samples: TensorType["N", "T"] = (intercepts * f_t.unsqueeze(0)) + baseline_run
|
| 554 |
+
samples = samples.clamp_min(0.0)
|
| 555 |
+
|
| 556 |
+
# ---------------------------
|
| 557 |
+
# Dummy dosing / time scales
|
| 558 |
+
# ---------------------------
|
| 559 |
+
dosing_amounts: TensorType["N"] = torch.zeros(N)
|
| 560 |
+
dosing_routes: TensorType["N"] = torch.zeros(N)
|
| 561 |
+
|
| 562 |
+
# Report t_peak and an approximate tail half-life (after the peak)
|
| 563 |
+
t_half_tail = t_peak + (torch.log(torch.tensor(2.0)) / max(r_tail, 1e-6)).item()
|
| 564 |
+
time_scales: TensorType[2] = torch.tensor([t_peak, t_half_tail], dtype=torch.float32)
|
| 565 |
+
|
| 566 |
+
# ---------------------------
|
| 567 |
+
# Construct outputs
|
| 568 |
+
# ---------------------------
|
| 569 |
+
full_sim = samples # [N, T]
|
| 570 |
+
full_sim_times: TensorType = t.unsqueeze(0).expand(N, -1) # [N, T]
|
| 571 |
+
time_points = t # [T]
|
| 572 |
+
|
| 573 |
+
return (
|
| 574 |
+
full_sim,
|
| 575 |
+
full_sim_times,
|
| 576 |
+
dosing_amounts,
|
| 577 |
+
dosing_routes,
|
| 578 |
+
time_points,
|
| 579 |
+
time_scales,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def _generate_simple_simulation(
|
| 584 |
+
meta_study_config,
|
| 585 |
+
) -> Tuple[
|
| 586 |
+
TensorType["N", "T"],
|
| 587 |
+
TensorType["N", "T"],
|
| 588 |
+
TensorType["N"],
|
| 589 |
+
TensorType["N"],
|
| 590 |
+
TensorType["T"],
|
| 591 |
+
TensorType[2],
|
| 592 |
+
]:
|
| 593 |
+
"""
|
| 594 |
+
Dispatcher that mixes two generators:
|
| 595 |
+
- with probability p1: _generate_simple_exp_simulation(...)
|
| 596 |
+
- with probability 1 - p1: _generate_pulse_simulation(...)
|
| 597 |
+
|
| 598 |
+
Config:
|
| 599 |
+
- p1 (float in [0,1]), default 0.5
|
| 600 |
+
"""
|
| 601 |
+
p1 = float(getattr(meta_study_config, "p1", 0.5))
|
| 602 |
+
# clamp to [0,1]
|
| 603 |
+
p1 = 0.0 if p1 < 0.0 else (1.0 if p1 > 1.0 else p1)
|
| 604 |
+
|
| 605 |
+
if torch.rand(1).item() < p1:
|
| 606 |
+
return _generate_simple_exp_simulation(meta_study_config)
|
| 607 |
+
else:
|
| 608 |
+
return _generate_pulse_simulation(meta_study_config)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def prepare_full_simulation(
|
| 612 |
+
meta_study_config,
|
| 613 |
+
meta_dosing_config,
|
| 614 |
+
*,
|
| 615 |
+
retry_on_invalid: bool = True,
|
| 616 |
+
idx: int = 0,
|
| 617 |
+
) -> Tuple[
|
| 618 |
+
TensorType["N", "T", 1],
|
| 619 |
+
TensorType["N", "T"],
|
| 620 |
+
TensorType["N"],
|
| 621 |
+
TensorType["N"],
|
| 622 |
+
TensorType["T"],
|
| 623 |
+
TensorType[2],
|
| 624 |
+
]:
|
| 625 |
+
"""
|
| 626 |
+
Generate a full INDIVIDUAL study simulation (before context/target split).
|
| 627 |
+
|
| 628 |
+
This bundles the common steps shared across all dataset generators.
|
| 629 |
+
If `meta_study_config.simple_mode=True`, uses `_generate_simple_simulation`.
|
| 630 |
+
"""
|
| 631 |
+
|
| 632 |
+
if getattr(meta_study_config, "simple_mode", False):
|
| 633 |
+
return _generate_simple_simulation(meta_study_config)
|
| 634 |
+
|
| 635 |
+
(
|
| 636 |
+
full_sim,
|
| 637 |
+
full_times,
|
| 638 |
+
dosing_amounts,
|
| 639 |
+
dosing_routes,
|
| 640 |
+
time_points,
|
| 641 |
+
time_scales,
|
| 642 |
+
_,
|
| 643 |
+
_,
|
| 644 |
+
_,
|
| 645 |
+
) = _generate_full_simulation(
|
| 646 |
+
meta_study_config,
|
| 647 |
+
meta_dosing_config,
|
| 648 |
+
retry_on_invalid=retry_on_invalid,
|
| 649 |
+
idx=idx,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
return full_sim, full_times, dosing_amounts, dosing_routes, time_points, time_scales
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def prepare_full_simulation_with_duration(
|
| 656 |
+
meta_study_config,
|
| 657 |
+
meta_dosing_config,
|
| 658 |
+
*,
|
| 659 |
+
retry_on_invalid: bool = True,
|
| 660 |
+
idx: int = 0,
|
| 661 |
+
) -> Tuple[
|
| 662 |
+
TensorType["N", "T", 1],
|
| 663 |
+
TensorType["N", "T"],
|
| 664 |
+
TensorType["N"],
|
| 665 |
+
TensorType["N"],
|
| 666 |
+
TensorType["T"],
|
| 667 |
+
TensorType[2],
|
| 668 |
+
]:
|
| 669 |
+
"""
|
| 670 |
+
Generate a full INDIVIDUAL study simulation (before context/target split).
|
| 671 |
+
|
| 672 |
+
This bundles the common steps shared across all dataset generators.
|
| 673 |
+
If `meta_study_config.simple_mode=True`, uses `_generate_simple_simulation`.
|
| 674 |
+
|
| 675 |
+
This is a parallel implementation to `prepare_full_simulation` that supports
|
| 676 |
+
dosing with duration. Once validated, the two can be merged.
|
| 677 |
+
"""
|
| 678 |
+
|
| 679 |
+
if getattr(meta_study_config, "simple_mode", False):
|
| 680 |
+
return _generate_simple_simulation(meta_study_config)
|
| 681 |
+
|
| 682 |
+
(
|
| 683 |
+
full_sim,
|
| 684 |
+
full_times,
|
| 685 |
+
dosing_amounts,
|
| 686 |
+
dosing_routes,
|
| 687 |
+
time_points,
|
| 688 |
+
time_scales,
|
| 689 |
+
_,
|
| 690 |
+
_,
|
| 691 |
+
_,
|
| 692 |
+
) = _generate_full_simulation_with_duration(
|
| 693 |
+
meta_study_config,
|
| 694 |
+
meta_dosing_config,
|
| 695 |
+
retry_on_invalid=retry_on_invalid,
|
| 696 |
+
idx=idx,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
return full_sim, full_times, dosing_amounts, dosing_routes, time_points, time_scales
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def _ensure_strictly_increasing_observations(
|
| 703 |
+
obs_times: list[float], obs_vals: list[list[float]], *, individual_id: str
|
| 704 |
+
) -> None:
|
| 705 |
+
"""Validate that the provided observation times are strictly increasing.
|
| 706 |
+
|
| 707 |
+
Parameters
|
| 708 |
+
----------
|
| 709 |
+
obs_times:
|
| 710 |
+
Sequence of observation timestamps extracted from the simulator.
|
| 711 |
+
obs_vals:
|
| 712 |
+
Sequence of observation values sampled at ``obs_times``.
|
| 713 |
+
individual_id:
|
| 714 |
+
Identifier of the individual being validated. Included in the
|
| 715 |
+
diagnostic error message to simplify debugging when duplicates are
|
| 716 |
+
detected in batched runs.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
if len(obs_times) != len(obs_vals):
|
| 720 |
+
raise ValueError(
|
| 721 |
+
"Observation times must be sorted and match the number of observations. "
|
| 722 |
+
f"Received lengths times={len(obs_times)} and values={len(obs_vals)} for "
|
| 723 |
+
f"{individual_id}. Observations={obs_vals}, times={obs_times}."
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
for idx_time in range(len(obs_times) - 1):
|
| 727 |
+
if obs_times[idx_time] >= obs_times[idx_time + 1]:
|
| 728 |
+
raise ValueError(
|
| 729 |
+
"Observation times must be sorted and match the number of observations. "
|
| 730 |
+
f"Detected non-increasing times for {individual_id} at position {idx_time}. "
|
| 731 |
+
f"Observations={obs_vals}, times={obs_times}."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def prepare_full_simulation_to_study_json(
|
| 736 |
+
meta_study_config: MetaStudyConfig,
|
| 737 |
+
observation_config: ObservationsConfig,
|
| 738 |
+
meta_dosing_config: MetaDosingConfig,
|
| 739 |
+
*,
|
| 740 |
+
retry_on_invalid: bool = True,
|
| 741 |
+
idx: int = 0,
|
| 742 |
+
) -> tuple[StudyJSON, int]:
|
| 743 |
+
"""Generate a full simulation and convert it into a :class:`StudyJSON` record.
|
| 744 |
+
|
| 745 |
+
Parameters
|
| 746 |
+
----------
|
| 747 |
+
meta_study_config:
|
| 748 |
+
Sampling configuration describing the population and numerical solver.
|
| 749 |
+
If meta_study_config.simple_mode is True, uses simplified synthetic data.
|
| 750 |
+
observation_config:
|
| 751 |
+
Configuration for the observation strategy used to extract measurements
|
| 752 |
+
from the raw simulation. All generated observations are stored under
|
| 753 |
+
the ``context`` section of the returned study.
|
| 754 |
+
meta_dosing_config:
|
| 755 |
+
Configuration describing the dosing regimen for each simulated
|
| 756 |
+
individual.
|
| 757 |
+
retry_on_invalid:
|
| 758 |
+
When ``True`` (default) the function retries simulation sampling if the
|
| 759 |
+
generated trajectories are numerically invalid.
|
| 760 |
+
idx:
|
| 761 |
+
Internal recursion depth counter exposed for debugging and testing.
|
| 762 |
+
|
| 763 |
+
Returns
|
| 764 |
+
-------
|
| 765 |
+
tuple[StudyJSON, int]
|
| 766 |
+
Canonical JSON representation of the simulated study with all
|
| 767 |
+
individuals stored in the ``context`` field and an empty ``target``
|
| 768 |
+
list, alongside the number of failed attempts before obtaining the
|
| 769 |
+
valid simulation.
|
| 770 |
+
"""
|
| 771 |
+
if getattr(meta_study_config, "simple_mode", False):
|
| 772 |
+
# Handle simple synthetic data generation
|
| 773 |
+
(
|
| 774 |
+
full_sim,
|
| 775 |
+
full_times,
|
| 776 |
+
dosing_amounts,
|
| 777 |
+
dosing_routes,
|
| 778 |
+
_time_points,
|
| 779 |
+
time_scales,
|
| 780 |
+
) = _generate_simple_simulation(meta_study_config)
|
| 781 |
+
study_config = {""}
|
| 782 |
+
dosing_config_array = [
|
| 783 |
+
DosingConfig(dose=float(d), route="", time=0.0) for d in dosing_amounts
|
| 784 |
+
]
|
| 785 |
+
failed_attempts = 0
|
| 786 |
+
else:
|
| 787 |
+
# Original mechanistic simulation code
|
| 788 |
+
(
|
| 789 |
+
full_sim,
|
| 790 |
+
full_times,
|
| 791 |
+
dosing_amounts,
|
| 792 |
+
_dosing_routes,
|
| 793 |
+
_time_points,
|
| 794 |
+
time_scales,
|
| 795 |
+
study_config,
|
| 796 |
+
dosing_config_array,
|
| 797 |
+
failed_attempts,
|
| 798 |
+
) = _generate_full_simulation(
|
| 799 |
+
meta_study_config,
|
| 800 |
+
meta_dosing_config,
|
| 801 |
+
retry_on_invalid=retry_on_invalid,
|
| 802 |
+
idx=idx,
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
observation_strategy = ObservationStrategyFactory.from_config(
|
| 806 |
+
observation_config, meta_study_config
|
| 807 |
+
)
|
| 808 |
+
obs_out, time_out, mask_out, rem_sim, rem_time, rem_mask, _ = observation_strategy.generate(
|
| 809 |
+
full_simulation=full_sim,
|
| 810 |
+
full_simulation_times=full_times,
|
| 811 |
+
time_scales=time_scales,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
context: list[IndividualJSON] = []
|
| 815 |
+
num_individuals = full_sim.shape[0]
|
| 816 |
+
|
| 817 |
+
for ind_idx in range(num_individuals):
|
| 818 |
+
mask = mask_out[ind_idx].to(torch.bool)
|
| 819 |
+
observations = obs_out[ind_idx][mask].tolist()
|
| 820 |
+
observation_times = time_out[ind_idx][mask].tolist()
|
| 821 |
+
|
| 822 |
+
_ensure_strictly_increasing_observations(
|
| 823 |
+
observation_times,
|
| 824 |
+
observations,
|
| 825 |
+
individual_id=f"context_{ind_idx}",
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
individual: IndividualJSON = {
|
| 829 |
+
"name_id": f"context_{ind_idx}",
|
| 830 |
+
"observations": observations,
|
| 831 |
+
"observation_times": observation_times,
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
if rem_sim is not None and rem_time is not None and rem_mask is not None:
|
| 835 |
+
rem_mask_row = rem_mask[ind_idx].to(torch.bool)
|
| 836 |
+
if rem_mask_row.any():
|
| 837 |
+
individual["remaining"] = rem_sim[ind_idx][rem_mask_row].tolist()
|
| 838 |
+
individual["remaining_times"] = rem_time[ind_idx][rem_mask_row].tolist()
|
| 839 |
+
|
| 840 |
+
dosing_cfg = dosing_config_array[ind_idx]
|
| 841 |
+
dose = float(dosing_amounts[ind_idx].item())
|
| 842 |
+
route = getattr(dosing_cfg, "route", "")
|
| 843 |
+
dosing_time = float(getattr(dosing_cfg, "time", 0.0))
|
| 844 |
+
|
| 845 |
+
if dose or route:
|
| 846 |
+
individual["dosing"] = [dose]
|
| 847 |
+
individual["dosing_type"] = [route]
|
| 848 |
+
individual["dosing_times"] = [dosing_time]
|
| 849 |
+
individual["dosing_name"] = [route]
|
| 850 |
+
|
| 851 |
+
context.append(individual)
|
| 852 |
+
|
| 853 |
+
study_json: StudyJSON = {
|
| 854 |
+
"context": context,
|
| 855 |
+
"target": [],
|
| 856 |
+
"meta_data": {
|
| 857 |
+
"study_name": f"simulated_study_{idx}",
|
| 858 |
+
"substance_name": getattr(study_config, "drug_id", "simulated_substance"),
|
| 859 |
+
},
|
| 860 |
+
}
|
| 861 |
+
|
| 862 |
+
return study_json, failed_attempts
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def prepare_full_simulation_with_repeated_targets(
|
| 866 |
+
meta_study_config: MetaStudyConfig,
|
| 867 |
+
meta_dosing_config: MetaDosingConfig,
|
| 868 |
+
n_targets: int,
|
| 869 |
+
*,
|
| 870 |
+
different_dosing: bool = False,
|
| 871 |
+
retry_on_invalid: bool = True,
|
| 872 |
+
idx: int = 0,
|
| 873 |
+
):
|
| 874 |
+
"""
|
| 875 |
+
Generate a context simulation (normal dosing) plus a new set of target
|
| 876 |
+
individuals.
|
| 877 |
+
|
| 878 |
+
Parameters
|
| 879 |
+
----------
|
| 880 |
+
different_dosing:
|
| 881 |
+
If ``False`` (default), all target individuals share one repeated
|
| 882 |
+
dosing configuration.
|
| 883 |
+
If ``True``, each target individual gets an independent dosing sample
|
| 884 |
+
from the same distribution used for context individuals.
|
| 885 |
+
|
| 886 |
+
Returns
|
| 887 |
+
-------
|
| 888 |
+
context_sim, context_times,
|
| 889 |
+
target_sim, target_times,
|
| 890 |
+
dosing_amounts_ctx, dosing_routes_ctx,
|
| 891 |
+
dosing_amounts_tgt, dosing_routes_tgt,
|
| 892 |
+
time_points, time_scales
|
| 893 |
+
"""
|
| 894 |
+
study_config = sample_study_config(meta_study_config)
|
| 895 |
+
indiv_config_array = sample_individual_configs(study_config)
|
| 896 |
+
time_scales = derive_timescale_parameters(study_config, meta_study_config)
|
| 897 |
+
|
| 898 |
+
time_points = torch.linspace(
|
| 899 |
+
meta_study_config.time_start,
|
| 900 |
+
meta_study_config.time_stop,
|
| 901 |
+
meta_study_config.time_num_steps,
|
| 902 |
+
dtype=torch.float32,
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
# Context part
|
| 906 |
+
local_meta_dosing_ctx = replace(
|
| 907 |
+
meta_dosing_config, num_individuals=study_config.num_individuals
|
| 908 |
+
)
|
| 909 |
+
dosing_config_array_ctx = sample_dosing_configs(local_meta_dosing_ctx)
|
| 910 |
+
|
| 911 |
+
full_sim, full_times, dosing_amounts_all, dosing_routes_all = sample_study(
|
| 912 |
+
indiv_config_array,
|
| 913 |
+
dosing_config_array_ctx,
|
| 914 |
+
time_points,
|
| 915 |
+
meta_study_config.solver_method,
|
| 916 |
+
)
|
| 917 |
+
if not is_valid_simulation(full_sim):
|
| 918 |
+
if retry_on_invalid:
|
| 919 |
+
return prepare_full_simulation_with_repeated_targets(
|
| 920 |
+
meta_study_config,
|
| 921 |
+
meta_dosing_config,
|
| 922 |
+
n_targets,
|
| 923 |
+
different_dosing=different_dosing,
|
| 924 |
+
idx=idx + 1,
|
| 925 |
+
)
|
| 926 |
+
raise RuntimeError("Invalid context simulation")
|
| 927 |
+
|
| 928 |
+
context_sim, context_times, ctx_idx = split_context_only(full_sim, full_times)
|
| 929 |
+
dosing_amounts_ctx = dosing_amounts_all[ctx_idx]
|
| 930 |
+
dosing_routes_ctx = dosing_routes_all[ctx_idx]
|
| 931 |
+
|
| 932 |
+
dosing_amounts_ctx = dosing_amounts_all[ctx_idx]
|
| 933 |
+
dosing_routes_ctx = dosing_routes_all[ctx_idx]
|
| 934 |
+
|
| 935 |
+
# Target part
|
| 936 |
+
indiv_cfg_targets = sample_individual_configs(study_config, n=n_targets)
|
| 937 |
+
local_meta_dosing_tgt = replace(meta_dosing_config, num_individuals=n_targets)
|
| 938 |
+
if different_dosing:
|
| 939 |
+
dosing_config_array_tgt = sample_dosing_configs(local_meta_dosing_tgt)
|
| 940 |
+
else:
|
| 941 |
+
dosing_config_array_tgt = sample_dosing_configs_repeated_target(
|
| 942 |
+
local_meta_dosing_tgt, n_targets
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
full_sim_tgt, full_times_tgt, dosing_amounts_tgt, dosing_routes_tgt = sample_study(
|
| 946 |
+
indiv_cfg_targets,
|
| 947 |
+
dosing_config_array_tgt,
|
| 948 |
+
time_points,
|
| 949 |
+
meta_study_config.solver_method,
|
| 950 |
+
)
|
| 951 |
+
if not is_valid_simulation(full_sim_tgt):
|
| 952 |
+
if retry_on_invalid:
|
| 953 |
+
return prepare_full_simulation_with_repeated_targets(
|
| 954 |
+
meta_study_config,
|
| 955 |
+
meta_dosing_config,
|
| 956 |
+
n_targets,
|
| 957 |
+
different_dosing=different_dosing,
|
| 958 |
+
idx=idx + 1,
|
| 959 |
+
)
|
| 960 |
+
raise RuntimeError("Invalid target simulation")
|
| 961 |
+
|
| 962 |
+
_, _, target_sim, target_times, _, tgt_idx = split_simulations_repeated_target(
|
| 963 |
+
full_sim_tgt, full_times_tgt
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
return (
|
| 967 |
+
context_sim,
|
| 968 |
+
context_times,
|
| 969 |
+
target_sim,
|
| 970 |
+
target_times,
|
| 971 |
+
dosing_amounts_ctx,
|
| 972 |
+
dosing_routes_ctx,
|
| 973 |
+
dosing_amounts_tgt[tgt_idx],
|
| 974 |
+
dosing_routes_tgt[tgt_idx],
|
| 975 |
+
time_points,
|
| 976 |
+
time_scales,
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
def prepare_full_simulation_list_with_repeated_targets(
|
| 981 |
+
meta_study_config: MetaStudyConfig,
|
| 982 |
+
meta_dosing_config: MetaDosingConfig,
|
| 983 |
+
n_targets: int,
|
| 984 |
+
num_of_different_dosages: int,
|
| 985 |
+
*,
|
| 986 |
+
retry_on_invalid: bool = True,
|
| 987 |
+
idx: int = 0,
|
| 988 |
+
):
|
| 989 |
+
"""Generate one shared context and ``L`` target sets with repeated dosing.
|
| 990 |
+
|
| 991 |
+
Parameters
|
| 992 |
+
----------
|
| 993 |
+
meta_study_config:
|
| 994 |
+
Sampling configuration controlling PK population and solver behaviour.
|
| 995 |
+
meta_dosing_config:
|
| 996 |
+
Dosing-distribution configuration used for both context and targets.
|
| 997 |
+
n_targets:
|
| 998 |
+
Number of target individuals for each dosing condition.
|
| 999 |
+
num_of_different_dosages:
|
| 1000 |
+
Number of target dosing conditions ``L``.
|
| 1001 |
+
retry_on_invalid:
|
| 1002 |
+
Whether to retry sampling when numerical invalid simulations are found.
|
| 1003 |
+
idx:
|
| 1004 |
+
Retry depth / attempt index used for diagnostics.
|
| 1005 |
+
|
| 1006 |
+
Returns
|
| 1007 |
+
-------
|
| 1008 |
+
tuple
|
| 1009 |
+
``(context_sim, context_times, dosing_amounts_ctx, dosing_routes_ctx,``
|
| 1010 |
+
``target_simulations, target_times_list, target_dosing_amounts_list,``
|
| 1011 |
+
``target_dosing_routes_list, time_points, time_scales)`` where each
|
| 1012 |
+
target list has length ``num_of_different_dosages``.
|
| 1013 |
+
"""
|
| 1014 |
+
|
| 1015 |
+
if num_of_different_dosages < 0:
|
| 1016 |
+
raise ValueError("num_of_different_dosages must be non-negative")
|
| 1017 |
+
|
| 1018 |
+
study_config = sample_study_config(meta_study_config)
|
| 1019 |
+
indiv_config_array = sample_individual_configs(study_config)
|
| 1020 |
+
time_scales = derive_timescale_parameters(study_config, meta_study_config)
|
| 1021 |
+
|
| 1022 |
+
# [T]
|
| 1023 |
+
time_points = torch.linspace(
|
| 1024 |
+
meta_study_config.time_start,
|
| 1025 |
+
meta_study_config.time_stop,
|
| 1026 |
+
meta_study_config.time_num_steps,
|
| 1027 |
+
dtype=torch.float32,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
# Context is sampled exactly once.
|
| 1031 |
+
local_meta_dosing_ctx = replace(
|
| 1032 |
+
meta_dosing_config, num_individuals=study_config.num_individuals
|
| 1033 |
+
)
|
| 1034 |
+
dosing_config_array_ctx = sample_dosing_configs(local_meta_dosing_ctx)
|
| 1035 |
+
full_sim, full_times, dosing_amounts_all, dosing_routes_all = sample_study(
|
| 1036 |
+
indiv_config_array,
|
| 1037 |
+
dosing_config_array_ctx,
|
| 1038 |
+
time_points,
|
| 1039 |
+
meta_study_config.solver_method,
|
| 1040 |
+
)
|
| 1041 |
+
if not is_valid_simulation(full_sim):
|
| 1042 |
+
if retry_on_invalid:
|
| 1043 |
+
return prepare_full_simulation_list_with_repeated_targets(
|
| 1044 |
+
meta_study_config,
|
| 1045 |
+
meta_dosing_config,
|
| 1046 |
+
n_targets,
|
| 1047 |
+
num_of_different_dosages,
|
| 1048 |
+
idx=idx + 1,
|
| 1049 |
+
)
|
| 1050 |
+
raise RuntimeError("Invalid context simulation")
|
| 1051 |
+
|
| 1052 |
+
# context_sim: [N_ctx, T], context_times: [N_ctx, T]
|
| 1053 |
+
context_sim, context_times, ctx_idx = split_context_only(full_sim, full_times)
|
| 1054 |
+
dosing_amounts_ctx = dosing_amounts_all[ctx_idx]
|
| 1055 |
+
dosing_routes_ctx = dosing_routes_all[ctx_idx]
|
| 1056 |
+
|
| 1057 |
+
# Keep the same target PK individuals across all dosing conditions so that
|
| 1058 |
+
# only dosing changes across list elements.
|
| 1059 |
+
indiv_cfg_targets = sample_individual_configs(study_config, n=n_targets)
|
| 1060 |
+
local_meta_dosing_tgt = replace(meta_dosing_config, num_individuals=n_targets)
|
| 1061 |
+
|
| 1062 |
+
target_simulations = []
|
| 1063 |
+
target_times_list = []
|
| 1064 |
+
target_dosing_amounts_list = []
|
| 1065 |
+
target_dosing_routes_list = []
|
| 1066 |
+
seen_dosing_signatures: set[tuple[str, float]] = set()
|
| 1067 |
+
|
| 1068 |
+
for _ in range(num_of_different_dosages):
|
| 1069 |
+
attempts = 0
|
| 1070 |
+
while True:
|
| 1071 |
+
attempts += 1
|
| 1072 |
+
dosing_config_array_tgt = sample_dosing_configs_repeated_target(
|
| 1073 |
+
local_meta_dosing_tgt, n_targets
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
# Ensure distinct dosing regimens across list elements.
|
| 1077 |
+
dosing_signature = ("", 0.0)
|
| 1078 |
+
if n_targets > 0 and len(dosing_config_array_tgt) > 0:
|
| 1079 |
+
first_cfg = dosing_config_array_tgt[0]
|
| 1080 |
+
dosing_signature = (
|
| 1081 |
+
str(getattr(first_cfg, "route", "")),
|
| 1082 |
+
float(getattr(first_cfg, "dose", 0.0)),
|
| 1083 |
+
)
|
| 1084 |
+
if dosing_signature in seen_dosing_signatures and num_of_different_dosages > 1:
|
| 1085 |
+
if attempts < 100:
|
| 1086 |
+
continue
|
| 1087 |
+
logger.warning(
|
| 1088 |
+
"Could not sample a unique repeated target dosing signature after %d attempts.",
|
| 1089 |
+
attempts,
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
full_sim_tgt, full_times_tgt, dosing_amounts_tgt, dosing_routes_tgt = sample_study(
|
| 1093 |
+
indiv_cfg_targets,
|
| 1094 |
+
dosing_config_array_tgt,
|
| 1095 |
+
time_points,
|
| 1096 |
+
meta_study_config.solver_method,
|
| 1097 |
+
)
|
| 1098 |
+
if not is_valid_simulation(full_sim_tgt):
|
| 1099 |
+
if retry_on_invalid and attempts < 100:
|
| 1100 |
+
continue
|
| 1101 |
+
if retry_on_invalid:
|
| 1102 |
+
return prepare_full_simulation_list_with_repeated_targets(
|
| 1103 |
+
meta_study_config,
|
| 1104 |
+
meta_dosing_config,
|
| 1105 |
+
n_targets,
|
| 1106 |
+
num_of_different_dosages,
|
| 1107 |
+
idx=idx + 1,
|
| 1108 |
+
)
|
| 1109 |
+
raise RuntimeError("Invalid target simulation")
|
| 1110 |
+
|
| 1111 |
+
_, _, target_sim, target_times, _, tgt_idx = split_simulations_repeated_target(
|
| 1112 |
+
full_sim_tgt, full_times_tgt
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
target_simulations.append(target_sim)
|
| 1116 |
+
target_times_list.append(target_times)
|
| 1117 |
+
target_dosing_amounts_list.append(dosing_amounts_tgt[tgt_idx])
|
| 1118 |
+
target_dosing_routes_list.append(dosing_routes_tgt[tgt_idx])
|
| 1119 |
+
if n_targets > 0:
|
| 1120 |
+
seen_dosing_signatures.add(dosing_signature)
|
| 1121 |
+
break
|
| 1122 |
+
|
| 1123 |
+
return (
|
| 1124 |
+
context_sim,
|
| 1125 |
+
context_times,
|
| 1126 |
+
dosing_amounts_ctx,
|
| 1127 |
+
dosing_routes_ctx,
|
| 1128 |
+
target_simulations,
|
| 1129 |
+
target_times_list,
|
| 1130 |
+
target_dosing_amounts_list,
|
| 1131 |
+
target_dosing_routes_list,
|
| 1132 |
+
time_points,
|
| 1133 |
+
time_scales,
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
def prepare_ensemble_of_simulations(
|
| 1138 |
+
meta_study_config: MetaStudyConfig,
|
| 1139 |
+
observation_config: ObservationsConfig,
|
| 1140 |
+
meta_dosing_config: MetaDosingConfig,
|
| 1141 |
+
number_of_samples: int,
|
| 1142 |
+
file_name: Optional[str] = None,
|
| 1143 |
+
group_size: Optional[int] = None,
|
| 1144 |
+
) -> tuple[list[StudyJSON] | list[list[StudyJSON]], float]:
|
| 1145 |
+
"""Generate an ensemble of simulated studies.
|
| 1146 |
+
|
| 1147 |
+
The helper repeatedly calls :func:`prepare_full_simulation_to_study_json`
|
| 1148 |
+
to produce ``number_of_samples`` independent simulations. When ``file_name``
|
| 1149 |
+
is provided, the resulting list is serialized as JSON for reproducibility
|
| 1150 |
+
and downstream processing.
|
| 1151 |
+
|
| 1152 |
+
Parameters
|
| 1153 |
+
----------
|
| 1154 |
+
meta_study_config:
|
| 1155 |
+
Sampling configuration controlling the pharmacokinetic population and
|
| 1156 |
+
solver settings.
|
| 1157 |
+
observation_config:
|
| 1158 |
+
Observation strategy applied to each generated simulation.
|
| 1159 |
+
meta_dosing_config:
|
| 1160 |
+
Configuration describing the dosing regimen per simulated individual.
|
| 1161 |
+
number_of_samples:
|
| 1162 |
+
Number of simulations to generate.
|
| 1163 |
+
file_name:
|
| 1164 |
+
Optional path used to persist the generated ensemble as a JSON file.
|
| 1165 |
+
group_size:
|
| 1166 |
+
Optional number of studies per group. If provided, the return value is
|
| 1167 |
+
a list of lists where each sublist has ``group_size`` elements. Extra
|
| 1168 |
+
simulations that do not fit evenly into the last group are ignored.
|
| 1169 |
+
|
| 1170 |
+
Returns
|
| 1171 |
+
-------
|
| 1172 |
+
tuple[list[StudyJSON] | list[list[StudyJSON]], float]
|
| 1173 |
+
Ensemble of simulated studies (flat or grouped) and the proportion of
|
| 1174 |
+
failed simulation attempts encountered while generating the ensemble.
|
| 1175 |
+
"""
|
| 1176 |
+
|
| 1177 |
+
studies: list[StudyJSON] = []
|
| 1178 |
+
total_failed_attempts = 0
|
| 1179 |
+
for idx in range(number_of_samples):
|
| 1180 |
+
study, failed_attempts = prepare_full_simulation_to_study_json(
|
| 1181 |
+
meta_study_config=meta_study_config,
|
| 1182 |
+
observation_config=observation_config,
|
| 1183 |
+
meta_dosing_config=meta_dosing_config,
|
| 1184 |
+
idx=idx,
|
| 1185 |
+
)
|
| 1186 |
+
studies.append(study)
|
| 1187 |
+
total_failed_attempts += failed_attempts
|
| 1188 |
+
|
| 1189 |
+
# --- Optional serialization ---
|
| 1190 |
+
if file_name:
|
| 1191 |
+
path = Path(file_name)
|
| 1192 |
+
path.write_text(json.dumps(studies, indent=2))
|
| 1193 |
+
|
| 1194 |
+
# --- Compute failure rate ---
|
| 1195 |
+
total_successful = len(studies)
|
| 1196 |
+
total_attempts = total_failed_attempts + total_successful
|
| 1197 |
+
failure_rate = total_failed_attempts / total_attempts if total_attempts > 0 else 0.0
|
| 1198 |
+
|
| 1199 |
+
# --- Optional grouping ---
|
| 1200 |
+
if group_size and group_size > 0:
|
| 1201 |
+
n_full_groups = len(studies) // group_size
|
| 1202 |
+
grouped_studies = [
|
| 1203 |
+
studies[i * group_size : (i + 1) * group_size] for i in range(n_full_groups)
|
| 1204 |
+
]
|
| 1205 |
+
return grouped_studies, failure_rate
|
| 1206 |
+
|
| 1207 |
+
return studies, failure_rate
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
def prepare_full_simulation_to_study_json_context_target(
|
| 1211 |
+
meta_study_config: MetaStudyConfig,
|
| 1212 |
+
observation_config: ObservationsConfig,
|
| 1213 |
+
meta_dosing_config_context: MetaDosingConfig,
|
| 1214 |
+
meta_dosing_config_target: MetaDosingConfig,
|
| 1215 |
+
*,
|
| 1216 |
+
retry_on_invalid: bool = True,
|
| 1217 |
+
idx: int = 0,
|
| 1218 |
+
) -> tuple[StudyJSON, int]:
|
| 1219 |
+
"""Generate a full simulation and convert it into a :class:`StudyJSON` record.
|
| 1220 |
+
Different dosing regimens are used for context and target individuals.
|
| 1221 |
+
|
| 1222 |
+
Parameters
|
| 1223 |
+
----------
|
| 1224 |
+
meta_study_config:
|
| 1225 |
+
Sampling configuration describing the population and numerical solver.
|
| 1226 |
+
If meta_study_config.simple_mode is True, uses simplified synthetic data.
|
| 1227 |
+
observation_config:
|
| 1228 |
+
Configuration for the observation strategy used to extract measurements
|
| 1229 |
+
from the raw simulation. All generated observations are stored under
|
| 1230 |
+
the ``context`` section of the returned study.
|
| 1231 |
+
meta_dosing_config_context:
|
| 1232 |
+
Configuration describing the dosing regimen for each simulated
|
| 1233 |
+
individual in the context set.
|
| 1234 |
+
meta_dosing_config_target:
|
| 1235 |
+
Configuration describing the dosing regimen for each simulated
|
| 1236 |
+
individual in the target set.
|
| 1237 |
+
retry_on_invalid:
|
| 1238 |
+
When ``True`` (default) the function retries simulation sampling if the
|
| 1239 |
+
generated trajectories are numerically invalid.
|
| 1240 |
+
idx:
|
| 1241 |
+
Internal recursion depth counter exposed for debugging and testing.
|
| 1242 |
+
|
| 1243 |
+
Returns
|
| 1244 |
+
-------
|
| 1245 |
+
tuple[StudyJSON, int]
|
| 1246 |
+
Canonical JSON representation of the simulated study with all
|
| 1247 |
+
individuals stored in the ``context`` field and an empty ``target``
|
| 1248 |
+
list, alongside the number of failed attempts before obtaining the
|
| 1249 |
+
valid simulation.
|
| 1250 |
+
"""
|
| 1251 |
+
|
| 1252 |
+
def prepare_section(name, meta_dosing_config):
|
| 1253 |
+
(
|
| 1254 |
+
full_sim,
|
| 1255 |
+
full_times,
|
| 1256 |
+
dosing_amounts,
|
| 1257 |
+
_dosing_routes,
|
| 1258 |
+
_time_points,
|
| 1259 |
+
time_scales,
|
| 1260 |
+
study_config,
|
| 1261 |
+
dosing_config_array,
|
| 1262 |
+
failed_attempts,
|
| 1263 |
+
) = _generate_full_simulation(
|
| 1264 |
+
meta_study_config,
|
| 1265 |
+
meta_dosing_config,
|
| 1266 |
+
retry_on_invalid=retry_on_invalid,
|
| 1267 |
+
idx=idx,
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
observation_strategy = ObservationStrategyFactory.from_config(
|
| 1271 |
+
observation_config, meta_study_config
|
| 1272 |
+
)
|
| 1273 |
+
obs_out, time_out, mask_out, rem_sim, rem_time, rem_mask, _ = observation_strategy.generate(
|
| 1274 |
+
full_simulation=full_sim,
|
| 1275 |
+
full_simulation_times=full_times,
|
| 1276 |
+
time_scales=time_scales,
|
| 1277 |
+
)
|
| 1278 |
+
|
| 1279 |
+
section: list[IndividualJSON] = []
|
| 1280 |
+
num_individuals = full_sim.shape[0]
|
| 1281 |
+
|
| 1282 |
+
for ind_idx in range(num_individuals):
|
| 1283 |
+
mask = mask_out[ind_idx].to(torch.bool)
|
| 1284 |
+
observations = obs_out[ind_idx][mask].tolist()
|
| 1285 |
+
observation_times = time_out[ind_idx][mask].tolist()
|
| 1286 |
+
|
| 1287 |
+
_ensure_strictly_increasing_observations(
|
| 1288 |
+
observation_times,
|
| 1289 |
+
observations,
|
| 1290 |
+
individual_id=f"{name}_{ind_idx}",
|
| 1291 |
+
)
|
| 1292 |
+
|
| 1293 |
+
individual: IndividualJSON = {
|
| 1294 |
+
"name_id": f"{name}_{ind_idx}",
|
| 1295 |
+
"observations": observations,
|
| 1296 |
+
"observation_times": observation_times,
|
| 1297 |
+
}
|
| 1298 |
+
|
| 1299 |
+
if rem_sim is not None and rem_time is not None and rem_mask is not None:
|
| 1300 |
+
rem_mask_row = rem_mask[ind_idx].to(torch.bool)
|
| 1301 |
+
if rem_mask_row.any():
|
| 1302 |
+
individual["remaining"] = rem_sim[ind_idx][rem_mask_row].tolist()
|
| 1303 |
+
individual["remaining_times"] = rem_time[ind_idx][rem_mask_row].tolist()
|
| 1304 |
+
|
| 1305 |
+
dosing_cfg = dosing_config_array[ind_idx]
|
| 1306 |
+
dose = float(dosing_amounts[ind_idx].item())
|
| 1307 |
+
route = getattr(dosing_cfg, "route", "")
|
| 1308 |
+
dosing_time = float(getattr(dosing_cfg, "time", 0.0))
|
| 1309 |
+
|
| 1310 |
+
if dose or route:
|
| 1311 |
+
individual["dosing"] = [dose]
|
| 1312 |
+
individual["dosing_type"] = [route]
|
| 1313 |
+
individual["dosing_times"] = [dosing_time]
|
| 1314 |
+
individual["dosing_name"] = [route]
|
| 1315 |
+
|
| 1316 |
+
section.append(individual)
|
| 1317 |
+
|
| 1318 |
+
return section, study_config, failed_attempts
|
| 1319 |
+
|
| 1320 |
+
# Set RNG to have the same study config for both context and target
|
| 1321 |
+
torch.manual_seed(42)
|
| 1322 |
+
context, study_config, failed_attempts_context = prepare_section(
|
| 1323 |
+
"context", meta_dosing_config_context
|
| 1324 |
+
)
|
| 1325 |
+
torch.manual_seed(42)
|
| 1326 |
+
target, _, failed_attempts_target = prepare_section("target", meta_dosing_config_target)
|
| 1327 |
+
|
| 1328 |
+
study_json: StudyJSON = {
|
| 1329 |
+
"context": context,
|
| 1330 |
+
"target": target,
|
| 1331 |
+
"meta_data": {
|
| 1332 |
+
"study_name": f"simulated_study_{idx}",
|
| 1333 |
+
"substance_name": getattr(study_config, "drug_id", "simulated_substance"),
|
| 1334 |
+
},
|
| 1335 |
+
}
|
| 1336 |
+
failed_attempts = failed_attempts_context + failed_attempts_target
|
| 1337 |
+
|
| 1338 |
+
return study_json, failed_attempts
|
sim_priors_pk/data/data_generation/dosing_models.py
ADDED
|
File without changes
|
sim_priors_pk/data/data_generation/observations_classes.py
ADDED
|
@@ -0,0 +1,1776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Callable, Optional, Tuple, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torchtyping import TensorType
|
| 7 |
+
from sim_priors_pk.config_classes.data_config import ObservationsConfig, MetaStudyConfig
|
| 8 |
+
from sim_priors_pk.data.data_generation.observations_functions import fix_past_time_random_selection
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _sample_past_count_with_bias(
|
| 12 |
+
low: int,
|
| 13 |
+
high: int,
|
| 14 |
+
*,
|
| 15 |
+
generative_bias: bool,
|
| 16 |
+
generator: torch.Generator,
|
| 17 |
+
device: torch.device,
|
| 18 |
+
) -> int:
|
| 19 |
+
"""Sample the number of past observations under the configured bias mode."""
|
| 20 |
+
|
| 21 |
+
if high <= 0:
|
| 22 |
+
return 0
|
| 23 |
+
|
| 24 |
+
if generative_bias:
|
| 25 |
+
sample_zero = int(torch.randint(0, 2, (1,), generator=generator, device=device).item()) == 0
|
| 26 |
+
if sample_zero:
|
| 27 |
+
return 0
|
| 28 |
+
|
| 29 |
+
rest_low = max(1, low)
|
| 30 |
+
if rest_low > high:
|
| 31 |
+
return 0
|
| 32 |
+
if rest_low == high:
|
| 33 |
+
return rest_low
|
| 34 |
+
return int(
|
| 35 |
+
torch.randint(
|
| 36 |
+
rest_low,
|
| 37 |
+
high + 1,
|
| 38 |
+
(1,),
|
| 39 |
+
generator=generator,
|
| 40 |
+
device=device,
|
| 41 |
+
).item()
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if low >= high:
|
| 45 |
+
return int(high)
|
| 46 |
+
|
| 47 |
+
return int(torch.randint(low, high + 1, (1,), generator=generator, device=device).item())
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ObservationStrategy(ABC):
|
| 51 |
+
def __init__(self, observations_config: ObservationsConfig, meta_config: MetaStudyConfig):
|
| 52 |
+
self.observations_config = observations_config
|
| 53 |
+
self.meta_config = meta_config
|
| 54 |
+
|
| 55 |
+
def _drop_non_positive_times_from_mask(self, times: Tensor, mask: Tensor) -> Tensor:
|
| 56 |
+
"""Optionally invalidate observations at non-positive timestamps.
|
| 57 |
+
|
| 58 |
+
When ``drop_time_zero_observations=True`` in :class:`ObservationsConfig`,
|
| 59 |
+
entries with ``time <= 0`` are excluded from downstream sampling.
|
| 60 |
+
"""
|
| 61 |
+
if not getattr(self.observations_config, "drop_time_zero_observations", False):
|
| 62 |
+
return mask
|
| 63 |
+
return mask & (times > 0)
|
| 64 |
+
|
| 65 |
+
def generate(
|
| 66 |
+
self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
|
| 67 |
+
) -> Tuple[Tensor, ...]:
|
| 68 |
+
"""Wrap raw generate: apply add_rem flag"""
|
| 69 |
+
# call subclass raw generation
|
| 70 |
+
obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._generate_raw(
|
| 71 |
+
full_simulation, full_simulation_times, **kwargs
|
| 72 |
+
)
|
| 73 |
+
# drop remaining if not desired
|
| 74 |
+
if not self.observations_config.add_rem:
|
| 75 |
+
rem_sim = rem_time = rem_mask = None
|
| 76 |
+
return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, None
|
| 77 |
+
|
| 78 |
+
@abstractmethod
|
| 79 |
+
def _generate_raw(
|
| 80 |
+
self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
|
| 81 |
+
) -> Tuple[
|
| 82 |
+
Tensor,
|
| 83 |
+
TensorType["B", "T_obs"],
|
| 84 |
+
TensorType["B", "T_obs"],
|
| 85 |
+
TensorType["B", "T_rem"],
|
| 86 |
+
TensorType["B", "T_rem"],
|
| 87 |
+
TensorType["B", "T_rem"],
|
| 88 |
+
]:
|
| 89 |
+
"""Generate observations and remaining sims raw, regardless of add_rem"""
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
def get_shapes(self) -> Tuple[int, int]:
|
| 93 |
+
"""Wrap raw shapes: apply add_rem flag"""
|
| 94 |
+
max_obs, max_rem = self._get_shapes_raw()
|
| 95 |
+
if not self.observations_config.add_rem:
|
| 96 |
+
max_rem = 0
|
| 97 |
+
return max_obs, max_rem
|
| 98 |
+
|
| 99 |
+
@abstractmethod
|
| 100 |
+
def _get_shapes_raw(self) -> Tuple[int, int]:
|
| 101 |
+
"""Return max observations and max remaining assuming add_rem=True"""
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class PKPeakHalfLifeStrategy(ObservationStrategy):
|
| 106 |
+
"""Observation strategy tailored to pharmacokinetic (PK) curves.
|
| 107 |
+
|
| 108 |
+
The strategy samples observations around the absorption peak and along the
|
| 109 |
+
elimination phase of a PK simulation. It uses a canonical grid composed of
|
| 110 |
+
four segments:
|
| 111 |
+
|
| 112 |
+
1. Several points before the peak that are proportional to the configured
|
| 113 |
+
peak time.
|
| 114 |
+
2. The peak itself.
|
| 115 |
+
3. Several points after the peak spaced by multiples of the provided
|
| 116 |
+
half-life.
|
| 117 |
+
4. Optional remainder points that are handed back to the caller when
|
| 118 |
+
``add_rem`` is enabled.
|
| 119 |
+
|
| 120 |
+
For **synthetic simulations**, the strategy still uses this canonical grid
|
| 121 |
+
and nearest-neighbour alignment.
|
| 122 |
+
|
| 123 |
+
For **empirical data**, measurements are treated as already canonical:
|
| 124 |
+
|
| 125 |
+
* No canonical time grid construction.
|
| 126 |
+
* No time normalisation or template matching.
|
| 127 |
+
* No interpolation or re-scaling to canonical coordinates.
|
| 128 |
+
|
| 129 |
+
Empirical sequences are only padded / truncated to the internal capacity
|
| 130 |
+
implied by :class:`ObservationsConfig` and :class:`MetaStudyConfig`, and
|
| 131 |
+
then passed through the same past/future splitting logic.
|
| 132 |
+
|
| 133 |
+
Past/future splitting
|
| 134 |
+
----------------------
|
| 135 |
+
When ``split_past_future=True``, the canonical sequence for each row is
|
| 136 |
+
split into:
|
| 137 |
+
|
| 138 |
+
* a *past* observation block of fixed width (``max_obs``), and
|
| 139 |
+
* an optional *remainder* block of width (``max_rem``).
|
| 140 |
+
|
| 141 |
+
In the default mode (no fixed past selection), the number of past points
|
| 142 |
+
is sampled according to ``generative_bias``:
|
| 143 |
+
|
| 144 |
+
* ``False`` samples in ``[min_past, max_past]``.
|
| 145 |
+
* ``True`` samples exactly ``0`` with probability 0.5 and, otherwise,
|
| 146 |
+
samples uniformly in ``[max(1, min_past), max_past]``.
|
| 147 |
+
|
| 148 |
+
Under ``generative_bias=False``, **short sequences** receive a special treatment: when
|
| 149 |
+
the number of valid canonical points is less than or equal to the
|
| 150 |
+
observation capacity, *all* valid points are placed in the observation
|
| 151 |
+
block and none are shifted into the remainder.
|
| 152 |
+
|
| 153 |
+
Fixed past selection
|
| 154 |
+
--------------------
|
| 155 |
+
Calling :meth:`fix_past_selection(k)` activates a strict mode in which
|
| 156 |
+
the strategy tries to expose exactly ``k`` earliest valid timestamps as
|
| 157 |
+
"past" for each series, subject to the following structural limits:
|
| 158 |
+
|
| 159 |
+
1. The number of real data points available in the series.
|
| 160 |
+
2. The observation capacity dictated by :meth:`_get_shapes_raw`.
|
| 161 |
+
|
| 162 |
+
Concretely, for each row:
|
| 163 |
+
|
| 164 |
+
* Let ``k`` be the fixed past count.
|
| 165 |
+
* Let ``total_valid`` be the number of valid canonical points.
|
| 166 |
+
* Let ``past_required = min(k, total_valid)``.
|
| 167 |
+
|
| 168 |
+
The observation block receives
|
| 169 |
+
|
| 170 |
+
* ``obs_count = min(past_required, max_obs)`` earliest valid points.
|
| 171 |
+
|
| 172 |
+
If ``past_required > obs_count`` (for example because ``k`` exceeds the
|
| 173 |
+
number of observation slots), the remaining required past events
|
| 174 |
+
``past_required - obs_count`` are the *first entries* in the remainder
|
| 175 |
+
block (subject to the remainder capacity). This guarantees that, as long
|
| 176 |
+
as data and shapes allow, the first ``k`` valid timestamps appear in
|
| 177 |
+
``obs`` + ``rem`` before any later timestamps.
|
| 178 |
+
|
| 179 |
+
Calling :meth:`release_past_selection()` returns to the default stochastic
|
| 180 |
+
behaviour governed by ``min_past``/``max_past``.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
_PEAK_PHASE_MULTIPLIERS = (0.1, 0.2, 0.5, 0.8)
|
| 184 |
+
_POST_PEAK_HALF_LIFE_MULTIPLIERS = (
|
| 185 |
+
0.25,
|
| 186 |
+
0.50,
|
| 187 |
+
1.00,
|
| 188 |
+
2.00,
|
| 189 |
+
4.00,
|
| 190 |
+
6.00,
|
| 191 |
+
8.00,
|
| 192 |
+
9.00,
|
| 193 |
+
14.0,
|
| 194 |
+
19.0,
|
| 195 |
+
30.0,
|
| 196 |
+
)
|
| 197 |
+
_RAW_CANONICAL_POINTS = len(_PEAK_PHASE_MULTIPLIERS) + 1 + len(_POST_PEAK_HALF_LIFE_MULTIPLIERS)
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self, observations_config: ObservationsConfig, meta_config: MetaStudyConfig
|
| 201 |
+
) -> None:
|
| 202 |
+
super().__init__(observations_config, meta_config)
|
| 203 |
+
self.max_num_obs = observations_config.max_num_obs
|
| 204 |
+
self.split_past_future = observations_config.split_past_future
|
| 205 |
+
self.min_past = observations_config.min_past
|
| 206 |
+
self.max_past = observations_config.max_past
|
| 207 |
+
self.generative_bias = observations_config.generative_bias
|
| 208 |
+
# None → default random selection. When set, the strategy enforces a
|
| 209 |
+
# strict fixed-past semantics as documented above.
|
| 210 |
+
self._fixed_past_obs_count: Optional[int] = None
|
| 211 |
+
|
| 212 |
+
def fix_past_selection(self, obs_count: int) -> None:
|
| 213 |
+
"""Activate strict ``k``-past behaviour.
|
| 214 |
+
|
| 215 |
+
When this mode is active and ``split_past_future=True``, every call to
|
| 216 |
+
:meth:`generate` or :meth:`generate_empirical` will:
|
| 217 |
+
|
| 218 |
+
* expose up to ``obs_count`` earliest valid timestamps in the
|
| 219 |
+
observation block, bounded by the available data and the observation
|
| 220 |
+
capacity;
|
| 221 |
+
* place any additional required past events (when ``obs_count`` is
|
| 222 |
+
larger than the observation capacity) at the *front* of the remainder
|
| 223 |
+
block (when a remainder is present).
|
| 224 |
+
|
| 225 |
+
The strategy is allowed to allocate fewer than ``obs_count`` past
|
| 226 |
+
events only when:
|
| 227 |
+
|
| 228 |
+
* the series contains fewer real data points than ``obs_count``, or
|
| 229 |
+
* the observation/remainder shapes leave insufficient slots.
|
| 230 |
+
|
| 231 |
+
In all other cases the earliest valid timestamps are allocated in the
|
| 232 |
+
order: observation block first, then remainder.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
if not self.split_past_future:
|
| 236 |
+
# No split → fixed past count is meaningless.
|
| 237 |
+
return
|
| 238 |
+
|
| 239 |
+
if obs_count < self.min_past or obs_count > self.max_past:
|
| 240 |
+
raise ValueError(
|
| 241 |
+
"Fixed past observation count must lie within the configured min/max bounds."
|
| 242 |
+
)
|
| 243 |
+
self._fixed_past_obs_count = int(obs_count)
|
| 244 |
+
|
| 245 |
+
def release_past_selection(self) -> None:
|
| 246 |
+
"""Return to the default random past selection behaviour."""
|
| 247 |
+
self._fixed_past_obs_count = None
|
| 248 |
+
|
| 249 |
+
@classmethod
|
| 250 |
+
def _build_canonical_grid(
|
| 251 |
+
cls,
|
| 252 |
+
*,
|
| 253 |
+
t_peak: float,
|
| 254 |
+
t_half: float,
|
| 255 |
+
device: torch.device,
|
| 256 |
+
dtype: torch.dtype,
|
| 257 |
+
) -> Tensor:
|
| 258 |
+
"""Construct the canonical grid for a single simulation.
|
| 259 |
+
|
| 260 |
+
The grid covers the pre-peak, peak and post-peak regime of the curve by
|
| 261 |
+
scaling two fundamental quantities supplied at runtime: the time of the
|
| 262 |
+
peak concentration ``t_peak`` and the half-life ``t_half``. Both values
|
| 263 |
+
are expected to be expressed in the same units as the simulation time
|
| 264 |
+
axis.
|
| 265 |
+
"""
|
| 266 |
+
before_peak = [mult * t_peak for mult in cls._PEAK_PHASE_MULTIPLIERS]
|
| 267 |
+
after_peak = [t_peak + mult * t_half for mult in cls._POST_PEAK_HALF_LIFE_MULTIPLIERS]
|
| 268 |
+
values = before_peak + [t_peak] + after_peak
|
| 269 |
+
return torch.tensor(values, device=device, dtype=dtype)
|
| 270 |
+
|
| 271 |
+
def _canonical_grid_capacity(self) -> int:
|
| 272 |
+
"""Return the number of canonical grid points available.
|
| 273 |
+
|
| 274 |
+
The capacity is the minimum between the simulator resolution and the
|
| 275 |
+
theoretical number of canonical points. This ensures that the
|
| 276 |
+
observation tensors never attempt to gather indices outside the
|
| 277 |
+
original simulation.
|
| 278 |
+
"""
|
| 279 |
+
time_steps = getattr(self.meta_config, "time_num_steps", self.max_num_obs)
|
| 280 |
+
return max(
|
| 281 |
+
0,
|
| 282 |
+
min(int(self.max_num_obs), int(time_steps), self._RAW_CANONICAL_POINTS),
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def _get_shapes_raw(self) -> Tuple[int, int]:
|
| 286 |
+
"""Compute the maximum number of observation and remainder slots.
|
| 287 |
+
|
| 288 |
+
Returns
|
| 289 |
+
-------
|
| 290 |
+
max_obs, max_rem : int, int
|
| 291 |
+
* ``max_obs`` – maximum number of observation time-steps.
|
| 292 |
+
* ``max_rem`` – maximum number of remainder time-steps when
|
| 293 |
+
``add_rem`` is enabled.
|
| 294 |
+
|
| 295 |
+
Raises
|
| 296 |
+
------
|
| 297 |
+
ValueError
|
| 298 |
+
If a past/future split is requested but the canonical capacity
|
| 299 |
+
cannot satisfy the configured ``min_past`` requirement.
|
| 300 |
+
"""
|
| 301 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 302 |
+
if canonical_cap == 0:
|
| 303 |
+
return 0, 0
|
| 304 |
+
|
| 305 |
+
if self.split_past_future:
|
| 306 |
+
if canonical_cap < self.min_past:
|
| 307 |
+
raise ValueError("Canonical grid capacity is smaller than the configured min_past")
|
| 308 |
+
max_obs = min(self.max_past, canonical_cap)
|
| 309 |
+
max_rem = max(0, canonical_cap - self.min_past)
|
| 310 |
+
else:
|
| 311 |
+
max_obs = canonical_cap
|
| 312 |
+
max_rem = canonical_cap
|
| 313 |
+
|
| 314 |
+
return max_obs, max_rem
|
| 315 |
+
|
| 316 |
+
@staticmethod
|
| 317 |
+
def _deduplicate_sorted_indices(
|
| 318 |
+
idx: Tensor, valid_mask: Optional[Tensor] = None
|
| 319 |
+
) -> Tuple[Tensor, Tensor]:
|
| 320 |
+
"""Collapse repeated gather indices while preserving alignment.
|
| 321 |
+
|
| 322 |
+
``idx`` is expected to be monotonically increasing. Consecutive
|
| 323 |
+
duplicates are collapsed into a single entry at the front of the tensor
|
| 324 |
+
and the corresponding ``valid_mask`` entries are shifted accordingly.
|
| 325 |
+
"""
|
| 326 |
+
if valid_mask is None:
|
| 327 |
+
valid_mask = torch.ones_like(idx, dtype=torch.bool)
|
| 328 |
+
|
| 329 |
+
if idx.numel() <= 1:
|
| 330 |
+
return idx, valid_mask
|
| 331 |
+
|
| 332 |
+
duplicate_mask = torch.zeros_like(idx, dtype=torch.bool)
|
| 333 |
+
duplicate_mask[1:] = idx[1:] == idx[:-1]
|
| 334 |
+
|
| 335 |
+
if not duplicate_mask.any():
|
| 336 |
+
return idx, valid_mask
|
| 337 |
+
|
| 338 |
+
unique_mask = ~duplicate_mask
|
| 339 |
+
kept_idx = idx[unique_mask]
|
| 340 |
+
duplicate_idx = idx[duplicate_mask]
|
| 341 |
+
|
| 342 |
+
padded_idx = torch.empty_like(idx)
|
| 343 |
+
padded_idx[: kept_idx.numel()] = kept_idx
|
| 344 |
+
padded_idx[kept_idx.numel() :] = duplicate_idx
|
| 345 |
+
|
| 346 |
+
kept_valid = valid_mask[unique_mask]
|
| 347 |
+
padded_mask = torch.zeros_like(valid_mask)
|
| 348 |
+
padded_mask[: kept_valid.numel()] = kept_valid
|
| 349 |
+
|
| 350 |
+
return padded_idx, padded_mask
|
| 351 |
+
|
| 352 |
+
def _assemble_from_canonical(
|
| 353 |
+
self,
|
| 354 |
+
canonical_vals: Tensor,
|
| 355 |
+
canonical_times: Tensor,
|
| 356 |
+
canonical_mask: Tensor,
|
| 357 |
+
*,
|
| 358 |
+
generator: Optional[torch.Generator] = None,
|
| 359 |
+
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
| 360 |
+
"""Convert canonical tensors into output observations.
|
| 361 |
+
|
| 362 |
+
The canonical representation stores **all** admissible samples for a
|
| 363 |
+
batch element. This helper slices the canonical tensors into the
|
| 364 |
+
"past" observations that will be returned to the caller and (when
|
| 365 |
+
requested) the "future" remainder.
|
| 366 |
+
|
| 367 |
+
Allocation invariants
|
| 368 |
+
---------------------
|
| 369 |
+
For each batch row:
|
| 370 |
+
|
| 371 |
+
* Let ``valid_idx`` be the indices where ``canonical_mask`` is True,
|
| 372 |
+
sorted in ascending order.
|
| 373 |
+
* The observation block always receives the **earliest**
|
| 374 |
+
``obs_count`` indices from ``valid_idx``.
|
| 375 |
+
* The remainder block (when present) receives later indices only; it
|
| 376 |
+
never contains timestamps that precede those in the observation block.
|
| 377 |
+
* Under ``generative_bias=False``, short sequences
|
| 378 |
+
(``total_valid <= max_obs``) keep all valid points in the
|
| 379 |
+
observation block and do not shift points to the remainder.
|
| 380 |
+
|
| 381 |
+
When :meth:`fix_past_selection(k)` is active, we define::
|
| 382 |
+
|
| 383 |
+
past_required = min(k, total_valid)
|
| 384 |
+
|
| 385 |
+
and allocate:
|
| 386 |
+
|
| 387 |
+
* ``obs_count = min(past_required, max_obs)`` to the observation
|
| 388 |
+
block; and
|
| 389 |
+
* any surplus past events ``past_required - obs_count`` at the **front**
|
| 390 |
+
of the remainder block (subject to the remainder capacity), followed
|
| 391 |
+
by any truly future points.
|
| 392 |
+
|
| 393 |
+
Releasing the fixed selection returns to the stochastic behaviour
|
| 394 |
+
controlled by ``generative_bias``.
|
| 395 |
+
"""
|
| 396 |
+
max_obs, max_rem = self._get_shapes_raw()
|
| 397 |
+
device = canonical_vals.device
|
| 398 |
+
dtype = canonical_vals.dtype
|
| 399 |
+
batch, _ = canonical_vals.shape
|
| 400 |
+
|
| 401 |
+
obs_out = torch.zeros(batch, max_obs, dtype=dtype, device=device)
|
| 402 |
+
obs_time = torch.zeros_like(obs_out)
|
| 403 |
+
obs_mask = torch.zeros(batch, max_obs, dtype=torch.bool, device=device)
|
| 404 |
+
|
| 405 |
+
rem_sim = rem_time = rem_mask = None
|
| 406 |
+
if max_rem > 0:
|
| 407 |
+
rem_sim = torch.zeros(batch, max_rem, dtype=dtype, device=device)
|
| 408 |
+
rem_time = torch.zeros_like(rem_sim)
|
| 409 |
+
rem_mask = torch.zeros(batch, max_rem, dtype=torch.bool, device=device)
|
| 410 |
+
|
| 411 |
+
gen = generator if generator is not None else torch.default_generator
|
| 412 |
+
|
| 413 |
+
for row in range(batch):
|
| 414 |
+
valid_idx = canonical_mask[row].nonzero(as_tuple=True)[0]
|
| 415 |
+
total_valid = int(valid_idx.numel())
|
| 416 |
+
if total_valid == 0:
|
| 417 |
+
continue
|
| 418 |
+
|
| 419 |
+
fixed_k = self._fixed_past_obs_count if self.split_past_future else None
|
| 420 |
+
|
| 421 |
+
# ------------------------------------------------------------------
|
| 422 |
+
# 1) Decide obs_count
|
| 423 |
+
# ------------------------------------------------------------------
|
| 424 |
+
if self.split_past_future and fixed_k is not None:
|
| 425 |
+
# Strict fixed-past semantics. Structural limits:
|
| 426 |
+
# - real data (total_valid)
|
| 427 |
+
# - observation capacity (max_obs)
|
| 428 |
+
past_required = min(fixed_k, total_valid)
|
| 429 |
+
obs_capacity = min(max_obs, total_valid)
|
| 430 |
+
obs_count = min(past_required, obs_capacity)
|
| 431 |
+
else:
|
| 432 |
+
# Default stochastic behaviour; the short-series fix is kept
|
| 433 |
+
# for the non-biased mode only.
|
| 434 |
+
if self.split_past_future:
|
| 435 |
+
low = min(self.min_past, total_valid)
|
| 436 |
+
high = min(self.max_past, total_valid)
|
| 437 |
+
|
| 438 |
+
sampled = _sample_past_count_with_bias(
|
| 439 |
+
low=low,
|
| 440 |
+
high=high,
|
| 441 |
+
generative_bias=self.generative_bias,
|
| 442 |
+
generator=gen,
|
| 443 |
+
device=device,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
if (not self.generative_bias) and total_valid <= max_obs:
|
| 447 |
+
# Short-series fix: never push valid points into the
|
| 448 |
+
# remainder just to satisfy a random split.
|
| 449 |
+
obs_count = total_valid
|
| 450 |
+
else:
|
| 451 |
+
obs_count = min(sampled, max_obs)
|
| 452 |
+
else:
|
| 453 |
+
obs_count = min(total_valid, max_obs)
|
| 454 |
+
|
| 455 |
+
# Safety clamp.
|
| 456 |
+
obs_count = max(0, min(obs_count, min(max_obs, total_valid)))
|
| 457 |
+
|
| 458 |
+
# ------------------------------------------------------------------
|
| 459 |
+
# 2) Fill observation block (earliest obs_count indices)
|
| 460 |
+
# ------------------------------------------------------------------
|
| 461 |
+
if obs_count > 0:
|
| 462 |
+
take = valid_idx[:obs_count]
|
| 463 |
+
obs_out[row, :obs_count] = canonical_vals[row, take]
|
| 464 |
+
obs_time[row, :obs_count] = canonical_times[row, take]
|
| 465 |
+
obs_mask[row, :obs_count] = True
|
| 466 |
+
|
| 467 |
+
# ------------------------------------------------------------------
|
| 468 |
+
# 3) Fill remainder block (if enabled)
|
| 469 |
+
# ------------------------------------------------------------------
|
| 470 |
+
if rem_sim is not None:
|
| 471 |
+
if self.split_past_future and fixed_k is not None:
|
| 472 |
+
# Remaining required past events plus genuine future.
|
| 473 |
+
past_required = min(fixed_k, total_valid)
|
| 474 |
+
# indices that are still part of the fixed past window
|
| 475 |
+
# but did not fit into the observation block
|
| 476 |
+
extra_past_idx = valid_idx[obs_count:past_required]
|
| 477 |
+
future_idx = valid_idx[past_required:]
|
| 478 |
+
|
| 479 |
+
candidates: List[Tensor] = []
|
| 480 |
+
if extra_past_idx.numel() > 0:
|
| 481 |
+
candidates.append(extra_past_idx)
|
| 482 |
+
if future_idx.numel() > 0:
|
| 483 |
+
candidates.append(future_idx)
|
| 484 |
+
if candidates:
|
| 485 |
+
remainder_candidates = torch.cat(candidates, dim=0)
|
| 486 |
+
else:
|
| 487 |
+
remainder_candidates = valid_idx.new_empty((0,), dtype=valid_idx.dtype)
|
| 488 |
+
else:
|
| 489 |
+
# Default behaviour: everything after the obs window.
|
| 490 |
+
remainder_candidates = valid_idx[obs_count:]
|
| 491 |
+
|
| 492 |
+
rem_count = min(int(remainder_candidates.numel()), max_rem)
|
| 493 |
+
if rem_count > 0:
|
| 494 |
+
rem_idx = remainder_candidates[:rem_count]
|
| 495 |
+
rem_sim[row, :rem_count] = canonical_vals[row, rem_idx]
|
| 496 |
+
rem_time[row, :rem_count] = canonical_times[row, rem_idx]
|
| 497 |
+
rem_mask[row, :rem_count] = True
|
| 498 |
+
|
| 499 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask
|
| 500 |
+
|
| 501 |
+
def _align_simulation_to_canonical(
|
| 502 |
+
self,
|
| 503 |
+
full_simulation: Tensor,
|
| 504 |
+
full_simulation_times: Tensor,
|
| 505 |
+
*,
|
| 506 |
+
time_scales: Tensor,
|
| 507 |
+
num_obs_sampler: Optional[Callable[[int], Tensor]] = None,
|
| 508 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
| 509 |
+
"""Gather canonical samples from a simulated PK curve.
|
| 510 |
+
|
| 511 |
+
Synthetic behaviour is unchanged compared to the original strategy:
|
| 512 |
+
we build a canonical grid, snap it to the nearest simulation times and
|
| 513 |
+
optionally subsample points via ``num_obs_sampler``.
|
| 514 |
+
"""
|
| 515 |
+
device = full_simulation.device
|
| 516 |
+
dtype = full_simulation.dtype
|
| 517 |
+
batch, _ = full_simulation.shape
|
| 518 |
+
time_steps = int(full_simulation_times.size(1))
|
| 519 |
+
|
| 520 |
+
# DataLoader workers may receive empty row slices (B=0). In that case
|
| 521 |
+
# there is no reference timeline to align against; return an empty
|
| 522 |
+
# canonical block and let _assemble_from_canonical create [B, *] outputs.
|
| 523 |
+
if batch == 0 or time_steps == 0:
|
| 524 |
+
zero = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 525 |
+
mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
|
| 526 |
+
return zero, zero, mask, time_scales.clone()
|
| 527 |
+
|
| 528 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 529 |
+
if canonical_cap == 0:
|
| 530 |
+
zero = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 531 |
+
mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
|
| 532 |
+
return zero, zero, mask, time_scales.clone()
|
| 533 |
+
|
| 534 |
+
grid = self._build_canonical_grid(
|
| 535 |
+
t_peak=time_scales[0].item(),
|
| 536 |
+
t_half=time_scales[1].item(),
|
| 537 |
+
device=device,
|
| 538 |
+
dtype=dtype,
|
| 539 |
+
)[:canonical_cap]
|
| 540 |
+
|
| 541 |
+
ref_times = full_simulation_times[0]
|
| 542 |
+
min_time = ref_times.min()
|
| 543 |
+
max_time = ref_times.max()
|
| 544 |
+
grid_valid_mask = (grid >= min_time) & (grid <= max_time)
|
| 545 |
+
|
| 546 |
+
idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
|
| 547 |
+
idx, order = idx.sort()
|
| 548 |
+
grid_valid_mask = grid_valid_mask[order]
|
| 549 |
+
idx, grid_valid_mask = self._deduplicate_sorted_indices(idx, grid_valid_mask)
|
| 550 |
+
|
| 551 |
+
gather_idx = idx[None, :].expand(batch, -1)
|
| 552 |
+
batch_idx = torch.arange(batch, device=device)[:, None]
|
| 553 |
+
|
| 554 |
+
canonical_vals = full_simulation[batch_idx, gather_idx]
|
| 555 |
+
canonical_times = full_simulation_times[batch_idx, gather_idx]
|
| 556 |
+
|
| 557 |
+
invalid_slots = ~grid_valid_mask
|
| 558 |
+
if invalid_slots.any():
|
| 559 |
+
canonical_vals[:, invalid_slots] = 0
|
| 560 |
+
canonical_times[:, invalid_slots] = 0
|
| 561 |
+
|
| 562 |
+
if num_obs_sampler is None:
|
| 563 |
+
total_counts = torch.full((batch,), canonical_cap, dtype=torch.long, device=device)
|
| 564 |
+
else:
|
| 565 |
+
sampled = num_obs_sampler(batch).to(device=device).long()
|
| 566 |
+
total_counts = sampled.clamp(min=0, max=canonical_cap)
|
| 567 |
+
|
| 568 |
+
max_valid = int(grid_valid_mask.sum().item())
|
| 569 |
+
if max_valid == 0:
|
| 570 |
+
total_counts.zero_()
|
| 571 |
+
else:
|
| 572 |
+
total_counts.clamp_(max=max_valid)
|
| 573 |
+
|
| 574 |
+
valid_order = grid_valid_mask.long().cumsum(dim=0) - 1
|
| 575 |
+
valid_order = torch.where(
|
| 576 |
+
grid_valid_mask,
|
| 577 |
+
valid_order,
|
| 578 |
+
torch.full_like(valid_order, -1, dtype=valid_order.dtype),
|
| 579 |
+
)
|
| 580 |
+
canonical_mask = grid_valid_mask[None, :] & (valid_order[None, :] < total_counts[:, None])
|
| 581 |
+
canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
|
| 582 |
+
|
| 583 |
+
return canonical_vals, canonical_times, canonical_mask, time_scales.clone()
|
| 584 |
+
|
| 585 |
+
def _align_empirical_to_canonical(
|
| 586 |
+
self,
|
| 587 |
+
empirical_obs: Tensor,
|
| 588 |
+
empirical_times: Tensor,
|
| 589 |
+
empirical_mask: Tensor,
|
| 590 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 591 |
+
"""(Legacy) Project empirical observations onto the canonical grid.
|
| 592 |
+
|
| 593 |
+
This method is retained for backward compatibility but is **not** used
|
| 594 |
+
by :meth:`generate_empirical`, which now treats empirical data as
|
| 595 |
+
already canonical. New code should avoid calling this helper.
|
| 596 |
+
"""
|
| 597 |
+
device = empirical_obs.device
|
| 598 |
+
dtype = empirical_obs.dtype
|
| 599 |
+
batch, _ = empirical_obs.shape
|
| 600 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 601 |
+
|
| 602 |
+
canonical_vals = torch.zeros(batch, canonical_cap, dtype=dtype, device=device)
|
| 603 |
+
canonical_times = torch.zeros_like(canonical_vals)
|
| 604 |
+
canonical_mask = torch.zeros(batch, canonical_cap, dtype=torch.bool, device=device)
|
| 605 |
+
|
| 606 |
+
if canonical_cap == 0:
|
| 607 |
+
return canonical_vals, canonical_times, canonical_mask
|
| 608 |
+
|
| 609 |
+
for row in range(batch):
|
| 610 |
+
valid_idx = empirical_mask[row].nonzero(as_tuple=True)[0]
|
| 611 |
+
if valid_idx.numel() == 0:
|
| 612 |
+
continue
|
| 613 |
+
|
| 614 |
+
obs_row = empirical_obs[row, valid_idx]
|
| 615 |
+
time_row = empirical_times[row, valid_idx]
|
| 616 |
+
max_time = torch.maximum(time_row.max(), torch.tensor(1.0, device=device))
|
| 617 |
+
norm_time = time_row / max_time
|
| 618 |
+
|
| 619 |
+
peak_idx = obs_row.argmax().item()
|
| 620 |
+
t_peak = norm_time[peak_idx].item()
|
| 621 |
+
post_times = norm_time[peak_idx:]
|
| 622 |
+
post_obs = obs_row[peak_idx:]
|
| 623 |
+
half_level = obs_row[peak_idx] / 2
|
| 624 |
+
below_half = (post_obs <= half_level).nonzero(as_tuple=True)[0]
|
| 625 |
+
if below_half.numel() == 0:
|
| 626 |
+
half_time = post_times[-1].item()
|
| 627 |
+
else:
|
| 628 |
+
half_time = post_times[below_half[0]].item()
|
| 629 |
+
t_half = max(half_time - t_peak, 1e-3)
|
| 630 |
+
|
| 631 |
+
grid = self._build_canonical_grid(
|
| 632 |
+
t_peak=t_peak if t_peak > 0 else 1e-3,
|
| 633 |
+
t_half=t_half,
|
| 634 |
+
device=device,
|
| 635 |
+
dtype=dtype,
|
| 636 |
+
)[:canonical_cap].clamp(max=1.0)
|
| 637 |
+
|
| 638 |
+
actual_grid = grid * max_time
|
| 639 |
+
distances = torch.cdist(actual_grid[:, None], time_row[:, None])
|
| 640 |
+
nearest = distances.argmin(dim=1)
|
| 641 |
+
|
| 642 |
+
usable = min(time_row.numel(), grid.numel())
|
| 643 |
+
if usable == 0:
|
| 644 |
+
continue
|
| 645 |
+
|
| 646 |
+
canonical_vals[row, :usable] = obs_row[nearest[:usable]]
|
| 647 |
+
canonical_times[row, :usable] = time_row[nearest[:usable]]
|
| 648 |
+
canonical_mask[row, :usable] = True
|
| 649 |
+
|
| 650 |
+
canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
|
| 651 |
+
|
| 652 |
+
return canonical_vals, canonical_times, canonical_mask
|
| 653 |
+
|
| 654 |
+
def _prepare_empirical_as_canonical(
|
| 655 |
+
self,
|
| 656 |
+
empirical_obs: Tensor,
|
| 657 |
+
empirical_times: Tensor,
|
| 658 |
+
empirical_mask: Tensor,
|
| 659 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 660 |
+
"""Treat empirical observations as already canonical.
|
| 661 |
+
|
| 662 |
+
This helper:
|
| 663 |
+
|
| 664 |
+
* does **not** build any canonical grid;
|
| 665 |
+
* does **not** normalise or re-scale time;
|
| 666 |
+
* simply copies valid empirical points in their original order into
|
| 667 |
+
fixed-size tensors, padding with zeros / False as needed.
|
| 668 |
+
|
| 669 |
+
The resulting tensors have width equal to the canonical capacity so
|
| 670 |
+
that they can be passed to :meth:`_assemble_from_canonical`.
|
| 671 |
+
"""
|
| 672 |
+
device = empirical_obs.device
|
| 673 |
+
dtype = empirical_obs.dtype
|
| 674 |
+
batch, _ = empirical_obs.shape
|
| 675 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 676 |
+
|
| 677 |
+
canonical_vals = torch.zeros(batch, canonical_cap, dtype=dtype, device=device)
|
| 678 |
+
canonical_times = torch.zeros_like(canonical_vals)
|
| 679 |
+
canonical_mask = torch.zeros(batch, canonical_cap, dtype=torch.bool, device=device)
|
| 680 |
+
|
| 681 |
+
if canonical_cap == 0:
|
| 682 |
+
return canonical_vals, canonical_times, canonical_mask
|
| 683 |
+
|
| 684 |
+
for row in range(batch):
|
| 685 |
+
valid_idx = empirical_mask[row].nonzero(as_tuple=True)[0]
|
| 686 |
+
if valid_idx.numel() == 0:
|
| 687 |
+
continue
|
| 688 |
+
|
| 689 |
+
take_count = min(int(valid_idx.numel()), canonical_cap)
|
| 690 |
+
take_idx = valid_idx[:take_count]
|
| 691 |
+
|
| 692 |
+
canonical_vals[row, :take_count] = empirical_obs[row, take_idx]
|
| 693 |
+
canonical_times[row, :take_count] = empirical_times[row, take_idx]
|
| 694 |
+
canonical_mask[row, :take_count] = True
|
| 695 |
+
|
| 696 |
+
canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
|
| 697 |
+
|
| 698 |
+
return canonical_vals, canonical_times, canonical_mask
|
| 699 |
+
|
| 700 |
+
def _generate_raw(
|
| 701 |
+
self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
|
| 702 |
+
) -> Tuple[
|
| 703 |
+
Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
|
| 704 |
+
]:
|
| 705 |
+
"""Deterministic canonical PK sampling for synthetic simulations."""
|
| 706 |
+
time_scales: Optional[Tensor] = kwargs.get("time_scales")
|
| 707 |
+
if time_scales is None:
|
| 708 |
+
raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
|
| 709 |
+
|
| 710 |
+
canonical_vals, canonical_times, canonical_mask, rescaled = (
|
| 711 |
+
self._align_simulation_to_canonical(
|
| 712 |
+
full_simulation,
|
| 713 |
+
full_simulation_times,
|
| 714 |
+
time_scales=time_scales,
|
| 715 |
+
num_obs_sampler=kwargs.get("num_obs_sampler"),
|
| 716 |
+
)
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 720 |
+
canonical_vals,
|
| 721 |
+
canonical_times,
|
| 722 |
+
canonical_mask,
|
| 723 |
+
generator=kwargs.get("generator"),
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
|
| 727 |
+
|
| 728 |
+
def _generate_random(
|
| 729 |
+
self,
|
| 730 |
+
full_simulation: Tensor,
|
| 731 |
+
full_simulation_times: Tensor,
|
| 732 |
+
*,
|
| 733 |
+
time_scales: Tensor,
|
| 734 |
+
generator: Optional[torch.Generator] = None,
|
| 735 |
+
) -> Tuple[
|
| 736 |
+
Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
|
| 737 |
+
]:
|
| 738 |
+
"""Randomised variant of canonical observation generation.
|
| 739 |
+
|
| 740 |
+
The pre- and post-peak segments are sampled from uniform distributions
|
| 741 |
+
bounded by the canonical limits. This keeps the semantic meaning of the
|
| 742 |
+
selected points while injecting stochasticity that can improve
|
| 743 |
+
robustness during training.
|
| 744 |
+
"""
|
| 745 |
+
device, dtype = full_simulation.device, full_simulation.dtype
|
| 746 |
+
batch = full_simulation.size(0)
|
| 747 |
+
time_steps = int(full_simulation_times.size(1))
|
| 748 |
+
if batch == 0 or time_steps == 0:
|
| 749 |
+
canonical_vals = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 750 |
+
canonical_times = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 751 |
+
canonical_mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
|
| 752 |
+
obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 753 |
+
canonical_vals, canonical_times, canonical_mask, generator=generator
|
| 754 |
+
)
|
| 755 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
|
| 756 |
+
t_peak, t_half = time_scales[0].item(), time_scales[1].item()
|
| 757 |
+
|
| 758 |
+
n_pre = len(self._PEAK_PHASE_MULTIPLIERS)
|
| 759 |
+
n_post = len(self._POST_PEAK_HALF_LIFE_MULTIPLIERS)
|
| 760 |
+
|
| 761 |
+
# Uniform samples before peak
|
| 762 |
+
pre_times = torch.rand(n_pre, device=device, dtype=dtype) * t_peak
|
| 763 |
+
# Always include the peak
|
| 764 |
+
peak_time = torch.tensor([t_peak], device=device, dtype=dtype)
|
| 765 |
+
# Uniform samples after peak
|
| 766 |
+
post_times = []
|
| 767 |
+
for mult in self._POST_PEAK_HALF_LIFE_MULTIPLIERS:
|
| 768 |
+
t_end = t_peak + mult * t_half
|
| 769 |
+
t_rand = torch.empty(1, device=device, dtype=dtype).uniform_(t_peak, t_end)
|
| 770 |
+
post_times.append(t_rand)
|
| 771 |
+
post_times = torch.cat(post_times, dim=0)
|
| 772 |
+
|
| 773 |
+
# Truncate to canonical capacity
|
| 774 |
+
grid = torch.cat([pre_times, peak_time, post_times], dim=0)
|
| 775 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 776 |
+
grid = grid[:canonical_cap]
|
| 777 |
+
|
| 778 |
+
# Map grid to nearest simulation points
|
| 779 |
+
ref_times = full_simulation_times[0]
|
| 780 |
+
idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
|
| 781 |
+
idx, _ = idx.sort()
|
| 782 |
+
valid_mask = torch.ones_like(idx, dtype=torch.bool)
|
| 783 |
+
idx, valid_mask = self._deduplicate_sorted_indices(idx, valid_mask)
|
| 784 |
+
gather_idx = idx[None, :].expand(batch, -1)
|
| 785 |
+
batch_idx = torch.arange(batch, device=device)[:, None]
|
| 786 |
+
|
| 787 |
+
canonical_vals = full_simulation[batch_idx, gather_idx]
|
| 788 |
+
canonical_times = full_simulation_times[batch_idx, gather_idx]
|
| 789 |
+
invalid_slots = ~valid_mask
|
| 790 |
+
if invalid_slots.any():
|
| 791 |
+
canonical_vals[:, invalid_slots] = 0
|
| 792 |
+
canonical_times[:, invalid_slots] = 0
|
| 793 |
+
|
| 794 |
+
canonical_mask = valid_mask[None, :].expand(batch, -1).clone()
|
| 795 |
+
canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
|
| 796 |
+
|
| 797 |
+
obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 798 |
+
canonical_vals, canonical_times, canonical_mask, generator=generator
|
| 799 |
+
)
|
| 800 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
|
| 801 |
+
|
| 802 |
+
def generate(
|
| 803 |
+
self,
|
| 804 |
+
full_simulation: Tensor,
|
| 805 |
+
full_simulation_times: Tensor,
|
| 806 |
+
**kwargs,
|
| 807 |
+
) -> Tuple[
|
| 808 |
+
Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
|
| 809 |
+
]:
|
| 810 |
+
"""Generate PK observations for synthetic simulations.
|
| 811 |
+
|
| 812 |
+
With probability ``randomize_prob`` (default 0.5) the method delegates
|
| 813 |
+
to :meth:`_generate_random`; otherwise the deterministic
|
| 814 |
+
:meth:`_generate_raw` path is taken. Setting ``deterministic_only=True``
|
| 815 |
+
forces the deterministic branch. Both paths require ``time_scales`` and
|
| 816 |
+
honour the ``add_rem`` flag.
|
| 817 |
+
"""
|
| 818 |
+
time_scales: Optional[Tensor] = kwargs.get("time_scales")
|
| 819 |
+
if time_scales is None:
|
| 820 |
+
raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
|
| 821 |
+
|
| 822 |
+
deterministic_only = kwargs.pop("deterministic_only", False)
|
| 823 |
+
|
| 824 |
+
use_random = False
|
| 825 |
+
if not deterministic_only:
|
| 826 |
+
use_random = torch.rand(()) < getattr(self, "randomize_prob", 0.5)
|
| 827 |
+
|
| 828 |
+
if use_random:
|
| 829 |
+
obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_random(
|
| 830 |
+
full_simulation,
|
| 831 |
+
full_simulation_times,
|
| 832 |
+
time_scales=time_scales,
|
| 833 |
+
generator=kwargs.get("generator"),
|
| 834 |
+
)
|
| 835 |
+
else:
|
| 836 |
+
obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_raw(
|
| 837 |
+
full_simulation,
|
| 838 |
+
full_simulation_times,
|
| 839 |
+
**kwargs,
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
if not self.observations_config.add_rem:
|
| 843 |
+
rem_sim = rem_time = rem_mask = None
|
| 844 |
+
|
| 845 |
+
return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
|
| 846 |
+
|
| 847 |
+
def generate_empirical(
|
| 848 |
+
self,
|
| 849 |
+
empirical_obs: Tensor,
|
| 850 |
+
empirical_times: Tensor,
|
| 851 |
+
empirical_mask: Tensor,
|
| 852 |
+
*,
|
| 853 |
+
generator: Optional[torch.Generator] = None,
|
| 854 |
+
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
| 855 |
+
"""Generate observations from empirical data.
|
| 856 |
+
|
| 857 |
+
Empirical measurements are assumed to already live on their correct
|
| 858 |
+
time grid. This routine:
|
| 859 |
+
|
| 860 |
+
* does **not** perform canonical alignment or time normalisation;
|
| 861 |
+
* only pads / truncates sequences to match the internal capacity;
|
| 862 |
+
* applies past/future splitting via :meth:`_assemble_from_canonical`
|
| 863 |
+
using the configuration in :class:`ObservationsConfig`.
|
| 864 |
+
|
| 865 |
+
Synthetic simulations keep using the canonical alignment path.
|
| 866 |
+
"""
|
| 867 |
+
canonical_vals, canonical_times, canonical_mask = self._prepare_empirical_as_canonical(
|
| 868 |
+
empirical_obs,
|
| 869 |
+
empirical_times,
|
| 870 |
+
empirical_mask,
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 874 |
+
canonical_vals,
|
| 875 |
+
canonical_times,
|
| 876 |
+
canonical_mask,
|
| 877 |
+
generator=generator,
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
if not self.observations_config.add_rem:
|
| 881 |
+
rem_sim = rem_time = rem_mask = None
|
| 882 |
+
|
| 883 |
+
return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
class PKPeakHalfLifeStrategyOld(ObservationStrategy):
|
| 887 |
+
"""Observation strategy tailored to pharmacokinetic (PK) curves.
|
| 888 |
+
|
| 889 |
+
The strategy samples observations around the absorption peak and along the
|
| 890 |
+
elimination phase of a PK simulation. It uses a canonical grid composed of
|
| 891 |
+
four segments:
|
| 892 |
+
|
| 893 |
+
1. Several points before the peak that are proportional to the configured
|
| 894 |
+
peak time.
|
| 895 |
+
2. The peak itself.
|
| 896 |
+
3. Several points after the peak spaced by multiples of the provided
|
| 897 |
+
half-life.
|
| 898 |
+
4. Optional remainder points that are handed back to the caller when
|
| 899 |
+
``add_rem`` is enabled.
|
| 900 |
+
|
| 901 |
+
The resulting observation tensor can be optionally split into "past" and
|
| 902 |
+
"future" observations according to :class:`ObservationsConfig`.
|
| 903 |
+
|
| 904 |
+
Parameters
|
| 905 |
+
----------
|
| 906 |
+
observations_config:
|
| 907 |
+
Simulation-level configuration that defines sampling constraints such
|
| 908 |
+
as ``max_num_obs`` or the minimum/maximum number of "past" points when
|
| 909 |
+
a split is requested.
|
| 910 |
+
meta_config:
|
| 911 |
+
Meta-study configuration. Only the ``time_num_steps`` attribute is
|
| 912 |
+
used and allows clamping the canonical grid to the resolution of the
|
| 913 |
+
simulator.
|
| 914 |
+
"""
|
| 915 |
+
|
| 916 |
+
_PEAK_PHASE_MULTIPLIERS = (0.1, 0.2, 0.5, 0.8)
|
| 917 |
+
_POST_PEAK_HALF_LIFE_MULTIPLIERS = (
|
| 918 |
+
0.25,
|
| 919 |
+
0.50,
|
| 920 |
+
1.00,
|
| 921 |
+
2.00,
|
| 922 |
+
4.00,
|
| 923 |
+
6.00,
|
| 924 |
+
8.00,
|
| 925 |
+
9.00,
|
| 926 |
+
14.0,
|
| 927 |
+
19.0,
|
| 928 |
+
30.0,
|
| 929 |
+
)
|
| 930 |
+
_RAW_CANONICAL_POINTS = len(_PEAK_PHASE_MULTIPLIERS) + 1 + len(_POST_PEAK_HALF_LIFE_MULTIPLIERS)
|
| 931 |
+
|
| 932 |
+
def __init__(
|
| 933 |
+
self, observations_config: ObservationsConfig, meta_config: MetaStudyConfig
|
| 934 |
+
) -> None:
|
| 935 |
+
super().__init__(observations_config, meta_config)
|
| 936 |
+
self.max_num_obs = observations_config.max_num_obs
|
| 937 |
+
self.split_past_future = observations_config.split_past_future
|
| 938 |
+
self.min_past = observations_config.min_past
|
| 939 |
+
self.max_past = observations_config.max_past
|
| 940 |
+
self.generative_bias = observations_config.generative_bias
|
| 941 |
+
# ``None`` indicates that the number of past observations should be
|
| 942 |
+
# sampled according to the standard strategy. When populated it forces
|
| 943 |
+
# :meth:`_assemble_from_canonical` to always select the provided number
|
| 944 |
+
# of past observations (within the valid range).
|
| 945 |
+
self._fixed_past_obs_count: Optional[int] = None
|
| 946 |
+
|
| 947 |
+
def fix_past_selection(self, obs_count: int) -> None:
|
| 948 |
+
"""Force the past observation count to ``obs_count`` when splitting.
|
| 949 |
+
|
| 950 |
+
The override is only applied when ``split_past_future`` is enabled. The
|
| 951 |
+
provided ``obs_count`` must fall within ``[min_past, max_past]``.
|
| 952 |
+
"""
|
| 953 |
+
|
| 954 |
+
if not self.split_past_future:
|
| 955 |
+
return
|
| 956 |
+
|
| 957 |
+
if obs_count < self.min_past or obs_count > self.max_past:
|
| 958 |
+
raise ValueError(
|
| 959 |
+
"Fixed past observation count must lie within the configured min/max bounds."
|
| 960 |
+
)
|
| 961 |
+
self._fixed_past_obs_count = int(obs_count)
|
| 962 |
+
|
| 963 |
+
def release_past_selection(self) -> None:
|
| 964 |
+
"""Return to the default random past selection behaviour."""
|
| 965 |
+
|
| 966 |
+
self._fixed_past_obs_count = None
|
| 967 |
+
|
| 968 |
+
@classmethod
|
| 969 |
+
def _build_canonical_grid(
|
| 970 |
+
cls,
|
| 971 |
+
*,
|
| 972 |
+
t_peak: float,
|
| 973 |
+
t_half: float,
|
| 974 |
+
device: torch.device,
|
| 975 |
+
dtype: torch.dtype,
|
| 976 |
+
) -> Tensor:
|
| 977 |
+
"""Construct the canonical grid for a single simulation.
|
| 978 |
+
|
| 979 |
+
The grid covers the pre-peak, peak and post-peak regime of the curve by
|
| 980 |
+
scaling two fundamental quantities supplied at runtime: the time of the
|
| 981 |
+
peak concentration ``t_peak`` and the half-life ``t_half``. Both values
|
| 982 |
+
are expected to be expressed in the same units as the simulation time
|
| 983 |
+
axis.
|
| 984 |
+
|
| 985 |
+
Parameters
|
| 986 |
+
----------
|
| 987 |
+
t_peak:
|
| 988 |
+
Estimated time of the concentration peak.
|
| 989 |
+
t_half:
|
| 990 |
+
Estimated half-life used to position post-peak points.
|
| 991 |
+
device, dtype:
|
| 992 |
+
Torch device and dtype for the returned tensor so that it matches
|
| 993 |
+
the simulation tensors that will be gathered later on.
|
| 994 |
+
|
| 995 |
+
Returns
|
| 996 |
+
-------
|
| 997 |
+
torch.Tensor
|
| 998 |
+
One-dimensional tensor containing monotonically increasing times
|
| 999 |
+
representing the canonical sampling grid.
|
| 1000 |
+
"""
|
| 1001 |
+
before_peak = [mult * t_peak for mult in cls._PEAK_PHASE_MULTIPLIERS]
|
| 1002 |
+
after_peak = [t_peak + mult * t_half for mult in cls._POST_PEAK_HALF_LIFE_MULTIPLIERS]
|
| 1003 |
+
values = before_peak + [t_peak] + after_peak
|
| 1004 |
+
return torch.tensor(values, device=device, dtype=dtype)
|
| 1005 |
+
|
| 1006 |
+
def _canonical_grid_capacity(self) -> int:
|
| 1007 |
+
"""Return the number of canonical grid points available.
|
| 1008 |
+
|
| 1009 |
+
The capacity is the minimum between the simulator resolution and the
|
| 1010 |
+
theoretical number of canonical points. This ensures that the
|
| 1011 |
+
observation tensors never attempt to gather indices outside the
|
| 1012 |
+
original simulation.
|
| 1013 |
+
|
| 1014 |
+
Returns
|
| 1015 |
+
-------
|
| 1016 |
+
int
|
| 1017 |
+
Maximum number of grid points that can be sampled for each
|
| 1018 |
+
simulation in the batch.
|
| 1019 |
+
"""
|
| 1020 |
+
time_steps = getattr(self.meta_config, "time_num_steps", self.max_num_obs)
|
| 1021 |
+
return max(
|
| 1022 |
+
0,
|
| 1023 |
+
min(int(self.max_num_obs), int(time_steps), self._RAW_CANONICAL_POINTS),
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
def _get_shapes_raw(self) -> Tuple[int, int]:
|
| 1027 |
+
"""Compute the maximum number of observation and remainder slots.
|
| 1028 |
+
|
| 1029 |
+
The method applies the canonical grid capacity alongside the
|
| 1030 |
+
``split_past_future`` configuration to decide how many points can be
|
| 1031 |
+
surfaced directly as observations and how many should be exposed as
|
| 1032 |
+
"remaining" (future) points.
|
| 1033 |
+
|
| 1034 |
+
Returns
|
| 1035 |
+
-------
|
| 1036 |
+
tuple[int, int]
|
| 1037 |
+
The first entry is the maximum number of observations. The second
|
| 1038 |
+
entry is the maximum number of remaining observations when
|
| 1039 |
+
``add_rem`` is enabled.
|
| 1040 |
+
|
| 1041 |
+
Raises
|
| 1042 |
+
------
|
| 1043 |
+
ValueError
|
| 1044 |
+
If a past/future split is requested but the canonical capacity
|
| 1045 |
+
cannot satisfy the configured ``min_past`` requirement.
|
| 1046 |
+
"""
|
| 1047 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 1048 |
+
if canonical_cap == 0:
|
| 1049 |
+
return 0, 0
|
| 1050 |
+
|
| 1051 |
+
if self.split_past_future:
|
| 1052 |
+
if canonical_cap < self.min_past:
|
| 1053 |
+
raise ValueError("Canonical grid capacity is smaller than the configured min_past")
|
| 1054 |
+
max_obs = min(self.max_past, canonical_cap)
|
| 1055 |
+
max_rem = max(0, canonical_cap - self.min_past)
|
| 1056 |
+
else:
|
| 1057 |
+
max_obs = canonical_cap
|
| 1058 |
+
max_rem = canonical_cap
|
| 1059 |
+
|
| 1060 |
+
return max_obs, max_rem
|
| 1061 |
+
|
| 1062 |
+
@staticmethod
|
| 1063 |
+
def _deduplicate_sorted_indices(
|
| 1064 |
+
idx: Tensor, valid_mask: Optional[Tensor] = None
|
| 1065 |
+
) -> Tuple[Tensor, Tensor]:
|
| 1066 |
+
"""Collapse repeated gather indices while preserving alignment."""
|
| 1067 |
+
|
| 1068 |
+
if valid_mask is None:
|
| 1069 |
+
valid_mask = torch.ones_like(idx, dtype=torch.bool)
|
| 1070 |
+
|
| 1071 |
+
if idx.numel() <= 1:
|
| 1072 |
+
return idx, valid_mask
|
| 1073 |
+
|
| 1074 |
+
duplicate_mask = torch.zeros_like(idx, dtype=torch.bool)
|
| 1075 |
+
duplicate_mask[1:] = idx[1:] == idx[:-1]
|
| 1076 |
+
|
| 1077 |
+
if not duplicate_mask.any():
|
| 1078 |
+
return idx, valid_mask
|
| 1079 |
+
|
| 1080 |
+
unique_mask = ~duplicate_mask
|
| 1081 |
+
kept_idx = idx[unique_mask]
|
| 1082 |
+
duplicate_idx = idx[duplicate_mask]
|
| 1083 |
+
|
| 1084 |
+
padded_idx = torch.empty_like(idx)
|
| 1085 |
+
padded_idx[: kept_idx.numel()] = kept_idx
|
| 1086 |
+
padded_idx[kept_idx.numel() :] = duplicate_idx
|
| 1087 |
+
|
| 1088 |
+
kept_valid = valid_mask[unique_mask]
|
| 1089 |
+
padded_mask = torch.zeros_like(valid_mask)
|
| 1090 |
+
padded_mask[: kept_valid.numel()] = kept_valid
|
| 1091 |
+
|
| 1092 |
+
return padded_idx, padded_mask
|
| 1093 |
+
|
| 1094 |
+
def _assemble_from_canonical(
|
| 1095 |
+
self,
|
| 1096 |
+
canonical_vals: Tensor,
|
| 1097 |
+
canonical_times: Tensor,
|
| 1098 |
+
canonical_mask: Tensor,
|
| 1099 |
+
*,
|
| 1100 |
+
generator: Optional[torch.Generator] = None,
|
| 1101 |
+
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
| 1102 |
+
"""Convert canonical tensors into output observations.
|
| 1103 |
+
|
| 1104 |
+
The canonical representation stores **all** admissible samples for a
|
| 1105 |
+
batch element. This helper slices the canonical tensors into the
|
| 1106 |
+
"past" observations that will be returned to the caller and (when
|
| 1107 |
+
requested) the "future" remainder. The selection proceeds row by row:
|
| 1108 |
+
|
| 1109 |
+
1. ``canonical_mask`` is inspected to identify the indices that contain
|
| 1110 |
+
valid information. These are the only points that may be surfaced.
|
| 1111 |
+
2. When ``split_past_future`` is ``False`` every valid point is treated
|
| 1112 |
+
as part of the observation history up to the configured capacity.
|
| 1113 |
+
3. Otherwise we randomly draw ``obs_count`` between ``min_past`` and
|
| 1114 |
+
``max_past`` (capped by the number of valid canonical entries). The
|
| 1115 |
+
first ``obs_count`` indices become past observations while the
|
| 1116 |
+
remaining valid points are placed in the remainder tensors.
|
| 1117 |
+
|
| 1118 |
+
Parameters
|
| 1119 |
+
----------
|
| 1120 |
+
canonical_vals, canonical_times:
|
| 1121 |
+
Tensors produced by aligning the simulation or empirical data to
|
| 1122 |
+
the canonical grid.
|
| 1123 |
+
canonical_mask:
|
| 1124 |
+
Boolean tensor marking valid entries for each batch element.
|
| 1125 |
+
generator:
|
| 1126 |
+
Optional random generator used when sampling ``obs_count`` in
|
| 1127 |
+
split-past/future mode.
|
| 1128 |
+
|
| 1129 |
+
Returns
|
| 1130 |
+
-------
|
| 1131 |
+
tuple of tensors
|
| 1132 |
+
Observation and remaining tensors matching the shapes dictated by
|
| 1133 |
+
:meth:`_get_shapes_raw`. All tensors share the same device and
|
| 1134 |
+
dtype as the inputs. ``None`` is returned for remainder tensors
|
| 1135 |
+
when the capacity is zero.
|
| 1136 |
+
"""
|
| 1137 |
+
max_obs, max_rem = self._get_shapes_raw()
|
| 1138 |
+
device = canonical_vals.device
|
| 1139 |
+
dtype = canonical_vals.dtype
|
| 1140 |
+
batch, _ = canonical_vals.shape
|
| 1141 |
+
|
| 1142 |
+
obs_out = torch.zeros(batch, max_obs, dtype=dtype, device=device)
|
| 1143 |
+
obs_time = torch.zeros_like(obs_out)
|
| 1144 |
+
obs_mask = torch.zeros(batch, max_obs, dtype=torch.bool, device=device)
|
| 1145 |
+
|
| 1146 |
+
rem_sim = rem_time = rem_mask = None
|
| 1147 |
+
if max_rem > 0:
|
| 1148 |
+
rem_sim = torch.zeros(batch, max_rem, dtype=dtype, device=device)
|
| 1149 |
+
rem_time = torch.zeros_like(rem_sim)
|
| 1150 |
+
rem_mask = torch.zeros(batch, max_rem, dtype=torch.bool, device=device)
|
| 1151 |
+
|
| 1152 |
+
gen = generator if generator is not None else torch.default_generator
|
| 1153 |
+
|
| 1154 |
+
for row in range(batch):
|
| 1155 |
+
valid_idx = canonical_mask[row].nonzero(as_tuple=True)[0]
|
| 1156 |
+
total_valid = valid_idx.numel()
|
| 1157 |
+
if total_valid == 0:
|
| 1158 |
+
continue
|
| 1159 |
+
|
| 1160 |
+
if self.split_past_future:
|
| 1161 |
+
low = min(self.min_past, total_valid)
|
| 1162 |
+
high = min(self.max_past, total_valid)
|
| 1163 |
+
if self._fixed_past_obs_count is not None:
|
| 1164 |
+
obs_count = min(self._fixed_past_obs_count, total_valid)
|
| 1165 |
+
else:
|
| 1166 |
+
obs_count = _sample_past_count_with_bias(
|
| 1167 |
+
low=low,
|
| 1168 |
+
high=high,
|
| 1169 |
+
generative_bias=self.generative_bias,
|
| 1170 |
+
generator=gen,
|
| 1171 |
+
device=device,
|
| 1172 |
+
)
|
| 1173 |
+
obs_count = min(obs_count, max_obs)
|
| 1174 |
+
else:
|
| 1175 |
+
obs_count = min(total_valid, max_obs)
|
| 1176 |
+
|
| 1177 |
+
if obs_count > 0:
|
| 1178 |
+
take = valid_idx[:obs_count]
|
| 1179 |
+
obs_out[row, :obs_count] = canonical_vals[row, take]
|
| 1180 |
+
obs_time[row, :obs_count] = canonical_times[row, take]
|
| 1181 |
+
obs_mask[row, :obs_count] = True
|
| 1182 |
+
|
| 1183 |
+
if rem_sim is not None:
|
| 1184 |
+
rem_candidates = valid_idx[obs_count:]
|
| 1185 |
+
rem_count = min(rem_candidates.numel(), max_rem)
|
| 1186 |
+
if rem_count > 0:
|
| 1187 |
+
rem_idx = rem_candidates[:rem_count]
|
| 1188 |
+
rem_sim[row, :rem_count] = canonical_vals[row, rem_idx]
|
| 1189 |
+
rem_time[row, :rem_count] = canonical_times[row, rem_idx]
|
| 1190 |
+
rem_mask[row, :rem_count] = True
|
| 1191 |
+
|
| 1192 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask
|
| 1193 |
+
|
| 1194 |
+
def _align_simulation_to_canonical(
|
| 1195 |
+
self,
|
| 1196 |
+
full_simulation: Tensor,
|
| 1197 |
+
full_simulation_times: Tensor,
|
| 1198 |
+
*,
|
| 1199 |
+
time_scales: Tensor,
|
| 1200 |
+
num_obs_sampler: Optional[Callable[[int], Tensor]] = None,
|
| 1201 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
| 1202 |
+
"""Gather the canonical samples from a simulated PK curve.
|
| 1203 |
+
|
| 1204 |
+
The routine creates the canonical grid described in the configuration
|
| 1205 |
+
(using the provided ``time_scales``) and then performs a nearest-neighbour
|
| 1206 |
+
lookup on the simulated trajectory. Each grid location picks the
|
| 1207 |
+
closest time point from the reference simulation (the first batch row);
|
| 1208 |
+
the same indices are applied to every batch element so that values and
|
| 1209 |
+
times remain aligned across the batch. ``num_obs_sampler`` can further
|
| 1210 |
+
prune the resulting grid by specifying how many of those canonical
|
| 1211 |
+
points should remain valid for each row.
|
| 1212 |
+
|
| 1213 |
+
Parameters
|
| 1214 |
+
----------
|
| 1215 |
+
full_simulation, full_simulation_times:
|
| 1216 |
+
Batched tensors representing the simulated concentration curve and
|
| 1217 |
+
its time axis.
|
| 1218 |
+
time_scales:
|
| 1219 |
+
Two-element tensor with ``t_peak`` and ``t_half`` scaling factors.
|
| 1220 |
+
num_obs_sampler:
|
| 1221 |
+
Optional callable that samples how many canonical points should be
|
| 1222 |
+
retained for each batch element.
|
| 1223 |
+
|
| 1224 |
+
Returns
|
| 1225 |
+
-------
|
| 1226 |
+
tuple of torch.Tensor
|
| 1227 |
+
The canonical values, their corresponding times, a boolean mask of
|
| 1228 |
+
valid entries and the (cloned) ``time_scales`` tensor. When the
|
| 1229 |
+
canonical capacity is zero, zero-sized tensors are returned for the
|
| 1230 |
+
first three entries.
|
| 1231 |
+
"""
|
| 1232 |
+
device = full_simulation.device
|
| 1233 |
+
dtype = full_simulation.dtype
|
| 1234 |
+
batch, _ = full_simulation.shape
|
| 1235 |
+
time_steps = int(full_simulation_times.size(1))
|
| 1236 |
+
|
| 1237 |
+
# Empty worker slices (B=0) and zero-step trajectories are valid edge
|
| 1238 |
+
# cases; return empty canonical tensors and keep shape assembly
|
| 1239 |
+
# delegated to _assemble_from_canonical.
|
| 1240 |
+
if batch == 0 or time_steps == 0:
|
| 1241 |
+
zero = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 1242 |
+
mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
|
| 1243 |
+
return zero, zero, mask, time_scales.clone()
|
| 1244 |
+
|
| 1245 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 1246 |
+
if canonical_cap == 0:
|
| 1247 |
+
zero = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 1248 |
+
mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
|
| 1249 |
+
return zero, zero, mask, time_scales.clone()
|
| 1250 |
+
|
| 1251 |
+
grid = self._build_canonical_grid(
|
| 1252 |
+
t_peak=time_scales[0].item(),
|
| 1253 |
+
t_half=time_scales[1].item(),
|
| 1254 |
+
device=device,
|
| 1255 |
+
dtype=dtype,
|
| 1256 |
+
)[:canonical_cap]
|
| 1257 |
+
|
| 1258 |
+
ref_times = full_simulation_times[0]
|
| 1259 |
+
min_time = ref_times.min()
|
| 1260 |
+
max_time = ref_times.max()
|
| 1261 |
+
grid_valid_mask = (grid >= min_time) & (grid <= max_time)
|
| 1262 |
+
|
| 1263 |
+
idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
|
| 1264 |
+
idx, order = idx.sort()
|
| 1265 |
+
grid_valid_mask = grid_valid_mask[order]
|
| 1266 |
+
idx, grid_valid_mask = self._deduplicate_sorted_indices(idx, grid_valid_mask)
|
| 1267 |
+
|
| 1268 |
+
gather_idx = idx[None, :].expand(batch, -1)
|
| 1269 |
+
batch_idx = torch.arange(batch, device=device)[:, None]
|
| 1270 |
+
|
| 1271 |
+
canonical_vals = full_simulation[batch_idx, gather_idx]
|
| 1272 |
+
canonical_times = full_simulation_times[batch_idx, gather_idx]
|
| 1273 |
+
|
| 1274 |
+
invalid_slots = ~grid_valid_mask
|
| 1275 |
+
if invalid_slots.any():
|
| 1276 |
+
canonical_vals[:, invalid_slots] = 0
|
| 1277 |
+
canonical_times[:, invalid_slots] = 0
|
| 1278 |
+
|
| 1279 |
+
if num_obs_sampler is None:
|
| 1280 |
+
total_counts = torch.full((batch,), canonical_cap, dtype=torch.long, device=device)
|
| 1281 |
+
else:
|
| 1282 |
+
sampled = num_obs_sampler(batch).to(device=device).long()
|
| 1283 |
+
total_counts = sampled.clamp(min=0, max=canonical_cap)
|
| 1284 |
+
|
| 1285 |
+
max_valid = int(grid_valid_mask.sum().item())
|
| 1286 |
+
if max_valid == 0:
|
| 1287 |
+
total_counts.zero_()
|
| 1288 |
+
else:
|
| 1289 |
+
total_counts.clamp_(max=max_valid)
|
| 1290 |
+
|
| 1291 |
+
valid_order = grid_valid_mask.long().cumsum(dim=0) - 1
|
| 1292 |
+
valid_order = torch.where(
|
| 1293 |
+
grid_valid_mask,
|
| 1294 |
+
valid_order,
|
| 1295 |
+
torch.full_like(valid_order, -1, dtype=valid_order.dtype),
|
| 1296 |
+
)
|
| 1297 |
+
canonical_mask = grid_valid_mask[None, :] & (valid_order[None, :] < total_counts[:, None])
|
| 1298 |
+
canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
|
| 1299 |
+
|
| 1300 |
+
return canonical_vals, canonical_times, canonical_mask, time_scales.clone()
|
| 1301 |
+
|
| 1302 |
+
def _align_empirical_to_canonical(
|
| 1303 |
+
self,
|
| 1304 |
+
empirical_obs: Tensor,
|
| 1305 |
+
empirical_times: Tensor,
|
| 1306 |
+
empirical_mask: Tensor,
|
| 1307 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 1308 |
+
"""Project empirical observations onto the canonical grid.
|
| 1309 |
+
|
| 1310 |
+
The projection normalises the empirical time axis to estimate the peak
|
| 1311 |
+
and half-life from the data itself. This allows harmonising real
|
| 1312 |
+
measurements with the canonical layout used during simulation-driven
|
| 1313 |
+
training.
|
| 1314 |
+
|
| 1315 |
+
Parameters
|
| 1316 |
+
----------
|
| 1317 |
+
empirical_obs, empirical_times, empirical_mask:
|
| 1318 |
+
Batched tensors storing empirical observations, the corresponding
|
| 1319 |
+
time stamps and a mask of valid entries.
|
| 1320 |
+
|
| 1321 |
+
Returns
|
| 1322 |
+
-------
|
| 1323 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
| 1324 |
+
Canonical values, times and boolean masks aligned to the canonical
|
| 1325 |
+
sampling scheme.
|
| 1326 |
+
"""
|
| 1327 |
+
device = empirical_obs.device
|
| 1328 |
+
dtype = empirical_obs.dtype
|
| 1329 |
+
batch, _ = empirical_obs.shape
|
| 1330 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 1331 |
+
|
| 1332 |
+
canonical_vals = torch.zeros(batch, canonical_cap, dtype=dtype, device=device)
|
| 1333 |
+
canonical_times = torch.zeros_like(canonical_vals)
|
| 1334 |
+
canonical_mask = torch.zeros(batch, canonical_cap, dtype=torch.bool, device=device)
|
| 1335 |
+
|
| 1336 |
+
if canonical_cap == 0:
|
| 1337 |
+
return canonical_vals, canonical_times, canonical_mask
|
| 1338 |
+
|
| 1339 |
+
for row in range(batch):
|
| 1340 |
+
valid_idx = empirical_mask[row].nonzero(as_tuple=True)[0]
|
| 1341 |
+
if valid_idx.numel() == 0:
|
| 1342 |
+
continue
|
| 1343 |
+
|
| 1344 |
+
obs_row = empirical_obs[row, valid_idx]
|
| 1345 |
+
time_row = empirical_times[row, valid_idx]
|
| 1346 |
+
max_time = torch.maximum(time_row.max(), torch.tensor(1.0, device=device))
|
| 1347 |
+
norm_time = time_row / max_time
|
| 1348 |
+
|
| 1349 |
+
peak_idx = obs_row.argmax().item()
|
| 1350 |
+
t_peak = norm_time[peak_idx].item()
|
| 1351 |
+
post_times = norm_time[peak_idx:]
|
| 1352 |
+
post_obs = obs_row[peak_idx:]
|
| 1353 |
+
half_level = obs_row[peak_idx] / 2
|
| 1354 |
+
below_half = (post_obs <= half_level).nonzero(as_tuple=True)[0]
|
| 1355 |
+
if below_half.numel() == 0:
|
| 1356 |
+
half_time = post_times[-1].item()
|
| 1357 |
+
else:
|
| 1358 |
+
half_time = post_times[below_half[0]].item()
|
| 1359 |
+
t_half = max(half_time - t_peak, 1e-3)
|
| 1360 |
+
|
| 1361 |
+
grid = self._build_canonical_grid(
|
| 1362 |
+
t_peak=t_peak if t_peak > 0 else 1e-3,
|
| 1363 |
+
t_half=t_half,
|
| 1364 |
+
device=device,
|
| 1365 |
+
dtype=dtype,
|
| 1366 |
+
)[:canonical_cap].clamp(max=1.0)
|
| 1367 |
+
|
| 1368 |
+
actual_grid = grid * max_time
|
| 1369 |
+
distances = torch.cdist(actual_grid[:, None], time_row[:, None])
|
| 1370 |
+
nearest = distances.argmin(dim=1)
|
| 1371 |
+
|
| 1372 |
+
usable = min(time_row.numel(), grid.numel())
|
| 1373 |
+
if usable == 0:
|
| 1374 |
+
continue
|
| 1375 |
+
|
| 1376 |
+
canonical_vals[row, :usable] = obs_row[nearest[:usable]]
|
| 1377 |
+
canonical_times[row, :usable] = time_row[nearest[:usable]]
|
| 1378 |
+
canonical_mask[row, :usable] = True
|
| 1379 |
+
|
| 1380 |
+
canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
|
| 1381 |
+
|
| 1382 |
+
return canonical_vals, canonical_times, canonical_mask
|
| 1383 |
+
|
| 1384 |
+
def _generate_raw(
|
| 1385 |
+
self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
|
| 1386 |
+
) -> Tuple[
|
| 1387 |
+
Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
|
| 1388 |
+
]:
|
| 1389 |
+
time_scales: Optional[Tensor] = kwargs.get("time_scales")
|
| 1390 |
+
if time_scales is None:
|
| 1391 |
+
raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
|
| 1392 |
+
|
| 1393 |
+
canonical_vals, canonical_times, canonical_mask, rescaled = (
|
| 1394 |
+
self._align_simulation_to_canonical(
|
| 1395 |
+
full_simulation,
|
| 1396 |
+
full_simulation_times,
|
| 1397 |
+
time_scales=time_scales,
|
| 1398 |
+
num_obs_sampler=kwargs.get("num_obs_sampler"),
|
| 1399 |
+
)
|
| 1400 |
+
)
|
| 1401 |
+
|
| 1402 |
+
obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 1403 |
+
canonical_vals,
|
| 1404 |
+
canonical_times,
|
| 1405 |
+
canonical_mask,
|
| 1406 |
+
generator=kwargs.get("generator"),
|
| 1407 |
+
)
|
| 1408 |
+
|
| 1409 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
|
| 1410 |
+
|
| 1411 |
+
def _generate_random(
|
| 1412 |
+
self,
|
| 1413 |
+
full_simulation: Tensor,
|
| 1414 |
+
full_simulation_times: Tensor,
|
| 1415 |
+
*,
|
| 1416 |
+
time_scales: Tensor,
|
| 1417 |
+
generator: Optional[torch.Generator] = None,
|
| 1418 |
+
) -> Tuple[
|
| 1419 |
+
Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
|
| 1420 |
+
]:
|
| 1421 |
+
"""Randomized variant of canonical observation generation.
|
| 1422 |
+
|
| 1423 |
+
Instead of fixed multipliers, the pre- and post-peak segments are
|
| 1424 |
+
sampled from uniform distributions bounded by the canonical limits.
|
| 1425 |
+
This keeps the semantic meaning of the selected points while injecting
|
| 1426 |
+
stochasticity that improves robustness when training amortised
|
| 1427 |
+
inference models.
|
| 1428 |
+
"""
|
| 1429 |
+
device, dtype = full_simulation.device, full_simulation.dtype
|
| 1430 |
+
batch = full_simulation.size(0)
|
| 1431 |
+
time_steps = int(full_simulation_times.size(1))
|
| 1432 |
+
if batch == 0 or time_steps == 0:
|
| 1433 |
+
canonical_vals = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 1434 |
+
canonical_times = torch.zeros(batch, 0, dtype=dtype, device=device)
|
| 1435 |
+
canonical_mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
|
| 1436 |
+
obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 1437 |
+
canonical_vals, canonical_times, canonical_mask, generator=generator
|
| 1438 |
+
)
|
| 1439 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
|
| 1440 |
+
t_peak, t_half = time_scales[0].item(), time_scales[1].item()
|
| 1441 |
+
|
| 1442 |
+
n_pre = len(self._PEAK_PHASE_MULTIPLIERS)
|
| 1443 |
+
n_post = len(self._POST_PEAK_HALF_LIFE_MULTIPLIERS)
|
| 1444 |
+
|
| 1445 |
+
# Uniform samples before peak
|
| 1446 |
+
pre_times = torch.rand(n_pre, device=device, dtype=dtype) * t_peak
|
| 1447 |
+
# Always include the peak
|
| 1448 |
+
peak_time = torch.tensor([t_peak], device=device, dtype=dtype)
|
| 1449 |
+
# Uniform samples after peak
|
| 1450 |
+
post_times = []
|
| 1451 |
+
for mult in self._POST_PEAK_HALF_LIFE_MULTIPLIERS:
|
| 1452 |
+
t_end = t_peak + mult * t_half
|
| 1453 |
+
t_rand = torch.empty(1, device=device, dtype=dtype).uniform_(t_peak, t_end)
|
| 1454 |
+
post_times.append(t_rand)
|
| 1455 |
+
post_times = torch.cat(post_times, dim=0)
|
| 1456 |
+
|
| 1457 |
+
# Truncate to canonical capacity
|
| 1458 |
+
grid = torch.cat([pre_times, peak_time, post_times], dim=0)
|
| 1459 |
+
canonical_cap = self._canonical_grid_capacity()
|
| 1460 |
+
grid = grid[:canonical_cap]
|
| 1461 |
+
|
| 1462 |
+
# Map grid to nearest simulation points
|
| 1463 |
+
ref_times = full_simulation_times[0]
|
| 1464 |
+
idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
|
| 1465 |
+
idx, _ = idx.sort()
|
| 1466 |
+
valid_mask = torch.ones_like(idx, dtype=torch.bool)
|
| 1467 |
+
idx, valid_mask = self._deduplicate_sorted_indices(idx, valid_mask)
|
| 1468 |
+
gather_idx = idx[None, :].expand(batch, -1)
|
| 1469 |
+
batch_idx = torch.arange(batch, device=device)[:, None]
|
| 1470 |
+
|
| 1471 |
+
canonical_vals = full_simulation[batch_idx, gather_idx]
|
| 1472 |
+
canonical_times = full_simulation_times[batch_idx, gather_idx]
|
| 1473 |
+
invalid_slots = ~valid_mask
|
| 1474 |
+
if invalid_slots.any():
|
| 1475 |
+
canonical_vals[:, invalid_slots] = 0
|
| 1476 |
+
canonical_times[:, invalid_slots] = 0
|
| 1477 |
+
|
| 1478 |
+
canonical_mask = valid_mask[None, :].expand(batch, -1).clone()
|
| 1479 |
+
canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
|
| 1480 |
+
|
| 1481 |
+
obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 1482 |
+
canonical_vals, canonical_times, canonical_mask, generator=generator
|
| 1483 |
+
)
|
| 1484 |
+
return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
|
| 1485 |
+
|
| 1486 |
+
def generate(
|
| 1487 |
+
self,
|
| 1488 |
+
full_simulation: Tensor,
|
| 1489 |
+
full_simulation_times: Tensor,
|
| 1490 |
+
**kwargs,
|
| 1491 |
+
) -> Tuple[
|
| 1492 |
+
Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
|
| 1493 |
+
]:
|
| 1494 |
+
"""Generate PK observations using canonical or randomized schedules.
|
| 1495 |
+
|
| 1496 |
+
With probability ``randomize_prob`` (default 0.5) the method delegates
|
| 1497 |
+
to :meth:`_generate_random`; otherwise the deterministic
|
| 1498 |
+
:meth:`_generate_raw` path is taken. Setting the keyword argument
|
| 1499 |
+
``deterministic_only=True`` forces the deterministic branch regardless
|
| 1500 |
+
of the random draw. Both paths require the caller to provide
|
| 1501 |
+
``time_scales`` specifying the peak and half-life. The method honours
|
| 1502 |
+
the ``add_rem`` flag by optionally returning remainder tensors.
|
| 1503 |
+
"""
|
| 1504 |
+
time_scales: Optional[Tensor] = kwargs.get("time_scales")
|
| 1505 |
+
if time_scales is None:
|
| 1506 |
+
raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
|
| 1507 |
+
|
| 1508 |
+
deterministic_only = kwargs.pop("deterministic_only", False)
|
| 1509 |
+
|
| 1510 |
+
use_random = False
|
| 1511 |
+
if not deterministic_only:
|
| 1512 |
+
use_random = torch.rand(()) < getattr(self, "randomize_prob", 0.5)
|
| 1513 |
+
|
| 1514 |
+
if use_random:
|
| 1515 |
+
obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_random(
|
| 1516 |
+
full_simulation,
|
| 1517 |
+
full_simulation_times,
|
| 1518 |
+
time_scales=time_scales,
|
| 1519 |
+
generator=kwargs.get("generator"),
|
| 1520 |
+
)
|
| 1521 |
+
else:
|
| 1522 |
+
obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_raw(
|
| 1523 |
+
full_simulation,
|
| 1524 |
+
full_simulation_times,
|
| 1525 |
+
**kwargs,
|
| 1526 |
+
)
|
| 1527 |
+
|
| 1528 |
+
if not self.observations_config.add_rem:
|
| 1529 |
+
rem_sim = rem_time = rem_mask = None
|
| 1530 |
+
|
| 1531 |
+
return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
|
| 1532 |
+
|
| 1533 |
+
def generate_empirical(
|
| 1534 |
+
self,
|
| 1535 |
+
empirical_obs: Tensor,
|
| 1536 |
+
empirical_times: Tensor,
|
| 1537 |
+
empirical_mask: Tensor,
|
| 1538 |
+
*,
|
| 1539 |
+
generator: Optional[torch.Generator] = None,
|
| 1540 |
+
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
| 1541 |
+
canonical_vals, canonical_times, canonical_mask = self._align_empirical_to_canonical(
|
| 1542 |
+
empirical_obs,
|
| 1543 |
+
empirical_times,
|
| 1544 |
+
empirical_mask,
|
| 1545 |
+
)
|
| 1546 |
+
|
| 1547 |
+
obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
|
| 1548 |
+
canonical_vals,
|
| 1549 |
+
canonical_times,
|
| 1550 |
+
canonical_mask,
|
| 1551 |
+
generator=generator,
|
| 1552 |
+
)
|
| 1553 |
+
|
| 1554 |
+
if not self.observations_config.add_rem:
|
| 1555 |
+
rem_sim = rem_time = rem_mask = None
|
| 1556 |
+
|
| 1557 |
+
return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask
|
| 1558 |
+
|
| 1559 |
+
|
| 1560 |
+
class FixPastTimeRandomSelectionStrategy(ObservationStrategy):
|
| 1561 |
+
"""Randomly sample observations and split with fixed-capacity past/future slots.
|
| 1562 |
+
|
| 1563 |
+
For ``split_past_future=True`` this strategy enforces the contract:
|
| 1564 |
+
``obs_capacity=max_past`` and ``rem_capacity=max_num_obs-max_past``
|
| 1565 |
+
(subject to ``fixed_M_max=min(max_num_obs, time_num_steps)``).
|
| 1566 |
+
"""
|
| 1567 |
+
|
| 1568 |
+
def __init__(self, config: ObservationsConfig, meta_config: MetaStudyConfig):
|
| 1569 |
+
super().__init__(config, meta_config)
|
| 1570 |
+
time_steps = getattr(meta_config, "time_num_steps", config.max_num_obs)
|
| 1571 |
+
self.fixed_M_max = min(config.max_num_obs, time_steps)
|
| 1572 |
+
self.split_past_future = config.split_past_future
|
| 1573 |
+
self.max_past = config.max_past
|
| 1574 |
+
self.min_past = config.min_past
|
| 1575 |
+
self.generative_bias = config.generative_bias
|
| 1576 |
+
self.boundary_ratio = getattr(config, "past_time_ratio", 0.1)
|
| 1577 |
+
|
| 1578 |
+
def _generate_raw(self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs):
|
| 1579 |
+
return fix_past_time_random_selection(
|
| 1580 |
+
full_simulation=full_simulation,
|
| 1581 |
+
full_simulation_times=full_simulation_times,
|
| 1582 |
+
boundary_ratio=self.boundary_ratio,
|
| 1583 |
+
fixed_M_max=self.fixed_M_max,
|
| 1584 |
+
num_obs_sampler=kwargs.get("num_obs_sampler", None),
|
| 1585 |
+
generator=kwargs.get("generator", None),
|
| 1586 |
+
)
|
| 1587 |
+
|
| 1588 |
+
def _get_shapes_raw(self) -> Tuple[int, int]:
|
| 1589 |
+
"""Return fixed-capacity shapes for random split outputs.
|
| 1590 |
+
|
| 1591 |
+
With ``split_past_future=True``:
|
| 1592 |
+
- ``max_obs`` is bounded by ``max_past``
|
| 1593 |
+
- ``max_rem`` is bounded by ``max_num_obs - max_past``
|
| 1594 |
+
"""
|
| 1595 |
+
if self.split_past_future:
|
| 1596 |
+
if self.min_past is None or self.max_past is None:
|
| 1597 |
+
raise ValueError(
|
| 1598 |
+
"min_past and max_past must be specified when split_past_future=True"
|
| 1599 |
+
)
|
| 1600 |
+
if self.fixed_M_max < self.min_past:
|
| 1601 |
+
raise ValueError("fixed_M_max is smaller than the configured min_past")
|
| 1602 |
+
max_obs = min(self.max_past, self.fixed_M_max)
|
| 1603 |
+
max_rem = max(0, self.fixed_M_max - self.max_past)
|
| 1604 |
+
else:
|
| 1605 |
+
max_obs = self.fixed_M_max
|
| 1606 |
+
max_rem = self.fixed_M_max
|
| 1607 |
+
|
| 1608 |
+
return max_obs, max_rem
|
| 1609 |
+
|
| 1610 |
+
def _split_by_boundary(
|
| 1611 |
+
self,
|
| 1612 |
+
obs: TensorType["B", "M"],
|
| 1613 |
+
obs_time: TensorType["B", "M"],
|
| 1614 |
+
obs_mask: TensorType["B", "M"],
|
| 1615 |
+
*,
|
| 1616 |
+
generator: Optional[torch.Generator] = None,
|
| 1617 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
| 1618 |
+
"""Split sampled observations into strict past and future blocks.
|
| 1619 |
+
|
| 1620 |
+
The split is boundary-based and strict:
|
| 1621 |
+
- Past block samples ``k`` points from ``time <= boundary`` candidates,
|
| 1622 |
+
where ``k`` follows ``min_past``/``max_past`` (and ``generative_bias``),
|
| 1623 |
+
capped by available candidates and ``K_max``.
|
| 1624 |
+
- When ``k > 0``, remainder receives up to ``R_cap`` points sampled
|
| 1625 |
+
from ``time > boundary`` only (strict future).
|
| 1626 |
+
- When ``k == 0``, boundary splitting is ignored for remainder and
|
| 1627 |
+
points are sampled from all valid candidates.
|
| 1628 |
+
|
| 1629 |
+
Extra past/future candidates are ignored, and missing entries are
|
| 1630 |
+
padded by zeros with mask=False.
|
| 1631 |
+
"""
|
| 1632 |
+
B, M = obs.shape
|
| 1633 |
+
# K_max: capacity of the past block [B, K_max]
|
| 1634 |
+
K_max = min(int(self.max_past), int(M))
|
| 1635 |
+
K_min = min(int(self.min_past), K_max)
|
| 1636 |
+
# R_cap: fixed capacity of the remainder block [B, R_cap]
|
| 1637 |
+
R_cap = max(0, int(M) - K_max)
|
| 1638 |
+
|
| 1639 |
+
boundary = self.meta_config.time_stop * self.boundary_ratio
|
| 1640 |
+
gen = generator if generator is not None else torch.default_generator
|
| 1641 |
+
|
| 1642 |
+
past_obs = torch.zeros(B, K_max, dtype=obs.dtype, device=obs.device)
|
| 1643 |
+
past_time = torch.zeros_like(past_obs)
|
| 1644 |
+
past_mask = torch.zeros(B, K_max, dtype=torch.bool, device=obs.device)
|
| 1645 |
+
|
| 1646 |
+
rem_obs = torch.zeros(B, R_cap, dtype=obs.dtype, device=obs.device)
|
| 1647 |
+
rem_time = torch.zeros_like(rem_obs)
|
| 1648 |
+
rem_mask = torch.zeros(B, R_cap, dtype=torch.bool, device=obs.device)
|
| 1649 |
+
|
| 1650 |
+
for b in range(B):
|
| 1651 |
+
valid_idx = obs_mask[b].nonzero(as_tuple=True)[0]
|
| 1652 |
+
past_candidates = valid_idx[obs_time[b, valid_idx] <= boundary]
|
| 1653 |
+
future_candidates = valid_idx[obs_time[b, valid_idx] > boundary]
|
| 1654 |
+
|
| 1655 |
+
if past_candidates.numel() > 1:
|
| 1656 |
+
order = torch.argsort(obs_time[b, past_candidates])
|
| 1657 |
+
past_candidates = past_candidates[order]
|
| 1658 |
+
if future_candidates.numel() > 1:
|
| 1659 |
+
order = torch.argsort(obs_time[b, future_candidates])
|
| 1660 |
+
future_candidates = future_candidates[order]
|
| 1661 |
+
|
| 1662 |
+
# Past is sampled uniformly without replacement from pre-boundary points.
|
| 1663 |
+
k_high = min(K_max, int(past_candidates.numel()))
|
| 1664 |
+
k_low = min(K_min, k_high)
|
| 1665 |
+
k = _sample_past_count_with_bias(
|
| 1666 |
+
low=int(k_low),
|
| 1667 |
+
high=int(k_high),
|
| 1668 |
+
generative_bias=self.generative_bias,
|
| 1669 |
+
generator=gen,
|
| 1670 |
+
device=obs.device,
|
| 1671 |
+
)
|
| 1672 |
+
if k > 0 and past_candidates.numel() > 0:
|
| 1673 |
+
chosen_offsets = torch.randperm(
|
| 1674 |
+
past_candidates.numel(),
|
| 1675 |
+
generator=gen,
|
| 1676 |
+
device=obs.device,
|
| 1677 |
+
)[:k]
|
| 1678 |
+
chosen_past = past_candidates[chosen_offsets]
|
| 1679 |
+
chosen_order = torch.argsort(obs_time[b, chosen_past])
|
| 1680 |
+
chosen_past = chosen_past[chosen_order]
|
| 1681 |
+
else:
|
| 1682 |
+
chosen_past = past_candidates[:0]
|
| 1683 |
+
|
| 1684 |
+
num_past = chosen_past.numel()
|
| 1685 |
+
if num_past > 0:
|
| 1686 |
+
past_obs[b, :num_past] = obs[b, chosen_past]
|
| 1687 |
+
past_time[b, :num_past] = obs_time[b, chosen_past]
|
| 1688 |
+
past_mask[b, :num_past] = True
|
| 1689 |
+
|
| 1690 |
+
# If no past point is selected, allow remainder sampling across the
|
| 1691 |
+
# whole valid domain. Otherwise keep strict future-only remainder.
|
| 1692 |
+
rem_pool = valid_idx if num_past == 0 else future_candidates
|
| 1693 |
+
if rem_pool.numel() > 1:
|
| 1694 |
+
order = torch.argsort(obs_time[b, rem_pool])
|
| 1695 |
+
rem_pool = rem_pool[order]
|
| 1696 |
+
|
| 1697 |
+
if R_cap <= 0 or rem_pool.numel() == 0:
|
| 1698 |
+
chosen_rem = rem_pool[:0]
|
| 1699 |
+
elif rem_pool.numel() <= R_cap:
|
| 1700 |
+
chosen_rem = rem_pool
|
| 1701 |
+
else:
|
| 1702 |
+
chosen_offsets = torch.randperm(
|
| 1703 |
+
rem_pool.numel(),
|
| 1704 |
+
generator=gen,
|
| 1705 |
+
device=obs.device,
|
| 1706 |
+
)[:R_cap]
|
| 1707 |
+
chosen_rem = rem_pool[chosen_offsets]
|
| 1708 |
+
chosen_order = torch.argsort(obs_time[b, chosen_rem])
|
| 1709 |
+
chosen_rem = chosen_rem[chosen_order]
|
| 1710 |
+
|
| 1711 |
+
r = chosen_rem.numel()
|
| 1712 |
+
if r > 0:
|
| 1713 |
+
rem_obs[b, :r] = obs[b, chosen_rem]
|
| 1714 |
+
rem_time[b, :r] = obs_time[b, chosen_rem]
|
| 1715 |
+
rem_mask[b, :r] = True
|
| 1716 |
+
|
| 1717 |
+
return past_obs, past_time, past_mask, rem_obs, rem_time, rem_mask
|
| 1718 |
+
|
| 1719 |
+
def generate(
|
| 1720 |
+
self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
|
| 1721 |
+
) -> Tuple[Tensor, ...]:
|
| 1722 |
+
obs, obs_time, obs_mask, _, _, _ = self._generate_raw(
|
| 1723 |
+
full_simulation, full_simulation_times, **kwargs
|
| 1724 |
+
)
|
| 1725 |
+
obs_mask = self._drop_non_positive_times_from_mask(obs_time, obs_mask)
|
| 1726 |
+
|
| 1727 |
+
if self.split_past_future:
|
| 1728 |
+
out = self._split_by_boundary(
|
| 1729 |
+
obs,
|
| 1730 |
+
obs_time,
|
| 1731 |
+
obs_mask,
|
| 1732 |
+
generator=kwargs.get("generator", None),
|
| 1733 |
+
)
|
| 1734 |
+
else:
|
| 1735 |
+
past_obs, past_time, past_mask = obs, obs_time, obs_mask
|
| 1736 |
+
rem_obs = rem_time = rem_mask = None
|
| 1737 |
+
out = (past_obs, past_time, past_mask, rem_obs, rem_time, rem_mask)
|
| 1738 |
+
|
| 1739 |
+
if not self.observations_config.add_rem:
|
| 1740 |
+
out = out[:3] + (None, None, None)
|
| 1741 |
+
|
| 1742 |
+
return (*out, None)
|
| 1743 |
+
|
| 1744 |
+
|
| 1745 |
+
class ObservationStrategyFactory:
|
| 1746 |
+
@staticmethod
|
| 1747 |
+
def from_config(
|
| 1748 |
+
obs_config: ObservationsConfig, meta_config: MetaStudyConfig
|
| 1749 |
+
) -> ObservationStrategy:
|
| 1750 |
+
# Legacy compatibility:
|
| 1751 |
+
# - omitted ``type`` defaults via dataclass to ``pk_peak_half_life``
|
| 1752 |
+
# - explicit YAML ``type: null`` is loaded as ``None`` and also falls
|
| 1753 |
+
# back to ``pk_peak_half_life``
|
| 1754 |
+
strategy_type = getattr(obs_config, "type", None)
|
| 1755 |
+
if strategy_type is None:
|
| 1756 |
+
normalized_type = "pk_peak_half_life"
|
| 1757 |
+
elif isinstance(strategy_type, str):
|
| 1758 |
+
stripped = strategy_type.strip()
|
| 1759 |
+
if stripped == "" or stripped.lower() in {"null", "none"}:
|
| 1760 |
+
normalized_type = "pk_peak_half_life"
|
| 1761 |
+
else:
|
| 1762 |
+
normalized_type = stripped.lower()
|
| 1763 |
+
else:
|
| 1764 |
+
normalized_type = str(strategy_type).strip().lower()
|
| 1765 |
+
|
| 1766 |
+
if normalized_type in {
|
| 1767 |
+
"observations_pk_peak_halflife",
|
| 1768 |
+
"pk_peak_half_life",
|
| 1769 |
+
}:
|
| 1770 |
+
return PKPeakHalfLifeStrategy(obs_config, meta_config)
|
| 1771 |
+
if normalized_type in {
|
| 1772 |
+
"fix_past_time_random_selection",
|
| 1773 |
+
"random",
|
| 1774 |
+
}:
|
| 1775 |
+
return FixPastTimeRandomSelectionStrategy(obs_config, meta_config)
|
| 1776 |
+
raise ValueError(f"Unknown observation type: {strategy_type}")
|
sim_priors_pk/data/data_generation/observations_functions.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains the observation functions that create the separation
|
| 3 |
+
between observations and remainders, the reminder can be either future
|
| 4 |
+
or selected from random in betweens, or None
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
from typing import Callable, Optional, Tuple
|
| 9 |
+
from torchtyping import TensorType
|
| 10 |
+
|
| 11 |
+
def fix_past_time_random_selection(
|
| 12 |
+
full_simulation: TensorType["N", "S"],
|
| 13 |
+
full_simulation_times: TensorType["N", "S"],
|
| 14 |
+
*,
|
| 15 |
+
boundary_ratio: float = 0.1,
|
| 16 |
+
fixed_M_max: int,
|
| 17 |
+
num_obs_sampler: Optional[Callable[[int], torch.Tensor]] = None,
|
| 18 |
+
generator: Optional[torch.Generator] = None,
|
| 19 |
+
**kwargs,
|
| 20 |
+
) -> Tuple[
|
| 21 |
+
TensorType["N", "M"],
|
| 22 |
+
TensorType["N", "M"],
|
| 23 |
+
TensorType["N", "M"],
|
| 24 |
+
None,
|
| 25 |
+
None,
|
| 26 |
+
None,
|
| 27 |
+
]:
|
| 28 |
+
"""Select observation time-points uniformly without replacement.
|
| 29 |
+
|
| 30 |
+
Each row samples indices from the simulation grid independently and
|
| 31 |
+
uniformly (no replacement), then sorts the selected points by sampled
|
| 32 |
+
timestamps to keep chronological ordering in the output tensors.
|
| 33 |
+
"""
|
| 34 |
+
if full_simulation is None:
|
| 35 |
+
return (None,) * 6
|
| 36 |
+
|
| 37 |
+
device = full_simulation.device
|
| 38 |
+
N, S = full_simulation.shape
|
| 39 |
+
M = int(max(0, fixed_M_max))
|
| 40 |
+
|
| 41 |
+
gen = generator if generator is not None else torch.default_generator
|
| 42 |
+
observations = torch.zeros(N, M, device=device, dtype=full_simulation.dtype)
|
| 43 |
+
observation_times = torch.zeros(N, M, device=device, dtype=full_simulation_times.dtype)
|
| 44 |
+
obs_mask = torch.zeros(N, M, dtype=torch.bool, device=device)
|
| 45 |
+
|
| 46 |
+
sample_cap = min(M, S)
|
| 47 |
+
if sample_cap == 0:
|
| 48 |
+
return observations, observation_times, obs_mask, None, None, None
|
| 49 |
+
|
| 50 |
+
if num_obs_sampler is None:
|
| 51 |
+
num_obs = torch.full((N,), sample_cap, dtype=torch.long, device=device)
|
| 52 |
+
else:
|
| 53 |
+
num_obs = num_obs_sampler(N).to(device=device, dtype=torch.long).clamp(1, sample_cap)
|
| 54 |
+
|
| 55 |
+
# Per-row sampling keeps selection uniform without replacement.
|
| 56 |
+
for row in range(N):
|
| 57 |
+
row_count = int(num_obs[row].item())
|
| 58 |
+
if row_count <= 0:
|
| 59 |
+
continue
|
| 60 |
+
selected = torch.randperm(S, generator=gen, device=device)[:row_count]
|
| 61 |
+
if row_count > 1:
|
| 62 |
+
# Order chosen simulation indices by sampled time for stable packing.
|
| 63 |
+
order = torch.argsort(full_simulation_times[row, selected])
|
| 64 |
+
selected = selected[order]
|
| 65 |
+
observations[row, :row_count] = full_simulation[row, selected]
|
| 66 |
+
observation_times[row, :row_count] = full_simulation_times[row, selected]
|
| 67 |
+
obs_mask[row, :row_count] = True
|
| 68 |
+
|
| 69 |
+
return observations, observation_times, obs_mask, None, None, None
|
sim_priors_pk/data/data_generation/study_population_stats.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This is used for calculating summary statistics over ensembles of StudyJSONs to check that
|
| 2 |
+
the distribution of simulated data matches empirical data."""
|
| 3 |
+
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StudyPopulationStats(ABC):
|
| 13 |
+
"""Abstract interface for computing and aggregating statistics over ensembles of StudyJSONs."""
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def compute_per_individual(self, ind: IndividualJSON) -> Dict[str, float]:
|
| 17 |
+
"""Compute statistics for a single individual (e.g., min/max observation value, count)."""
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def compute_per_study(self, study: StudyJSON) -> Dict[str, float]:
|
| 21 |
+
"""Compute statistics for a single study (e.g., min/max observation value, count)."""
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def aggregate(
|
| 25 |
+
self,
|
| 26 |
+
per_study: List[Dict[str, float]],
|
| 27 |
+
) -> Dict[str, object]:
|
| 28 |
+
"""Aggregate statistics across studies (e.g., global extrema, averages, or histograms)."""
|
| 29 |
+
|
| 30 |
+
def compute_study_population_statistics(
|
| 31 |
+
self,
|
| 32 |
+
studies: List[StudyJSON],
|
| 33 |
+
) -> Dict[str, object]:
|
| 34 |
+
"""Compute and aggregate statistics for a StudyJSON ensemble."""
|
| 35 |
+
per_study = [self.compute_per_study(study) for study in studies]
|
| 36 |
+
return self.aggregate(per_study)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BasicObservationStats(StudyPopulationStats):
|
| 40 |
+
"""Compute descriptive statistics for observation values across individuals.
|
| 41 |
+
For each individual, computes:
|
| 42 |
+
- nAUC: Area Under the Curve (AUC), normalized by dose, using trapezoidal rule.
|
| 43 |
+
- nCmax: Maximum observed concentration, normalized by dose.
|
| 44 |
+
- Tmax: Time at which Cmax occurs.
|
| 45 |
+
- Nobs: Number of observations.
|
| 46 |
+
- Duration: Duration of the observation period (max observation time).
|
| 47 |
+
For each study, computes:
|
| 48 |
+
- Mean and standard deviation of nAUC, nCmax, Tmax across individuals.
|
| 49 |
+
- Mean and total number of observations (Nobs) across all individuals.
|
| 50 |
+
- Total study duration (max Duration across individuals).
|
| 51 |
+
Aggregates across studies to provide percentiles of each study-level statistic.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, alpha=0.1):
|
| 55 |
+
self.alpha = alpha
|
| 56 |
+
|
| 57 |
+
def compute_per_individual(self, ind: IndividualJSON) -> Dict[str, float]:
|
| 58 |
+
obs_vals = ind.get("observations", [])
|
| 59 |
+
obs_times = ind.get("observation_times", [])
|
| 60 |
+
dose = ind.get("dosing", [])
|
| 61 |
+
dosing_time = ind.get("dosing_times", [])
|
| 62 |
+
route = ind.get("dosing_type", [])
|
| 63 |
+
|
| 64 |
+
if not obs_vals:
|
| 65 |
+
return {"nAUC": np.nan, "nCmax": np.nan, "Tmax": np.nan, "Nobs": 0, "Duration": np.nan}
|
| 66 |
+
|
| 67 |
+
# Check that input times are sorted and match the number of observations
|
| 68 |
+
if len(obs_times) != len(obs_vals) or any(
|
| 69 |
+
obs_times[i] >= obs_times[i + 1] for i in range(len(obs_times) - 1)
|
| 70 |
+
):
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"Observation times must be sorted and match the number of observations."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Check that there is only a single positive dose
|
| 76 |
+
if len(dose) != 1 or len(dosing_time) != 1 or len(route) != 1:
|
| 77 |
+
raise ValueError("Only single dosing is supported in this statistic.")
|
| 78 |
+
if dose[0] <= 0 or np.isnan(dose) or np.isnan(dosing_time[0]):
|
| 79 |
+
raise ValueError("Dose must be positive.")
|
| 80 |
+
|
| 81 |
+
# Check that dose precedes observations
|
| 82 |
+
if any(t < dosing_time[0] for t in obs_times):
|
| 83 |
+
raise ValueError("Dosing time must precede observation times.")
|
| 84 |
+
|
| 85 |
+
# calculate AUC using the trapezoidal rule:
|
| 86 |
+
# - for oral dosing, add a value of 0 at dosing time
|
| 87 |
+
# - for iv bolus, add the first observation at dosing time
|
| 88 |
+
|
| 89 |
+
obs_times_trapz = dosing_time + obs_times
|
| 90 |
+
if route[0] == "oral":
|
| 91 |
+
obs_vals_trapz = [0.0] + obs_vals
|
| 92 |
+
elif route[0] == "iv":
|
| 93 |
+
obs_vals_trapz = [obs_vals[0]] + obs_vals
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError("Only 'oral' and 'iv' dosing types are supported.")
|
| 96 |
+
|
| 97 |
+
auc = np.trapezoid(obs_vals_trapz, obs_times_trapz) if len(obs_vals) > 0 else np.nan
|
| 98 |
+
auc /= dose[0]
|
| 99 |
+
|
| 100 |
+
# Calculate Cmax and Tmax
|
| 101 |
+
Cmax_idx = np.argmax(obs_vals)
|
| 102 |
+
Cmax = obs_vals[Cmax_idx]
|
| 103 |
+
Tmax = obs_times[Cmax_idx]
|
| 104 |
+
Cmax /= dose[0]
|
| 105 |
+
|
| 106 |
+
return {
|
| 107 |
+
"nAUC": float(auc),
|
| 108 |
+
"nCmax": float(Cmax),
|
| 109 |
+
"Tmax": float(Tmax),
|
| 110 |
+
"Nobs": len(obs_vals),
|
| 111 |
+
"Duration": np.max(obs_times),
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
def compute_per_study(self, study: StudyJSON) -> Dict[str, float]:
|
| 115 |
+
ind_stats = [
|
| 116 |
+
self.compute_per_individual(ind)
|
| 117 |
+
for block in ("context", "target")
|
| 118 |
+
for ind in study.get(block, [])
|
| 119 |
+
]
|
| 120 |
+
if not ind_stats:
|
| 121 |
+
return {"max_obs": np.nan, "min_obs": np.nan, "mean_obs": np.nan, "num_obs": 0}
|
| 122 |
+
|
| 123 |
+
# Calculate statistics (maybe a bit too much, can be simplified later)
|
| 124 |
+
metrics = {
|
| 125 |
+
"nAUC_mean": ("nAUC", np.mean),
|
| 126 |
+
"nAUC_sd": ("nAUC", np.std),
|
| 127 |
+
"nAUC_cv": ("nAUC", lambda x: np.std(x) / np.mean(x) * 100 if np.mean(x) != 0 else np.nan),
|
| 128 |
+
"nCmax_mean": ("nCmax", np.mean),
|
| 129 |
+
"nCmax_sd": ("nCmax", np.std),
|
| 130 |
+
"nCmax_cv": ("nCmax", lambda x: np.std(x) / np.mean(x) * 100 if np.mean(x) != 0 else np.nan),
|
| 131 |
+
"Tmax_mean": ("Tmax", np.mean),
|
| 132 |
+
"Tmax_sd": ("Tmax", np.std),
|
| 133 |
+
"Tmax_cv": ("Tmax", lambda x: np.std(x) / np.mean(x) * 100 if np.mean(x) != 0 else np.nan),
|
| 134 |
+
"Nobs_mean": ("Nobs", np.mean),
|
| 135 |
+
"Nobs_total": ("Nobs", np.sum),
|
| 136 |
+
"Duration_max": ("Duration", np.max),
|
| 137 |
+
"nID": ("Nobs", lambda x: len(x)),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
results = {name: func([d[key] for d in ind_stats]) for name, (key, func) in metrics.items()}
|
| 141 |
+
|
| 142 |
+
# Ensure all values are floats for JSON-friendliness or downstream compatibility
|
| 143 |
+
return {k: float(v) for k, v in results.items()}
|
| 144 |
+
|
| 145 |
+
def aggregate(
|
| 146 |
+
self,
|
| 147 |
+
per_study: List[Dict[str, float]],
|
| 148 |
+
) -> Dict[str, object]:
|
| 149 |
+
"""Aggregate statistics across studies."""
|
| 150 |
+
# Calculate percentiles of study-level statistics
|
| 151 |
+
percentiles = [5, 50, 95]
|
| 152 |
+
summary: Dict[str, object] = {}
|
| 153 |
+
for key in per_study[0].keys():
|
| 154 |
+
values = [s[key] for s in per_study if not np.isnan(s[key])]
|
| 155 |
+
if values:
|
| 156 |
+
summary[f"{key}_percentiles"] = {
|
| 157 |
+
f"P{p}": float(np.percentile(values, p)) for p in percentiles
|
| 158 |
+
}
|
| 159 |
+
else:
|
| 160 |
+
summary[f"{key}_percentiles"] = {f"P{p}": np.nan for p in percentiles}
|
| 161 |
+
summary["Nstudy"] = len(per_study)
|
| 162 |
+
|
| 163 |
+
return summary
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ListedObservationStats(BasicObservationStats):
|
| 167 |
+
"""Variant of BasicObservationStats that returns lists of study-level statistics instead of percentiles.
|
| 168 |
+
This is useful for more detailed analyses or visualizations of the distribution of study-level statistics.
|
| 169 |
+
"""
|
| 170 |
+
def __init__(self, alpha=0.1):
|
| 171 |
+
self.alpha = alpha
|
| 172 |
+
|
| 173 |
+
def aggregate(
|
| 174 |
+
self,
|
| 175 |
+
per_study: List[Dict[str, float]],
|
| 176 |
+
) -> Dict[str, object]:
|
| 177 |
+
"""Aggregate statistics across studies."""
|
| 178 |
+
# Collect lists of study-level statistics
|
| 179 |
+
summary: Dict[str, object] = {}
|
| 180 |
+
for key in per_study[0].keys():
|
| 181 |
+
values = [s[key] for s in per_study]
|
| 182 |
+
summary[f"{key}_list"] = [float(v) for v in values]
|
| 183 |
+
summary["Nstudy"] = len(per_study)
|
| 184 |
+
|
| 185 |
+
return summary
|
sim_priors_pk/data/data_preprocessing/__init__.py
ADDED
|
File without changes
|
sim_priors_pk/data/data_preprocessing/data_preprocessing_utils.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torchtyping import TensorType
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchtyping import TensorType
|
| 8 |
+
from typing import List,Tuple,Optional
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from sim_priors_pk.data.data_preprocessing.raw_to_tensors_bundles import substance_cvs_to_tensors_bundle,substances_csv_to_tensors
|
| 12 |
+
|
| 13 |
+
from typing import NamedTuple
|
| 14 |
+
import torch
|
| 15 |
+
from torchtyping import TensorType
|
| 16 |
+
|
| 17 |
+
class SubstanceTensorGroup(NamedTuple):
|
| 18 |
+
observations: TensorType[1, "I", "T"]
|
| 19 |
+
times: TensorType[1, "I", "T"]
|
| 20 |
+
mask: TensorType[1, "I", "T"]
|
| 21 |
+
subject_mask: TensorType[1, "I"]
|
| 22 |
+
|
| 23 |
+
def apply_timescale_filter(
|
| 24 |
+
observations: TensorType["S", "I", "T"],
|
| 25 |
+
times: TensorType["S", "I", "T"],
|
| 26 |
+
masks: TensorType["S", "I", "T"],
|
| 27 |
+
subject_mask: TensorType["S", "I"],
|
| 28 |
+
*,
|
| 29 |
+
strategy: str = "log_zscore", # "log_zscore" | "median_fraction" | "none"
|
| 30 |
+
max_abs_z: float = 2.0, # for "log_zscore"
|
| 31 |
+
tau: float = 0.4, # for "median_fraction" (≈ ln 1.5)
|
| 32 |
+
) -> Tuple[
|
| 33 |
+
TensorType["S", "I", "T"], # filtered observations
|
| 34 |
+
TensorType["S", "I", "T"], # filtered times
|
| 35 |
+
TensorType["S", "I", "T"], # filtered masks
|
| 36 |
+
TensorType["S", "I"], # filtered subject_mask
|
| 37 |
+
]:
|
| 38 |
+
"""
|
| 39 |
+
Zeroes‑out and un‑masks subjects whose time‑span is an outlier
|
| 40 |
+
w.r.t. other subjects in the *same* substance.
|
| 41 |
+
|
| 42 |
+
• strategy="log_zscore": keep subjects with |z| ≤ max_abs_z in log‑span
|
| 43 |
+
• strategy="median_fraction": keep subjects within ±tau of median(log‑span)
|
| 44 |
+
• strategy="none": return inputs unchanged
|
| 45 |
+
"""
|
| 46 |
+
if strategy == "none":
|
| 47 |
+
return observations, times, masks, subject_mask
|
| 48 |
+
|
| 49 |
+
# combine padding + subject mask to know valid time points
|
| 50 |
+
valid = masks.bool() & subject_mask.unsqueeze(-1)
|
| 51 |
+
|
| 52 |
+
# --- compute log‑spans ----------------------------------------------------
|
| 53 |
+
t_max = times.masked_fill(~valid, float("-inf")).max(dim=2).values # [S, I]
|
| 54 |
+
t_min = times.masked_fill(~valid, float("inf")).min(dim=2).values # [S, I]
|
| 55 |
+
span = (t_max - t_min).clamp(min=1e-12)
|
| 56 |
+
log_span = span.log() # [S, I]
|
| 57 |
+
|
| 58 |
+
# --- decide which subjects to keep ---------------------------------------
|
| 59 |
+
if strategy == "log_zscore":
|
| 60 |
+
z = (log_span - log_span.mean(dim=1, keepdim=True)) / \
|
| 61 |
+
(log_span.std(dim=1, keepdim=True).clamp(min=1e-6))
|
| 62 |
+
keep = torch.abs(z) <= max_abs_z # [S, I]
|
| 63 |
+
|
| 64 |
+
elif strategy == "median_fraction":
|
| 65 |
+
med = log_span.median(dim=1, keepdim=True).values # [S,1]
|
| 66 |
+
keep = (log_span >= med - tau) & (log_span <= med + tau) # [S,I]
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
# No filtering applied — return inputs unchanged
|
| 70 |
+
return observations, times, masks, subject_mask
|
| 71 |
+
|
| 72 |
+
# --- apply filter: zero & un‑mask ----------------------------------------
|
| 73 |
+
# clone so we don't mutate original tensors accidentally
|
| 74 |
+
obs_f = observations.clone()
|
| 75 |
+
times_f = times.clone()
|
| 76 |
+
masks_f = masks.clone()
|
| 77 |
+
subj_f = subject_mask.clone()
|
| 78 |
+
|
| 79 |
+
# indices where we drop subjects
|
| 80 |
+
drop = ~keep & subj_f.bool()
|
| 81 |
+
subj_f[drop] = False
|
| 82 |
+
masks_f[drop] = False
|
| 83 |
+
obs_f[drop] = 0.0
|
| 84 |
+
times_f[drop] = 0.0
|
| 85 |
+
|
| 86 |
+
return obs_f, times_f, masks_f, subj_f
|
| 87 |
+
|
| 88 |
+
def plot_subjects_for_substance(
|
| 89 |
+
drug_data_frame,
|
| 90 |
+
substance_label: str,
|
| 91 |
+
*,
|
| 92 |
+
z_score_normalization: bool = False,
|
| 93 |
+
normalize_by_max:bool = False,
|
| 94 |
+
time_strategy:str="log_zscore", # "log_zscore" | "median_fraction" | "none"
|
| 95 |
+
max_abs_z:float=2.,
|
| 96 |
+
x_scale: str = "linear", # "linear" ▸ default · "log"
|
| 97 |
+
y_scale: str = "linear", # "linear" ▸ default · "log"
|
| 98 |
+
alpha: float = 1.0, # 0 ≤ alpha ≤ 1
|
| 99 |
+
legend_outside: bool = True, # park legend to the right
|
| 100 |
+
figsize: Tuple[float, float] = (10, 5), # default width × height
|
| 101 |
+
save_dir: Optional[str] = None, # if set, saves the figure here
|
| 102 |
+
|
| 103 |
+
) -> None:
|
| 104 |
+
"""
|
| 105 |
+
Draw every subject‑trajectory (points + line) for *one* substance.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
drug_data_frame : pandas.DataFrame
|
| 110 |
+
substance_label : str
|
| 111 |
+
z_score_normalization : bool, optional
|
| 112 |
+
x_scale, y_scale : {"linear", "log"}, optional
|
| 113 |
+
Axis scaling. If you pick "log", make sure data are strictly > 0
|
| 114 |
+
on that axis or Matplotlib will complain.
|
| 115 |
+
alpha : float in [0, 1], optional
|
| 116 |
+
Transparency applied to both the line and the markers.
|
| 117 |
+
legend_outside : bool, optional
|
| 118 |
+
True ⇢ legend in a separate column to the right;
|
| 119 |
+
False ⇢ legend inside plot.
|
| 120 |
+
"""
|
| 121 |
+
# ── 1. Pull tensors ────────────────────────────────────────────
|
| 122 |
+
data_bundle = substance_cvs_to_tensors_bundle(drug_data_frame,normalize_by_max=True)
|
| 123 |
+
|
| 124 |
+
all_obs = data_bundle.observations # [S, I, T]
|
| 125 |
+
all_times = data_bundle.times # [S, I, T]
|
| 126 |
+
all_masks = data_bundle.masks # [S, I, T]
|
| 127 |
+
all_subj_mask = data_bundle.individuals_mask
|
| 128 |
+
substance_labels = data_bundle.substance_names # [S]
|
| 129 |
+
mapping = data_bundle.mapping
|
| 130 |
+
study_names = data_bundle.study_names # [S]
|
| 131 |
+
subject_names = data_bundle.individuals_names # [S][I]
|
| 132 |
+
empirical_loaded = True
|
| 133 |
+
|
| 134 |
+
# ── 2. Find substance row ──────────────────────────────────────
|
| 135 |
+
try:
|
| 136 |
+
s_idx: int = int(np.where(substance_labels == substance_label)[0][0])
|
| 137 |
+
except IndexError:
|
| 138 |
+
raise ValueError(f"Substance '{substance_label}' not found.")
|
| 139 |
+
|
| 140 |
+
# ("I", "T")
|
| 141 |
+
obs: TensorType["I", "T"] = all_obs[s_idx]
|
| 142 |
+
times: TensorType["I", "T"] = all_times[s_idx]
|
| 143 |
+
step_mask: TensorType["I", "T"] = all_masks[s_idx].bool()
|
| 144 |
+
subj_mask: TensorType["I"] = all_subj_mask[s_idx].bool()
|
| 145 |
+
|
| 146 |
+
# ── 3. Filter Time Series ──────────────────────────────────────
|
| 147 |
+
# Add batch dimension to match expected input [S, I, T], [S, I]
|
| 148 |
+
obs_b = obs.unsqueeze(0) # [1, I, T]
|
| 149 |
+
times_b = times.unsqueeze(0) # [1, I, T]
|
| 150 |
+
step_mask_b = step_mask.unsqueeze(0) # [1, I]
|
| 151 |
+
subj_mask_b = subj_mask.unsqueeze(0) # [1, I]
|
| 152 |
+
|
| 153 |
+
# Apply timescale filter (choose one strategy)
|
| 154 |
+
obs_b, times_b, step_mask_b, subj_mask_b = apply_timescale_filter(
|
| 155 |
+
observations=obs_b,
|
| 156 |
+
times=times_b,
|
| 157 |
+
masks=step_mask_b,
|
| 158 |
+
subject_mask=subj_mask_b,
|
| 159 |
+
strategy=time_strategy, # or "median_fraction"
|
| 160 |
+
max_abs_z=max_abs_z,
|
| 161 |
+
tau=0.4,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Remove batch dim again
|
| 165 |
+
obs = obs_b[0]
|
| 166 |
+
times = times_b[0]
|
| 167 |
+
step_mask = step_mask_b[0]
|
| 168 |
+
subj_mask = subj_mask_b[0]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ── 4. Plot one line per *real* subject ────────────────────────
|
| 172 |
+
fig, ax = plt.subplots(figsize=figsize)
|
| 173 |
+
for i in range(obs.shape[0]): # iterate subjects (I)
|
| 174 |
+
if not subj_mask[i]:
|
| 175 |
+
continue # skip padded rows
|
| 176 |
+
|
| 177 |
+
valid: TensorType["T"] = step_mask[i] # True ⇢ real sample
|
| 178 |
+
t: TensorType["T"] = times[i][valid].cpu()
|
| 179 |
+
y: TensorType["T"] = obs[i][valid].cpu()
|
| 180 |
+
|
| 181 |
+
ax.plot(t, y, marker="o", alpha=alpha, label=f"subject {i}")
|
| 182 |
+
|
| 183 |
+
# ── 5. Styling ────────────────────────────────────────────────
|
| 184 |
+
ax.set_title(f"All subjects – {substance_label}")
|
| 185 |
+
ax.set_xlabel("Time (normalised per substance)")
|
| 186 |
+
ax.set_ylabel("Observation")
|
| 187 |
+
|
| 188 |
+
# Axis scales
|
| 189 |
+
ax.set_xscale(x_scale)
|
| 190 |
+
ax.set_yscale(y_scale)
|
| 191 |
+
|
| 192 |
+
# Legend placement
|
| 193 |
+
if legend_outside:
|
| 194 |
+
# ncol=1 ▸ vertical list; bbox_to_anchor shifts legend fully outside
|
| 195 |
+
ax.legend(
|
| 196 |
+
loc="center left",
|
| 197 |
+
bbox_to_anchor=(1.02, 0.5),
|
| 198 |
+
borderaxespad=0.0,
|
| 199 |
+
frameon=False,
|
| 200 |
+
)
|
| 201 |
+
plt.tight_layout(rect=[0, 0, 0.82, 1]) # leave room on the right
|
| 202 |
+
else:
|
| 203 |
+
ax.legend(frameon=False)
|
| 204 |
+
plt.tight_layout()
|
| 205 |
+
|
| 206 |
+
# Save figure if path is given
|
| 207 |
+
if save_dir is not None:
|
| 208 |
+
from pathlib import Path
|
| 209 |
+
study_name = mapping[substance_label]["study_name"]
|
| 210 |
+
index = mapping[substance_label]["index"]
|
| 211 |
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
| 212 |
+
filename = f"{study_name}_{substance_label}_{index}.png"
|
| 213 |
+
filepath = Path(save_dir) / filename
|
| 214 |
+
fig.savefig(filepath, bbox_inches="tight", dpi=300)
|
| 215 |
+
|
| 216 |
+
plt.show()
|
| 217 |
+
|
| 218 |
+
def substances_with_min_timesteps(
|
| 219 |
+
drug_data_frame,
|
| 220 |
+
min_timesteps: int = 140,
|
| 221 |
+
*,
|
| 222 |
+
z_score_normalization: bool = False,
|
| 223 |
+
normalize_by_max:bool = False,
|
| 224 |
+
) -> List[str]:
|
| 225 |
+
"""
|
| 226 |
+
Return the list of substance labels whose **best** subject has
|
| 227 |
+
≥ `min_timesteps` valid observations.
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
drug_data_frame : pandas.DataFrame
|
| 232 |
+
Same dataframe you already pass to `substance_cvs_to_tensors_from_list`.
|
| 233 |
+
min_timesteps : int, default = 140
|
| 234 |
+
Threshold on the number of valid (unpadded) time‑points.
|
| 235 |
+
z_score_normalization : bool, default = False
|
| 236 |
+
Passed straight through to `substance_cvs_to_tensors_from_list`.
|
| 237 |
+
|
| 238 |
+
Returns
|
| 239 |
+
-------
|
| 240 |
+
List[str]
|
| 241 |
+
Substance strings that satisfy the criterion.
|
| 242 |
+
"""
|
| 243 |
+
(
|
| 244 |
+
all_observations, # TensorType["S", "I", "T"] – concentration values
|
| 245 |
+
all_times, # TensorType["S", "I", "T"] – time grid (0‥1)
|
| 246 |
+
all_masks, # TensorType["S", "I", "T"] – bool, 1 = real step
|
| 247 |
+
all_subjects_mask, # TensorType["S", "I"] – bool, 1 = real subject
|
| 248 |
+
substance_labels, # np.ndarray, shape ["S"]
|
| 249 |
+
mapping
|
| 250 |
+
) = substance_cvs_to_tensors_bundle(
|
| 251 |
+
drug_data_frame,
|
| 252 |
+
z_score_normalization=z_score_normalization,
|
| 253 |
+
normalize_by_max=normalize_by_max
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# --- Shapes -------------------------------------------------------
|
| 257 |
+
# S = number of substances, I = max subjects per substance,
|
| 258 |
+
# T = max time‑steps per subject.
|
| 259 |
+
# all_masks : (S, I, T) – True at valid positions
|
| 260 |
+
# all_subjects_mask: (S, I) – True for *existing* subjects only
|
| 261 |
+
# -----------------------------------------------------------------
|
| 262 |
+
|
| 263 |
+
# Convert to bool & mask out padded subjects
|
| 264 |
+
valid_masks: TensorType["S", "I", "T"] = all_masks.bool()
|
| 265 |
+
subj_mask: TensorType["S", "I", 1] = all_subjects_mask.bool().unsqueeze(-1)
|
| 266 |
+
valid_masks = valid_masks & subj_mask # shape keeps (S,I,T)
|
| 267 |
+
|
| 268 |
+
# Count valid steps per subject ───────────────────────────────────
|
| 269 |
+
# counts[s, i] = #valid time‑points of subject i in substance s
|
| 270 |
+
counts: TensorType["S", "I"] = valid_masks.sum(dim=2) # (S, I)
|
| 271 |
+
|
| 272 |
+
# Max over subjects (per substance) -------------------------------
|
| 273 |
+
max_counts: TensorType["S"] = counts.max(dim=1).values # (S,)
|
| 274 |
+
|
| 275 |
+
# Pick substances that meet / beat the threshold ------------------
|
| 276 |
+
qualifying: TensorType["S"] = max_counts >= min_timesteps # (S,)
|
| 277 |
+
|
| 278 |
+
# Build the output list -------------------------------------------
|
| 279 |
+
return [label for label, keep in zip(substance_labels.tolist(), qualifying.tolist()) if keep]
|
| 280 |
+
|
| 281 |
+
def get_substance_tensors_by_label(
|
| 282 |
+
drug_data_frame,
|
| 283 |
+
substance_label: str,
|
| 284 |
+
*,
|
| 285 |
+
z_score_normalization: bool = False,
|
| 286 |
+
normalize_by_max: bool = False,
|
| 287 |
+
) -> SubstanceTensorGroup:
|
| 288 |
+
"""
|
| 289 |
+
Returns tensors for a selected substance, preserving S=1 batch shape.
|
| 290 |
+
|
| 291 |
+
Shapes:
|
| 292 |
+
observations : [1, I, T]
|
| 293 |
+
times : [1, I, T]
|
| 294 |
+
mask : [1, I, T]
|
| 295 |
+
subject_mask : [1, I]
|
| 296 |
+
"""
|
| 297 |
+
data_bundle = substance_cvs_to_tensors_bundle(drug_data_frame,
|
| 298 |
+
z_score_normalization=z_score_normalization,
|
| 299 |
+
normalize_by_max=normalize_by_max)
|
| 300 |
+
|
| 301 |
+
all_observations = data_bundle.observations # [S, I, T]
|
| 302 |
+
all_empirical_times = data_bundle.times # [S, I, T]
|
| 303 |
+
all_empirical_mask = data_bundle.masks # [S, I, T]
|
| 304 |
+
all_subjects_mask = data_bundle.individuals_mask
|
| 305 |
+
substance_labels = data_bundle.substance_names # [S]
|
| 306 |
+
mapping = data_bundle.mapping
|
| 307 |
+
|
| 308 |
+
# Lookup index
|
| 309 |
+
label_to_index = {label: idx for idx, label in enumerate(substance_labels)}
|
| 310 |
+
if substance_label not in label_to_index:
|
| 311 |
+
raise ValueError(f"Substance label '{substance_label}' not found.")
|
| 312 |
+
s_idx = label_to_index[substance_label]
|
| 313 |
+
|
| 314 |
+
# Add batch dim: [1, I, T] or [1, I]
|
| 315 |
+
return SubstanceTensorGroup(
|
| 316 |
+
observations=all_observations[s_idx].unsqueeze(0), # [1, I, T]
|
| 317 |
+
times=all_empirical_times[s_idx].unsqueeze(0), # [1, I, T]
|
| 318 |
+
mask=all_empirical_mask[s_idx].unsqueeze(0).bool(), # [1, I, T]
|
| 319 |
+
subject_mask=all_subjects_mask[s_idx].unsqueeze(0).bool() # [1, I]
|
| 320 |
+
)
|
| 321 |
+
|
sim_priors_pk/data/data_preprocessing/raw_to_tensors_bundles.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Here we define the functions requiered to process the data
|
| 3 |
+
|
| 4 |
+
https://pk-db.com/
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
from torchtyping import TensorType
|
| 12 |
+
from typing import NamedTuple, List, Dict
|
| 13 |
+
from torchtyping import TensorType
|
| 14 |
+
from typing import Dict, Tuple, List
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from torchtyping import TensorType
|
| 18 |
+
|
| 19 |
+
from typing import Dict, Tuple, List, Optional
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from torchtyping import TensorType
|
| 23 |
+
|
| 24 |
+
lenuzza_doses_mg_per_g = {
|
| 25 |
+
"memantine": 0.005,
|
| 26 |
+
"omeprazole": 0.010,
|
| 27 |
+
"repaglinide": 0.00025,
|
| 28 |
+
"rosuvastatin": 0.005,
|
| 29 |
+
"tolbutamide": 0.010,
|
| 30 |
+
"dextromethorphan": 0.018,
|
| 31 |
+
"digoxin": 0.00025,
|
| 32 |
+
"paracetamol": 0.060,
|
| 33 |
+
"caffeine": 0.073,
|
| 34 |
+
"midazolam": 0.004,
|
| 35 |
+
"paraxanthine":0.073,
|
| 36 |
+
"dextrorphan":0.018,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
class EmpiricalSubstanceTensorBundle(NamedTuple):
|
| 40 |
+
observations: TensorType["S", "I", "T"] # padded concentration values
|
| 41 |
+
times: TensorType["S", "I", "T"] # padded normalized times [0,1]
|
| 42 |
+
masks: TensorType["S", "I", "T"] # 1 = observed, 0 = missing or padded
|
| 43 |
+
individuals_mask: TensorType["S", "I"] # 1 = real subject, 0 = padded row
|
| 44 |
+
study_names: List[str] # [S] → one study name per substance
|
| 45 |
+
individuals_names: List[List[str]] # [S][I] → subject name per padded subject
|
| 46 |
+
substance_names: List[str] # [S] substance_label entries
|
| 47 |
+
mapping: Dict[str, Dict[str, object]]
|
| 48 |
+
dosing_amounts: TensorType["S", "I"] # dose mg/g per subject
|
| 49 |
+
dosing_route_types: TensorType["S", "I"] # route type index per subject
|
| 50 |
+
|
| 51 |
+
def map_substance_to_index_and_study(
|
| 52 |
+
drug_data_frame
|
| 53 |
+
) -> dict[str, dict[str, object]]:
|
| 54 |
+
"""
|
| 55 |
+
Returns a dictionary mapping each substance_label to its index (in np.unique order)
|
| 56 |
+
and its associated study_name (taken from the first row where that label appears).
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
dict: {
|
| 61 |
+
"substance_label": {
|
| 62 |
+
"index": int,
|
| 63 |
+
"study_name": str
|
| 64 |
+
},
|
| 65 |
+
...
|
| 66 |
+
}
|
| 67 |
+
"""
|
| 68 |
+
substance_labels = np.unique(drug_data_frame["substance_label"].values)
|
| 69 |
+
|
| 70 |
+
mapping = {}
|
| 71 |
+
for idx, label in enumerate(substance_labels):
|
| 72 |
+
study_name = drug_data_frame.loc[
|
| 73 |
+
drug_data_frame["substance_label"] == label, "study_name"
|
| 74 |
+
].iloc[0]
|
| 75 |
+
mapping[label] = {
|
| 76 |
+
"index": idx,
|
| 77 |
+
"study_name": study_name
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
return mapping
|
| 81 |
+
|
| 82 |
+
def substances_csv_to_tensors(drug_data_frame, substance_label='omeprazole'):
|
| 83 |
+
"""
|
| 84 |
+
The function groups by substance_label and obtains the time series
|
| 85 |
+
for each subject, pads when necessary, and returns observations, times, and masks.
|
| 86 |
+
|
| 87 |
+
Params:
|
| 88 |
+
drug_data_frame (pd.DataFrame): Input DataFrame with specified columns.
|
| 89 |
+
substance_label (str): The substance label to filter by. Defaults to 'omeprazole'.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
observations (torch.Tensor): Padded observation values tensor of shape [num_subjects, max_time].
|
| 93 |
+
observations_times (torch.Tensor): Padded time points tensor of shape [num_subjects, max_time].
|
| 94 |
+
observations_mask (torch.Tensor): Mask tensor indicating valid data points, shape [num_subjects, max_time].
|
| 95 |
+
dosing_amounts (torch.Tensor): Dose amount per subject [num_subjects].
|
| 96 |
+
dosing_route_types (torch.Tensor): Route type index per subject [num_subjects].
|
| 97 |
+
"""
|
| 98 |
+
# Filter the DataFrame by the given substance_label
|
| 99 |
+
substance_data = drug_data_frame[drug_data_frame['substance_label'] == substance_label]
|
| 100 |
+
|
| 101 |
+
# Group by subject_name
|
| 102 |
+
subject_groups = substance_data.groupby('subject_name')
|
| 103 |
+
|
| 104 |
+
# Collect sorted time and value arrays for each subject
|
| 105 |
+
times_list = []
|
| 106 |
+
values_list = []
|
| 107 |
+
dosing_amounts_list = []
|
| 108 |
+
route_list = []
|
| 109 |
+
for subject_name, group in subject_groups:
|
| 110 |
+
# Sort the group by 'time' to ensure chronological order
|
| 111 |
+
sorted_group = group.sort_values('time')
|
| 112 |
+
times = sorted_group['time'].values.astype(np.float32)
|
| 113 |
+
values = sorted_group['value'].values.astype(np.float32)
|
| 114 |
+
times_list.append(times)
|
| 115 |
+
values_list.append(values)
|
| 116 |
+
|
| 117 |
+
# Determine dosing amount based on substance name
|
| 118 |
+
if 'substance_name' in group.columns:
|
| 119 |
+
s_name = str(group['substance_name'].iloc[0]).lower()
|
| 120 |
+
else:
|
| 121 |
+
s_name = str(substance_label).lower()
|
| 122 |
+
|
| 123 |
+
dose_value = 0.5
|
| 124 |
+
for key, val in lenuzza_doses_mg_per_g.items():
|
| 125 |
+
if key in s_name:
|
| 126 |
+
dose_value = val
|
| 127 |
+
break
|
| 128 |
+
dosing_amounts_list.append(dose_value)
|
| 129 |
+
route_list.append(0) # oral
|
| 130 |
+
|
| 131 |
+
# Determine the maximum time sequence length
|
| 132 |
+
max_len = max(len(times) for times in times_list) if times_list else 0
|
| 133 |
+
|
| 134 |
+
# Pad each subject's time and value arrays, and create the mask
|
| 135 |
+
padded_times = []
|
| 136 |
+
padded_values = []
|
| 137 |
+
masks = []
|
| 138 |
+
for times, values in zip(times_list, values_list):
|
| 139 |
+
current_len = len(times)
|
| 140 |
+
pad_len = max_len - current_len
|
| 141 |
+
|
| 142 |
+
# Pad with zeros
|
| 143 |
+
padded_time = np.pad(times, (0, pad_len), mode='constant', constant_values=0)
|
| 144 |
+
padded_value = np.pad(values, (0, pad_len), mode='constant', constant_values=0)
|
| 145 |
+
|
| 146 |
+
# Create mask (1 for real data, 0 for padding)
|
| 147 |
+
mask = np.ones(max_len, dtype=np.float32)
|
| 148 |
+
mask[current_len:] = 0
|
| 149 |
+
|
| 150 |
+
padded_times.append(padded_time)
|
| 151 |
+
padded_values.append(padded_value)
|
| 152 |
+
masks.append(mask)
|
| 153 |
+
|
| 154 |
+
# Convert to PyTorch tensors
|
| 155 |
+
observations = torch.tensor(padded_values, dtype=torch.float32) # [P, T]
|
| 156 |
+
observations_times = torch.tensor(padded_times, dtype=torch.float32) # [P, T]
|
| 157 |
+
observations_mask = torch.tensor(masks, dtype=torch.float32) # [P, T]
|
| 158 |
+
|
| 159 |
+
dosing_amounts = torch.tensor(dosing_amounts_list, dtype=torch.float32) # [P]
|
| 160 |
+
dosing_route_types = torch.tensor(route_list, dtype=torch.long) # [P]
|
| 161 |
+
|
| 162 |
+
return observations, observations_times, observations_mask, dosing_amounts, dosing_route_types
|
| 163 |
+
|
| 164 |
+
def substance_dict_to_tensors(
|
| 165 |
+
selected_series: Optional[Dict[str, Dict[str, List[float]]]],
|
| 166 |
+
hidden_series: Optional[Dict[str, Dict[str, List[float]]]],
|
| 167 |
+
) -> Tuple[
|
| 168 |
+
Optional[TensorType["N_sel", "T"]], Optional[TensorType["N_sel", "T"]], Optional[TensorType["N_sel", "T"]],
|
| 169 |
+
Optional[TensorType["N_hid", "T"]], Optional[TensorType["N_hid", "T"]], Optional[TensorType["N_hid", "T"]],
|
| 170 |
+
]:
|
| 171 |
+
"""
|
| 172 |
+
Converts two dictionaries of time series into padded tensors, sharing a common maximum sequence length.
|
| 173 |
+
Typically comming from the frontend payload
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
selected_series: Mapping subject_name -> {'timepoints': [...], 'values': [...]}.
|
| 177 |
+
hidden_series: Mapping subject_name -> {'timepoints': [...], 'values': [...]}.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
sel_obs, sel_times, sel_mask: [N_sel, T] or None.
|
| 181 |
+
hid_obs, hid_times, hid_mask: [N_hid, T] or None.
|
| 182 |
+
"""
|
| 183 |
+
def _extract_sorted(series: Dict[str, Dict[str, List[float]]]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
| 184 |
+
times_list, values_list = [], []
|
| 185 |
+
for subj, data in series.items():
|
| 186 |
+
t = np.array(data['timepoints'], dtype=np.float32)
|
| 187 |
+
v = np.array(data['values'], dtype=np.float32)
|
| 188 |
+
idx = np.argsort(t)
|
| 189 |
+
times_list.append(t[idx])
|
| 190 |
+
values_list.append(v[idx])
|
| 191 |
+
return times_list, values_list
|
| 192 |
+
|
| 193 |
+
def _pad(times_list: List[np.ndarray], vals_list: List[np.ndarray], T: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 194 |
+
padded_times, padded_vals, masks = [], [], []
|
| 195 |
+
for t, v in zip(times_list, vals_list):
|
| 196 |
+
pad = T - len(t)
|
| 197 |
+
t_pad = np.pad(t, (0, pad), mode='constant', constant_values=0)
|
| 198 |
+
v_pad = np.pad(v, (0, pad), mode='constant', constant_values=0)
|
| 199 |
+
mask = np.ones(T, dtype=np.float32)
|
| 200 |
+
mask[len(t):] = 0
|
| 201 |
+
padded_times.append(t_pad)
|
| 202 |
+
padded_vals.append(v_pad)
|
| 203 |
+
masks.append(mask)
|
| 204 |
+
return (
|
| 205 |
+
torch.tensor(padded_vals, dtype=torch.float32), # [N, T]
|
| 206 |
+
torch.tensor(padded_times, dtype=torch.float32), # [N, T]
|
| 207 |
+
torch.tensor(masks, dtype=torch.float32), # [N, T]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Handle selected_series
|
| 211 |
+
if selected_series:
|
| 212 |
+
sel_times_list, sel_vals_list = _extract_sorted(selected_series)
|
| 213 |
+
max_len_sel = max((len(t) for t in sel_times_list), default=0)
|
| 214 |
+
else:
|
| 215 |
+
sel_times_list = sel_vals_list = []
|
| 216 |
+
max_len_sel = 0
|
| 217 |
+
|
| 218 |
+
# Handle hidden_series
|
| 219 |
+
if hidden_series:
|
| 220 |
+
hid_times_list, hid_vals_list = _extract_sorted(hidden_series)
|
| 221 |
+
max_len_hid = max((len(t) for t in hid_times_list), default=0)
|
| 222 |
+
else:
|
| 223 |
+
hid_times_list = hid_vals_list = []
|
| 224 |
+
max_len_hid = 0
|
| 225 |
+
|
| 226 |
+
# Determine shared max length
|
| 227 |
+
T = max(max_len_sel, max_len_hid)
|
| 228 |
+
|
| 229 |
+
# Pad or return None depending on presence of data
|
| 230 |
+
if sel_times_list:
|
| 231 |
+
sel_obs, sel_times, sel_mask = _pad(sel_times_list, sel_vals_list, T)
|
| 232 |
+
else:
|
| 233 |
+
sel_obs = sel_times = sel_mask = None
|
| 234 |
+
|
| 235 |
+
if hid_times_list:
|
| 236 |
+
hid_obs, hid_times, hid_mask = _pad(hid_times_list, hid_vals_list, T)
|
| 237 |
+
else:
|
| 238 |
+
hid_obs = hid_times = hid_mask = None
|
| 239 |
+
|
| 240 |
+
return sel_obs, sel_times, sel_mask, hid_obs, hid_times, hid_mask
|
| 241 |
+
|
| 242 |
+
def substance_cvs_to_tensors_bundle(
|
| 243 |
+
drug_data_frame: pd.DataFrame,
|
| 244 |
+
**kwargs
|
| 245 |
+
) -> EmpiricalSubstanceTensorBundle:
|
| 246 |
+
"""
|
| 247 |
+
Groups by substance_label and returns padded tensors for:
|
| 248 |
+
- observations,
|
| 249 |
+
- times (normalized per-substance to [0, 1]),
|
| 250 |
+
- observation masks.
|
| 251 |
+
|
| 252 |
+
Handles invalid (NaN) values in observations, applies optional normalization,
|
| 253 |
+
and constructs per-substance tensors.
|
| 254 |
+
|
| 255 |
+
Also returns metadata:
|
| 256 |
+
- study_names: one per substance,
|
| 257 |
+
- subject_names: one per subject (padded to max P).
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
observations: TensorType["S", "I", "T"]
|
| 261 |
+
times: TensorType["S", "I", "T"]
|
| 262 |
+
masks: TensorType["S", "I", "T"]
|
| 263 |
+
subjects_mask: TensorType["S", "I"]
|
| 264 |
+
substance_labels: np.ndarray of length S
|
| 265 |
+
mapping: metadata dictionary
|
| 266 |
+
study_names: list of S strings
|
| 267 |
+
subject_names: list of S lists of I strings
|
| 268 |
+
"""
|
| 269 |
+
import numpy as np
|
| 270 |
+
import torch
|
| 271 |
+
import torch.nn.functional as F
|
| 272 |
+
|
| 273 |
+
substance_labels = np.unique(drug_data_frame["substance_label"].values)
|
| 274 |
+
mapping = map_substance_to_index_and_study(drug_data_frame)
|
| 275 |
+
|
| 276 |
+
substance_observations = []
|
| 277 |
+
substance_times = []
|
| 278 |
+
substance_masks = []
|
| 279 |
+
subject_masks = []
|
| 280 |
+
substance_doses = []
|
| 281 |
+
substance_routes = []
|
| 282 |
+
|
| 283 |
+
study_names_per_substance = []
|
| 284 |
+
subject_names_per_substance = []
|
| 285 |
+
|
| 286 |
+
max_time_steps = 0
|
| 287 |
+
max_subjects = 0
|
| 288 |
+
|
| 289 |
+
for substance_label in substance_labels:
|
| 290 |
+
df_sub = drug_data_frame[drug_data_frame["substance_label"] == substance_label]
|
| 291 |
+
obs, times, masks, doses, routes = substances_csv_to_tensors(
|
| 292 |
+
drug_data_frame, substance_label=substance_label
|
| 293 |
+
)
|
| 294 |
+
# obs, times, masks: [P, T]
|
| 295 |
+
|
| 296 |
+
valid_obs_mask = ~torch.isnan(obs)
|
| 297 |
+
masks = masks.bool() & valid_obs_mask
|
| 298 |
+
obs = obs.nan_to_num(nan=0.0)
|
| 299 |
+
|
| 300 |
+
max_time_steps = max(max_time_steps, obs.shape[1])
|
| 301 |
+
max_subjects = max(max_subjects, obs.shape[0])
|
| 302 |
+
|
| 303 |
+
# --- Metadata collection ---
|
| 304 |
+
grouped = df_sub.groupby("subject_name").first()
|
| 305 |
+
subject_names = list(grouped.index)
|
| 306 |
+
study_name = grouped["study_name"].iloc[0] if len(grouped) > 0 else ""
|
| 307 |
+
|
| 308 |
+
study_names_per_substance.append(study_name)
|
| 309 |
+
subject_names_per_substance.append(subject_names)
|
| 310 |
+
|
| 311 |
+
substance_observations.append(obs)
|
| 312 |
+
substance_times.append(times)
|
| 313 |
+
substance_masks.append(masks)
|
| 314 |
+
subject_masks.append(torch.ones(obs.shape[0], dtype=torch.float32)) # [P]
|
| 315 |
+
substance_doses.append(doses)
|
| 316 |
+
substance_routes.append(routes)
|
| 317 |
+
|
| 318 |
+
# Padding pass
|
| 319 |
+
all_observations, all_times, all_masks, all_subjects_mask = [], [], [], []
|
| 320 |
+
all_doses, all_routes = [], []
|
| 321 |
+
|
| 322 |
+
for obs, time, mask, subj_mask, subj_names, doses, routes in zip(
|
| 323 |
+
substance_observations,
|
| 324 |
+
substance_times,
|
| 325 |
+
substance_masks,
|
| 326 |
+
subject_masks,
|
| 327 |
+
subject_names_per_substance,
|
| 328 |
+
substance_doses,
|
| 329 |
+
substance_routes,
|
| 330 |
+
):
|
| 331 |
+
pad_subjects = max_subjects - obs.shape[0]
|
| 332 |
+
pad_timesteps = max_time_steps - obs.shape[1]
|
| 333 |
+
|
| 334 |
+
obs_padded = F.pad(obs, (0, pad_timesteps, 0, pad_subjects)) # [I, T]
|
| 335 |
+
time_padded = F.pad(time, (0, pad_timesteps, 0, pad_subjects)) # [I, T]
|
| 336 |
+
mask_padded = F.pad(mask, (0, pad_timesteps, 0, pad_subjects)) # [I, T]
|
| 337 |
+
subj_mask_padded = F.pad(subj_mask, (0, pad_subjects)) # [I]
|
| 338 |
+
dose_padded = F.pad(doses, (0, pad_subjects)) # [I]
|
| 339 |
+
route_padded = F.pad(routes, (0, pad_subjects)) # [I]
|
| 340 |
+
subj_names += [""] * pad_subjects # [I] → pad with ""
|
| 341 |
+
|
| 342 |
+
all_observations.append(obs_padded)
|
| 343 |
+
all_times.append(time_padded)
|
| 344 |
+
all_masks.append(mask_padded)
|
| 345 |
+
all_subjects_mask.append(subj_mask_padded)
|
| 346 |
+
all_doses.append(dose_padded)
|
| 347 |
+
all_routes.append(route_padded)
|
| 348 |
+
|
| 349 |
+
return EmpiricalSubstanceTensorBundle(
|
| 350 |
+
observations=torch.stack(all_observations), # [S, I, T]
|
| 351 |
+
times=torch.stack(all_times), # [S, I, T]
|
| 352 |
+
masks=torch.stack(all_masks), # [S, I, T]
|
| 353 |
+
individuals_mask=torch.stack(all_subjects_mask), # [S, I]
|
| 354 |
+
substance_names=list(substance_labels), # [S]
|
| 355 |
+
mapping=mapping,
|
| 356 |
+
study_names=study_names_per_substance, # [S]
|
| 357 |
+
individuals_names=subject_names_per_substance, # [S][I]
|
| 358 |
+
dosing_amounts=torch.stack(all_doses), # [S, I]
|
| 359 |
+
dosing_route_types=torch.stack(all_routes) # [S, I]
|
| 360 |
+
)
|
sim_priors_pk/data/data_preprocessing/tensors_to_databatch.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility for initializing :class:`AICMECompartmentsDataBatch` objects.
|
| 2 |
+
|
| 3 |
+
This small helper is primarily used in older preprocessing scripts. It takes
|
| 4 |
+
precomputed observation tensors and wraps them into a minimal
|
| 5 |
+
``AICMECompartmentsDataBatch`` where only the context fields are populated.
|
| 6 |
+
All other entries are set to empty tensors or placeholders so that the
|
| 7 |
+
resulting object conforms to the new metadata interface.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def initialize_aicme_batch(
|
| 18 |
+
observations: torch.Tensor,
|
| 19 |
+
observations_times: torch.Tensor,
|
| 20 |
+
observations_mask: torch.Tensor,
|
| 21 |
+
) -> AICMECompartmentsDataBatch:
|
| 22 |
+
"""Wrap raw tensors into an :class:`AICMECompartmentsDataBatch`.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
observations:
|
| 27 |
+
Tensor of shape ``[I, T]`` containing concentration values.
|
| 28 |
+
observations_times:
|
| 29 |
+
Tensor of shape ``[I, T]`` with the corresponding time points.
|
| 30 |
+
observations_mask:
|
| 31 |
+
Boolean tensor of shape ``[I, T]`` indicating valid entries.
|
| 32 |
+
|
| 33 |
+
Returns
|
| 34 |
+
-------
|
| 35 |
+
AICMECompartmentsDataBatch
|
| 36 |
+
Batch with ``B=1`` where all context fields are populated and the
|
| 37 |
+
remaining fields are placeholders (zeros or empty strings).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Add batch dimension (B=1) and feature dimension for observations and times
|
| 41 |
+
context_obs = observations.unsqueeze(0).unsqueeze(-1) # [1, I, T, 1]
|
| 42 |
+
context_obs_time = observations_times.unsqueeze(0).unsqueeze(-1) # [1, I, T, 1]
|
| 43 |
+
# Add batch dimension for mask
|
| 44 |
+
context_obs_mask = observations_mask.unsqueeze(0) # [1, I, T]
|
| 45 |
+
|
| 46 |
+
num_individuals = observations.shape[0]
|
| 47 |
+
|
| 48 |
+
return AICMECompartmentsDataBatch(
|
| 49 |
+
target_obs=None,
|
| 50 |
+
target_obs_time=None,
|
| 51 |
+
target_obs_mask=None,
|
| 52 |
+
target_rem_sim=None,
|
| 53 |
+
target_rem_sim_time=None,
|
| 54 |
+
target_rem_sim_mask=None,
|
| 55 |
+
target_dosing_amounts=torch.zeros(1, 0),
|
| 56 |
+
target_dosing_route_types=torch.zeros(1, 0, dtype=torch.long),
|
| 57 |
+
context_obs=context_obs,
|
| 58 |
+
context_obs_time=context_obs_time,
|
| 59 |
+
context_obs_mask=context_obs_mask,
|
| 60 |
+
context_rem_sim=None,
|
| 61 |
+
context_rem_sim_time=None,
|
| 62 |
+
context_rem_sim_mask=None,
|
| 63 |
+
context_dosing_amounts=torch.zeros(1, num_individuals),
|
| 64 |
+
context_dosing_route_types=torch.zeros(1, num_individuals, dtype=torch.long),
|
| 65 |
+
study_name=[""],
|
| 66 |
+
context_subject_name=[["" for _ in range(num_individuals)]],
|
| 67 |
+
target_subject_name=[["" for _ in range(0)]],
|
| 68 |
+
substance_name=[""],
|
| 69 |
+
time_scales=None,
|
| 70 |
+
is_empirical=False,
|
| 71 |
+
)
|
| 72 |
+
|
sim_priors_pk/data/datasets/aicme_batch.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Batch structures shared between synthetic and empirical pipelines."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections import namedtuple
|
| 6 |
+
from typing import List, NamedTuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torchtyping import TensorType
|
| 10 |
+
|
| 11 |
+
ShapeConfig = namedtuple(
|
| 12 |
+
"ShapeConfig",
|
| 13 |
+
[
|
| 14 |
+
"batch_size",
|
| 15 |
+
"c_individuals",
|
| 16 |
+
"num_obs_c",
|
| 17 |
+
"remaining_obs_c",
|
| 18 |
+
"t_individuals",
|
| 19 |
+
"num_obs_t",
|
| 20 |
+
"remaining_obs_t",
|
| 21 |
+
],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AICMECompartmentsDataBatch(NamedTuple):
|
| 26 |
+
"""Container aggregating context and target trajectories.
|
| 27 |
+
|
| 28 |
+
The tuple carries tensors describing observed measurements, simulated
|
| 29 |
+
remainders, dosing metadata and masking utilities used across both the
|
| 30 |
+
synthetic simulation pipeline and the empirical JSON tooling.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# max_num_individuals-max_n_new_individuals = n_c_individuals
|
| 34 |
+
target_obs: TensorType["B", "t_ind", "num_obs_t", 1]
|
| 35 |
+
target_obs_time: TensorType["B", "t_ind", "num_obs_t", 1]
|
| 36 |
+
target_obs_mask: TensorType["B", "t_ind", "num_obs_t"]
|
| 37 |
+
|
| 38 |
+
target_rem_sim: TensorType["B", "t_ind", "rem_obs_t", 1]
|
| 39 |
+
target_rem_sim_time: TensorType["B", "t_ind", "rem_obs_t", 1]
|
| 40 |
+
target_rem_sim_mask: TensorType["B", "t_ind", "rem_obs_t"]
|
| 41 |
+
|
| 42 |
+
context_obs: TensorType["B", "c_ind", "num_obs_c", 1]
|
| 43 |
+
context_obs_time: TensorType["B", "c_ind", "num_obs_c", 1]
|
| 44 |
+
context_obs_mask: TensorType["B", "c_ind", "num_obs_c"]
|
| 45 |
+
|
| 46 |
+
context_rem_sim: TensorType["B", "c_ind", "rem_obs_c", 1]
|
| 47 |
+
context_rem_sim_time: TensorType["B", "c_ind", "rem_obs_c", 1]
|
| 48 |
+
context_rem_sim_mask: TensorType["B", "c_ind", "rem_obs_c"]
|
| 49 |
+
|
| 50 |
+
# Dosing information
|
| 51 |
+
target_dosing_amounts: TensorType["B", "t_ind"]
|
| 52 |
+
target_dosing_route_types: TensorType["B", "t_ind"]
|
| 53 |
+
context_dosing_amounts: TensorType["B", "c_ind"]
|
| 54 |
+
context_dosing_route_types: TensorType["B", "c_ind"]
|
| 55 |
+
|
| 56 |
+
# Masks over padded individuals
|
| 57 |
+
mask_context_individuals: TensorType["B", "c_ind"]
|
| 58 |
+
mask_target_individuals: TensorType["B", "t_ind"]
|
| 59 |
+
|
| 60 |
+
# 🆕 NEW: tracking metadata
|
| 61 |
+
study_name: List[str]
|
| 62 |
+
"""Study identifier for each element in the batch (length ``B``)."""
|
| 63 |
+
context_subject_name: List[List[str]]
|
| 64 |
+
"""Names of context individuals: shape ``[B][c_ind]``."""
|
| 65 |
+
target_subject_name: List[List[str]]
|
| 66 |
+
"""Names of target individuals: shape ``[B][t_ind]``."""
|
| 67 |
+
substance_name: List[str]
|
| 68 |
+
"""Drug or compound names corresponding to each study (length ``B``)."""
|
| 69 |
+
|
| 70 |
+
# Meta information
|
| 71 |
+
time_scales: TensorType["B", 2] # shape : [B,2]
|
| 72 |
+
is_empirical: bool = False # NEW: True ⇢ empirical CSV, False ⇢ simulation
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def mask_individuals(self) -> TensorType["B", "c_ind"]:
|
| 76 |
+
"""Alias for backward compatibility; returns ``mask_context_individuals``."""
|
| 77 |
+
|
| 78 |
+
return self.mask_context_individuals
|
| 79 |
+
|
| 80 |
+
def detach_all(self) -> "AICMECompartmentsDataBatch":
|
| 81 |
+
"""Detaches all tensor fields from the computation graph."""
|
| 82 |
+
|
| 83 |
+
return AICMECompartmentsDataBatch(
|
| 84 |
+
*(t.detach() if isinstance(t, torch.Tensor) else t for t in self)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def log_transform(self) -> "AICMECompartmentsDataBatch":
|
| 88 |
+
"""Applies log transformation to observation and remainder tensors.
|
| 89 |
+
|
| 90 |
+
Deprecated for training: log scaling is now expected to be handled by
|
| 91 |
+
``PKScaler`` (for example via ``value_method="log"`` or
|
| 92 |
+
``value_method="log_and_max"``).
|
| 93 |
+
Kept for backward compatibility with older utilities.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
transformed_tensors = []
|
| 97 |
+
for name, tensor in zip(self._fields, self):
|
| 98 |
+
if name in [
|
| 99 |
+
"target_obs",
|
| 100 |
+
"target_rem_sim",
|
| 101 |
+
"context_obs",
|
| 102 |
+
"context_rem_sim",
|
| 103 |
+
]:
|
| 104 |
+
transformed_tensors.append(torch.log(tensor + 1e-6))
|
| 105 |
+
else:
|
| 106 |
+
transformed_tensors.append(tensor)
|
| 107 |
+
return AICMECompartmentsDataBatch(*transformed_tensors)
|
| 108 |
+
|
| 109 |
+
def to_device(self, device: torch.device) -> "AICMECompartmentsDataBatch":
|
| 110 |
+
"""Moves all tensor fields to the specified device (leaves strings untouched)."""
|
| 111 |
+
|
| 112 |
+
return AICMECompartmentsDataBatch(
|
| 113 |
+
*(t.to(device) if isinstance(t, torch.Tensor) else t for t in self)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def to(self, device: torch.device | str) -> "AICMECompartmentsDataBatch":
|
| 117 |
+
"""PyTorch-style alias delegating to :meth:`to_device`.
|
| 118 |
+
|
| 119 |
+
Several generic utilities expect batch-like objects to implement
|
| 120 |
+
``.to(device)``. Exposing this alias keeps the explicit
|
| 121 |
+
``to_device(...)`` API while allowing those utilities to move the full
|
| 122 |
+
databatch onto the target device safely.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
return self.to_device(torch.device(device))
|
| 126 |
+
|
| 127 |
+
def to_reconstruct_type(self) -> "AICMECompartmentsDataBatch":
|
| 128 |
+
"""
|
| 129 |
+
Return a new databatch where the target trajectories are reconstructed
|
| 130 |
+
by concatenating observed and remainder segments, then right-padding
|
| 131 |
+
so that the target has the same time dimension as the context.
|
| 132 |
+
The context is left untouched.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
B, Ic, Tc, _ = self.context_obs.shape # context time dimension is reference
|
| 136 |
+
_, It, _, _ = self.target_obs.shape
|
| 137 |
+
|
| 138 |
+
T_max = Tc # max length for padding
|
| 139 |
+
|
| 140 |
+
# allocate reconstructed tensors
|
| 141 |
+
Xt_full = torch.zeros(
|
| 142 |
+
B, It, T_max, 1, dtype=self.target_obs.dtype, device=self.target_obs.device
|
| 143 |
+
)
|
| 144 |
+
Tt_full = torch.zeros(
|
| 145 |
+
B, It, T_max, 1, dtype=self.target_obs_time.dtype, device=self.target_obs_time.device
|
| 146 |
+
)
|
| 147 |
+
Mt_full = torch.zeros(B, It, T_max, dtype=torch.bool, device=self.target_obs_mask.device)
|
| 148 |
+
|
| 149 |
+
# fill with observed + remainder segments
|
| 150 |
+
for b in range(B):
|
| 151 |
+
for i in range(It):
|
| 152 |
+
o_len = int(self.target_obs_mask[b, i].sum().item())
|
| 153 |
+
r_len = int(self.target_rem_sim_mask[b, i].sum().item())
|
| 154 |
+
total = o_len + r_len
|
| 155 |
+
if total == 0:
|
| 156 |
+
continue
|
| 157 |
+
Xt_full[b, i, :o_len] = self.target_obs[b, i, :o_len]
|
| 158 |
+
Xt_full[b, i, o_len:total] = self.target_rem_sim[b, i, :r_len]
|
| 159 |
+
Tt_full[b, i, :o_len] = self.target_obs_time[b, i, :o_len]
|
| 160 |
+
Tt_full[b, i, o_len:total] = self.target_rem_sim_time[b, i, :r_len]
|
| 161 |
+
Mt_full[b, i, :total] = True
|
| 162 |
+
|
| 163 |
+
return self._replace(
|
| 164 |
+
target_obs=Xt_full,
|
| 165 |
+
target_obs_time=Tt_full,
|
| 166 |
+
target_obs_mask=Mt_full,
|
| 167 |
+
)
|
sim_priors_pk/data/datasets/aicme_datasets.py
ADDED
|
@@ -0,0 +1,1874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import tempfile
|
| 4 |
+
import warnings
|
| 5 |
+
from dataclasses import replace
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 8 |
+
|
| 9 |
+
import lightning.pytorch as pl
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset
|
| 13 |
+
from torch.utils.data.dataloader import default_collate
|
| 14 |
+
|
| 15 |
+
from sim_priors_pk import data_dir
|
| 16 |
+
from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig
|
| 17 |
+
from sim_priors_pk.data.data_generation.compartment_models_management import (
|
| 18 |
+
prepare_full_simulation,
|
| 19 |
+
prepare_full_simulation_list_with_repeated_targets as prepare_full_simulation_list_with_repeated_targets_backend,
|
| 20 |
+
prepare_full_simulation_with_repeated_targets,
|
| 21 |
+
)
|
| 22 |
+
from sim_priors_pk.data.data_generation.observations_classes import (
|
| 23 |
+
ObservationStrategyFactory,
|
| 24 |
+
)
|
| 25 |
+
from sim_priors_pk.data.datasets.aicme_batch import (
|
| 26 |
+
AICMECompartmentsDataBatch,
|
| 27 |
+
)
|
| 28 |
+
from sim_priors_pk.utils.tensors_operations import ensure_mask_or_empty, ensure_tensor_or_empty
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def ensure_min_valid(mask, min_length):
|
| 32 |
+
"""
|
| 33 |
+
Ensures that each row of the last dimension in the mask has at least `min_length` valid (1s) entries.
|
| 34 |
+
"""
|
| 35 |
+
valid_counts = mask.sum(dim=-1, keepdim=True) # Count valid entries along time dimension
|
| 36 |
+
needs_fixing = valid_counts < min_length # Identify sequences needing more valid entries
|
| 37 |
+
|
| 38 |
+
if needs_fixing.any():
|
| 39 |
+
# Find the top `min_length` indices in each row (sorted for deterministic filling)
|
| 40 |
+
_, topk_indices = torch.topk(
|
| 41 |
+
mask + torch.rand_like(mask) * 0.01, k=min_length, dim=-1, sorted=True
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Create an empty mask and scatter `1`s at selected indices
|
| 45 |
+
fixed_mask = torch.zeros_like(mask)
|
| 46 |
+
fixed_mask.scatter_(-1, topk_indices, 1.0)
|
| 47 |
+
|
| 48 |
+
# Combine the original and fixed masks
|
| 49 |
+
mask = torch.where(needs_fixing, fixed_mask, mask)
|
| 50 |
+
|
| 51 |
+
return mask
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def is_valid_simulation(sim: torch.Tensor) -> bool:
|
| 55 |
+
"""Returns True if the simulation is numerically valid and all values are < 10."""
|
| 56 |
+
return torch.isfinite(sim).all() and (sim >= 0).all() and (sim < 10).all()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _stack_one_perm(
|
| 60 |
+
batches: Sequence["AICMECompartmentsDataBatch"],
|
| 61 |
+
) -> "AICMECompartmentsDataBatch":
|
| 62 |
+
result = []
|
| 63 |
+
for f in AICMECompartmentsDataBatch._fields:
|
| 64 |
+
items = [getattr(b, f) for b in batches]
|
| 65 |
+
|
| 66 |
+
if f in {"study_name", "substance_name"}:
|
| 67 |
+
merged = []
|
| 68 |
+
for it in items:
|
| 69 |
+
if isinstance(it, (list, tuple)):
|
| 70 |
+
merged.extend(map(str, it))
|
| 71 |
+
elif isinstance(it, str):
|
| 72 |
+
merged.append(it)
|
| 73 |
+
else:
|
| 74 |
+
raise TypeError(f"Unexpected type for {f}: {type(it)}")
|
| 75 |
+
result.append(merged)
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
if f in {"context_subject_name", "target_subject_name"}:
|
| 79 |
+
merged_lls = []
|
| 80 |
+
for it in items:
|
| 81 |
+
if isinstance(it, (list, tuple)):
|
| 82 |
+
merged_lls.extend([list(inner) for inner in it])
|
| 83 |
+
else:
|
| 84 |
+
raise TypeError(f"Unexpected type for {f}: {type(it)}")
|
| 85 |
+
result.append(merged_lls)
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
result.append(default_collate(items))
|
| 89 |
+
|
| 90 |
+
return AICMECompartmentsDataBatch(*result)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _collate_aicme_batches(batch_list):
|
| 94 |
+
"""
|
| 95 |
+
Handles:
|
| 96 |
+
- [B] of AICMECompartmentsDataBatch → returns one collated batch
|
| 97 |
+
- [B][P] of AICMECompartmentsDataBatch → returns list of P collated batches
|
| 98 |
+
"""
|
| 99 |
+
if not batch_list:
|
| 100 |
+
return batch_list
|
| 101 |
+
|
| 102 |
+
first = batch_list[0]
|
| 103 |
+
|
| 104 |
+
# Case 1: flat list of AICME batches
|
| 105 |
+
if hasattr(first, "_fields"): # NamedTuple-like
|
| 106 |
+
return _stack_one_perm(batch_list)
|
| 107 |
+
|
| 108 |
+
# Case 2: nested [B][P]
|
| 109 |
+
if isinstance(first, (list, tuple)) and hasattr(first[0], "_fields"):
|
| 110 |
+
# transpose [B][P] -> [P][B]
|
| 111 |
+
transposed = list(zip(*batch_list))
|
| 112 |
+
return [_stack_one_perm(list(group)) for group in transposed]
|
| 113 |
+
|
| 114 |
+
# If we reach here and elements are Tensors, do NOT recurse further.
|
| 115 |
+
if torch.is_tensor(first):
|
| 116 |
+
raise TypeError(
|
| 117 |
+
"Got a list of tensors instead of AICMECompartmentsDataBatch. "
|
| 118 |
+
"Check that your Dataset returns AICMECompartmentsDataBatch, not raw tensors."
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
raise TypeError(
|
| 122 |
+
f"Unexpected element type in batch_list: {type(first)}. "
|
| 123 |
+
"Expected AICMECompartmentsDataBatch or list thereof."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def split_individuals_tensor_batch(
|
| 128 |
+
full_tensor_a: torch.Tensor,
|
| 129 |
+
full_tensor_b: torch.Tensor,
|
| 130 |
+
full_tensor_c: Optional[torch.Tensor],
|
| 131 |
+
n_of_target_individuals: int,
|
| 132 |
+
seed: Optional[int] = None,
|
| 133 |
+
) -> Tuple[
|
| 134 |
+
torch.Tensor,
|
| 135 |
+
torch.Tensor,
|
| 136 |
+
Optional[torch.Tensor],
|
| 137 |
+
torch.Tensor,
|
| 138 |
+
torch.Tensor,
|
| 139 |
+
Optional[torch.Tensor],
|
| 140 |
+
]:
|
| 141 |
+
num_individuals = full_tensor_a.shape[0]
|
| 142 |
+
if seed is not None:
|
| 143 |
+
random.seed(seed)
|
| 144 |
+
|
| 145 |
+
if n_of_target_individuals == 0:
|
| 146 |
+
return full_tensor_a, full_tensor_b, full_tensor_c, None, None, None
|
| 147 |
+
|
| 148 |
+
all_indices = list(range(num_individuals))
|
| 149 |
+
target_indices = random.sample(all_indices, n_of_target_individuals)
|
| 150 |
+
context_indices = [i for i in all_indices if i not in target_indices]
|
| 151 |
+
|
| 152 |
+
context_a = full_tensor_a[context_indices]
|
| 153 |
+
context_b = full_tensor_b[context_indices]
|
| 154 |
+
context_c = full_tensor_c[context_indices] if full_tensor_c is not None else None
|
| 155 |
+
|
| 156 |
+
target_a = full_tensor_a[target_indices]
|
| 157 |
+
target_b = full_tensor_b[target_indices]
|
| 158 |
+
target_c = full_tensor_c[target_indices] if full_tensor_c is not None else None
|
| 159 |
+
|
| 160 |
+
return context_a, context_b, context_c, target_a, target_b, target_c
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def list_of_databath_to_device(
|
| 164 |
+
batch_list: List[AICMECompartmentsDataBatch],
|
| 165 |
+
device: torch.device | str,
|
| 166 |
+
) -> List[AICMECompartmentsDataBatch]:
|
| 167 |
+
"""Move a list of batches to ``device``.
|
| 168 |
+
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
batch_list:
|
| 172 |
+
List of :class:`AICMECompartmentsDataBatch` objects.
|
| 173 |
+
device:
|
| 174 |
+
Target device.
|
| 175 |
+
"""
|
| 176 |
+
return [b.to_device(device) for b in batch_list]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def build_reconstruction_db(
|
| 180 |
+
db: AICMECompartmentsDataBatch,
|
| 181 |
+
) -> AICMECompartmentsDataBatch:
|
| 182 |
+
"""
|
| 183 |
+
Reconstruct the target trajectories by concatenating observed and remainder
|
| 184 |
+
segments, then right-padding so that the target has the same time dimension
|
| 185 |
+
as the context. The context is left untouched.
|
| 186 |
+
|
| 187 |
+
Returns a new AICMECompartmentsDataBatch.
|
| 188 |
+
"""
|
| 189 |
+
B, Ic, Tc, _ = db.context_obs.shape # context shape is the reference
|
| 190 |
+
_, It, _, _ = db.target_obs.shape
|
| 191 |
+
|
| 192 |
+
# reference length for padding (use context time dim)
|
| 193 |
+
T_max = Tc
|
| 194 |
+
|
| 195 |
+
# allocate new target tensors
|
| 196 |
+
Xt_full = torch.zeros(B, It, T_max, 1, dtype=db.target_obs.dtype, device=db.target_obs.device)
|
| 197 |
+
Tt_full = torch.zeros(
|
| 198 |
+
B, It, T_max, 1, dtype=db.target_obs_time.dtype, device=db.target_obs_time.device
|
| 199 |
+
)
|
| 200 |
+
Mt_full = torch.zeros(B, It, T_max, dtype=torch.bool, device=db.target_obs_mask.device)
|
| 201 |
+
|
| 202 |
+
# fill reconstructed target
|
| 203 |
+
for b in range(B):
|
| 204 |
+
for i in range(It):
|
| 205 |
+
o_len = int(db.target_obs_mask[b, i].sum().item())
|
| 206 |
+
r_len = int(db.target_rem_sim_mask[b, i].sum().item())
|
| 207 |
+
total = o_len + r_len
|
| 208 |
+
if total == 0:
|
| 209 |
+
continue
|
| 210 |
+
Xt_full[b, i, :o_len] = db.target_obs[b, i, :o_len]
|
| 211 |
+
Xt_full[b, i, o_len:total] = db.target_rem_sim[b, i, :r_len]
|
| 212 |
+
Tt_full[b, i, :o_len] = db.target_obs_time[b, i, :o_len]
|
| 213 |
+
Tt_full[b, i, o_len:total] = db.target_rem_sim_time[b, i, :r_len]
|
| 214 |
+
Mt_full[b, i, :total] = True
|
| 215 |
+
|
| 216 |
+
# replace only the target fields
|
| 217 |
+
return db._replace(
|
| 218 |
+
target_obs=Xt_full,
|
| 219 |
+
target_obs_time=Tt_full,
|
| 220 |
+
target_obs_mask=Mt_full,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class AICMECompartmentsDataset(Dataset):
|
| 225 |
+
"""Dataset generating synthetic PK batches for AICME models.
|
| 226 |
+
|
| 227 |
+
Target observation strategies should already divide past and future
|
| 228 |
+
observations (``split_past_future=True``).
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(
|
| 232 |
+
self,
|
| 233 |
+
model_config: NodePKExperimentConfig,
|
| 234 |
+
ctx_fn,
|
| 235 |
+
tgt_fn,
|
| 236 |
+
number_of_process=1000,
|
| 237 |
+
*,
|
| 238 |
+
store_in_tempfile: bool = False,
|
| 239 |
+
keep_tempfile: bool = False,
|
| 240 |
+
recreate_tempfile: bool = False,
|
| 241 |
+
tempfile_path: str | None = None,
|
| 242 |
+
show_progress: bool = True,
|
| 243 |
+
split: str = "",
|
| 244 |
+
use_shared_target_dosing: bool = False,
|
| 245 |
+
shared_target_n_targets: int = 100,
|
| 246 |
+
):
|
| 247 |
+
self.mix_data_config = model_config.mix_data
|
| 248 |
+
self.meta_study_config = model_config.meta_study
|
| 249 |
+
self.meta_dosing_config = model_config.dosing
|
| 250 |
+
self.number_of_process = number_of_process
|
| 251 |
+
# ``n_of_permutations`` specifies how many shuffled versions of the
|
| 252 |
+
# context/target split are generated for a single simulation.
|
| 253 |
+
# ``n_of_databatches`` is a deprecated alias kept for backward
|
| 254 |
+
# compatibility and mirrors ``n_of_permutations``.
|
| 255 |
+
self.n_of_permutations = model_config.mix_data.n_of_permutations
|
| 256 |
+
self.n_of_databatches = self.n_of_permutations # deprecated alias
|
| 257 |
+
self.n_of_target_individuals = int(model_config.mix_data.n_of_target_individuals)
|
| 258 |
+
if self.n_of_target_individuals < 0:
|
| 259 |
+
raise ValueError("n_of_target_individuals must be >= 0")
|
| 260 |
+
|
| 261 |
+
# `num_individuals_range` controls context individuals only.
|
| 262 |
+
self.min_context_individuals = int(self.meta_study_config.num_individuals_range[0])
|
| 263 |
+
self.max_context_individuals = int(self.meta_study_config.num_individuals_range[-1])
|
| 264 |
+
if self.min_context_individuals < 0:
|
| 265 |
+
raise ValueError("meta_study.num_individuals_range minimum must be >= 0")
|
| 266 |
+
if self.max_context_individuals < self.min_context_individuals:
|
| 267 |
+
raise ValueError("meta_study.num_individuals_range must satisfy max >= min")
|
| 268 |
+
|
| 269 |
+
# Fixed total capacity used by downstream consumers.
|
| 270 |
+
self.max_individuals = self.max_context_individuals + self.n_of_target_individuals
|
| 271 |
+
|
| 272 |
+
self.context_fn = ctx_fn
|
| 273 |
+
self.target_fn = tgt_fn
|
| 274 |
+
self.store_in_tempfile = store_in_tempfile
|
| 275 |
+
self.keep_tempfile = keep_tempfile
|
| 276 |
+
self.recreate_tempfile = recreate_tempfile
|
| 277 |
+
self.show_progress = True
|
| 278 |
+
self._tmpfile_path: List[str] | None = None
|
| 279 |
+
self._loaded_data = None
|
| 280 |
+
self.run_id = getattr(model_config, "run_index", 0)
|
| 281 |
+
self.model_name = model_config.name_str
|
| 282 |
+
|
| 283 |
+
if self.store_in_tempfile:
|
| 284 |
+
self._prepare_tempfile_data(tempfile_path=tempfile_path, split=split)
|
| 285 |
+
|
| 286 |
+
self.use_shared_target_dosing = use_shared_target_dosing
|
| 287 |
+
self.shared_target_n_targets = shared_target_n_targets
|
| 288 |
+
|
| 289 |
+
def __del__(self):
|
| 290 |
+
if (
|
| 291 |
+
self.store_in_tempfile
|
| 292 |
+
and not self.keep_tempfile
|
| 293 |
+
and self._tmpfile_path
|
| 294 |
+
and os.path.exists(self._tmpfile_path)
|
| 295 |
+
):
|
| 296 |
+
os.remove(self._tmpfile_path)
|
| 297 |
+
|
| 298 |
+
def __len__(self):
|
| 299 |
+
return self.number_of_process # Arbitrary large number to simulate infinite data
|
| 300 |
+
|
| 301 |
+
def _prepare_tempfile_data(self, *, tempfile_path: str | None, split: str) -> None:
|
| 302 |
+
"""Handle creation and (re)generation of the temporary data file."""
|
| 303 |
+
if tempfile_path is None:
|
| 304 |
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
|
| 305 |
+
self._tmpfile_path = tmp.name
|
| 306 |
+
tmp.close()
|
| 307 |
+
else:
|
| 308 |
+
# Allow both Tuple paths from YAML and plain strings
|
| 309 |
+
if isinstance(tempfile_path, (tuple, list)):
|
| 310 |
+
base_path = os.path.join(data_dir, *tempfile_path)
|
| 311 |
+
else:
|
| 312 |
+
base_path = tempfile_path
|
| 313 |
+
|
| 314 |
+
dirname = os.path.dirname(base_path)
|
| 315 |
+
basename = os.path.basename(base_path)
|
| 316 |
+
|
| 317 |
+
suffix = f"_{self.model_name}_{split}"
|
| 318 |
+
if self.run_id is not None:
|
| 319 |
+
suffix += f"_run{self.run_id}"
|
| 320 |
+
new_basename = basename + suffix + ".tr"
|
| 321 |
+
|
| 322 |
+
self._tmpfile_path = os.path.join(dirname, new_basename)
|
| 323 |
+
|
| 324 |
+
if self.recreate_tempfile or not os.path.exists(self._tmpfile_path):
|
| 325 |
+
print("RECREATING DATASET!")
|
| 326 |
+
iterator = range(self.number_of_process)
|
| 327 |
+
if self.show_progress:
|
| 328 |
+
from tqdm.auto import tqdm
|
| 329 |
+
|
| 330 |
+
iterator = tqdm(iterator, desc="Generating AICME data")
|
| 331 |
+
data = [self._generate_item(i) for i in iterator]
|
| 332 |
+
torch.save(data, self._tmpfile_path)
|
| 333 |
+
|
| 334 |
+
def split_simulations(
|
| 335 |
+
self, full_simulation, full_simulation_times
|
| 336 |
+
) -> Tuple[
|
| 337 |
+
torch.Tensor,
|
| 338 |
+
torch.Tensor,
|
| 339 |
+
Optional[torch.Tensor],
|
| 340 |
+
Optional[torch.Tensor],
|
| 341 |
+
list[int],
|
| 342 |
+
list[int],
|
| 343 |
+
]:
|
| 344 |
+
"""
|
| 345 |
+
From the full simulation, randomly select `n_of_target_individuals` as targets and keep the rest as context.
|
| 346 |
+
If `n_of_target_individuals == 0`, returns None for the target fields.
|
| 347 |
+
"""
|
| 348 |
+
n_of_target_individuals = self.n_of_target_individuals
|
| 349 |
+
num_individuals = full_simulation.shape[0]
|
| 350 |
+
|
| 351 |
+
if n_of_target_individuals == 0:
|
| 352 |
+
context_simulation = full_simulation
|
| 353 |
+
context_simulation_times = full_simulation_times
|
| 354 |
+
return (
|
| 355 |
+
context_simulation,
|
| 356 |
+
context_simulation_times,
|
| 357 |
+
None,
|
| 358 |
+
None,
|
| 359 |
+
list(range(num_individuals)),
|
| 360 |
+
[],
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if num_individuals < n_of_target_individuals:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
"Simulation contains fewer individuals than requested targets: "
|
| 366 |
+
f"num_individuals={num_individuals}, "
|
| 367 |
+
f"n_of_target_individuals={n_of_target_individuals}."
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Randomly select indices for target individuals
|
| 371 |
+
target_indices = random.sample(range(num_individuals), n_of_target_individuals)
|
| 372 |
+
context_indices = [i for i in range(num_individuals) if i not in target_indices]
|
| 373 |
+
|
| 374 |
+
# Split the simulations, times, and masks
|
| 375 |
+
target_simulation = full_simulation[target_indices]
|
| 376 |
+
target_simulation_times = full_simulation_times[target_indices]
|
| 377 |
+
context_simulation = full_simulation[context_indices]
|
| 378 |
+
context_simulation_times = full_simulation_times[context_indices]
|
| 379 |
+
|
| 380 |
+
return (
|
| 381 |
+
context_simulation,
|
| 382 |
+
context_simulation_times,
|
| 383 |
+
target_simulation,
|
| 384 |
+
target_simulation_times,
|
| 385 |
+
context_indices,
|
| 386 |
+
target_indices,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
def _build_generation_meta_study_config(self):
|
| 390 |
+
"""Return a meta-study config where totals include fixed target individuals.
|
| 391 |
+
|
| 392 |
+
The user-facing ``meta_study.num_individuals_range`` represents context
|
| 393 |
+
individuals only. For raw simulation generation, we therefore sample
|
| 394 |
+
``context + n_of_target_individuals`` total individuals.
|
| 395 |
+
"""
|
| 396 |
+
total_min = self.min_context_individuals + self.n_of_target_individuals
|
| 397 |
+
total_max = self.max_context_individuals + self.n_of_target_individuals
|
| 398 |
+
|
| 399 |
+
if getattr(self.meta_study_config, "simple_mode", False):
|
| 400 |
+
total_individuals = random.randint(total_min, total_max)
|
| 401 |
+
return replace(
|
| 402 |
+
self.meta_study_config,
|
| 403 |
+
num_individuals=total_individuals,
|
| 404 |
+
num_individuals_range=(total_individuals, total_individuals),
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
return replace(
|
| 408 |
+
self.meta_study_config,
|
| 409 |
+
num_individuals_range=(total_min, total_max),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def __getitem__(self, idx):
|
| 413 |
+
if self.store_in_tempfile:
|
| 414 |
+
if self._loaded_data is None:
|
| 415 |
+
self._loaded_data = torch.load(self._tmpfile_path, weights_only=False)
|
| 416 |
+
# If in distributed mode, adjust the index based on process rank/world size
|
| 417 |
+
if torch.distributed.is_initialized():
|
| 418 |
+
rank = torch.distributed.get_rank()
|
| 419 |
+
world_size = torch.distributed.get_world_size()
|
| 420 |
+
total_len = len(self._loaded_data)
|
| 421 |
+
# Compute adjusted indices for this rank
|
| 422 |
+
adjusted_idx = idx * world_size + rank
|
| 423 |
+
if adjusted_idx >= total_len:
|
| 424 |
+
# If we would go out of bounds, wrap around to get a valid index
|
| 425 |
+
adjusted_idx = adjusted_idx % total_len
|
| 426 |
+
return self._loaded_data[adjusted_idx]
|
| 427 |
+
return self._loaded_data[idx]
|
| 428 |
+
|
| 429 |
+
if self.use_shared_target_dosing:
|
| 430 |
+
return self._generate_item_sample_target_dosing(
|
| 431 |
+
idx, n_targets=self.shared_target_n_targets
|
| 432 |
+
)
|
| 433 |
+
return self._generate_item(idx)
|
| 434 |
+
|
| 435 |
+
def _generate_item(self, idx) -> List[AICMECompartmentsDataBatch]:
|
| 436 |
+
"""Generate a list of ``AICMECompartmentsDataBatch`` objects.
|
| 437 |
+
|
| 438 |
+
Each element corresponds to one permutation of the context/target split.
|
| 439 |
+
Target observations are generated using ``target_fn``, which is expected
|
| 440 |
+
to divide past and future observations.
|
| 441 |
+
"""
|
| 442 |
+
(
|
| 443 |
+
full_simulation,
|
| 444 |
+
full_simulation_times,
|
| 445 |
+
dosing_amounts,
|
| 446 |
+
dosing_routes,
|
| 447 |
+
time_points,
|
| 448 |
+
time_scales,
|
| 449 |
+
) = prepare_full_simulation(
|
| 450 |
+
self._build_generation_meta_study_config(),
|
| 451 |
+
self.meta_dosing_config,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
list_of_databatches: List[AICMECompartmentsDataBatch] = []
|
| 455 |
+
for _ in range(self.n_of_permutations):
|
| 456 |
+
# Split into context and target
|
| 457 |
+
(
|
| 458 |
+
context_simulation,
|
| 459 |
+
context_simulation_times,
|
| 460 |
+
target_simulation,
|
| 461 |
+
target_simulation_times,
|
| 462 |
+
context_indices,
|
| 463 |
+
target_indices,
|
| 464 |
+
) = self.split_simulations(full_simulation, full_simulation_times)
|
| 465 |
+
|
| 466 |
+
context_observations = self._safe_generate(
|
| 467 |
+
self.context_fn,
|
| 468 |
+
context_simulation,
|
| 469 |
+
context_simulation_times,
|
| 470 |
+
time_scales=time_scales,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
target_observations = self._safe_generate(
|
| 474 |
+
self.target_fn,
|
| 475 |
+
target_simulation,
|
| 476 |
+
target_simulation_times,
|
| 477 |
+
time_scales=time_scales,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
(
|
| 481 |
+
context_obs, # [c_ind, num_obs_c, 1]
|
| 482 |
+
context_obs_time, # [c_ind, num_obs_c, 1]
|
| 483 |
+
context_obs_mask, # [c_ind, num_obs_c]
|
| 484 |
+
context_rem_sim, # [c_ind, rem_obs_c, 1]
|
| 485 |
+
context_rem_sim_time, # [c_ind, rem_obs_c, 1]
|
| 486 |
+
context_rem_sim_mask, # [c_ind, rem_obs_c]
|
| 487 |
+
context_time_scales,
|
| 488 |
+
) = context_observations
|
| 489 |
+
|
| 490 |
+
(
|
| 491 |
+
target_obs, # [t_ind, num_obs_t, 1]
|
| 492 |
+
target_obs_time, # [t_ind, num_obs_t, 1]
|
| 493 |
+
target_obs_mask, # [t_ind, num_obs_t]
|
| 494 |
+
target_rem_sim, # [t_ind, rem_obs_t, 1]
|
| 495 |
+
target_rem_sim_time, # [t_ind, rem_obs_t, 1]
|
| 496 |
+
target_rem_sim_mask, # [t_ind, rem_obs_t]
|
| 497 |
+
target_time_scales,
|
| 498 |
+
) = target_observations
|
| 499 |
+
|
| 500 |
+
# Use provided time scales or fall back to simulation defaults
|
| 501 |
+
ts = (
|
| 502 |
+
context_time_scales
|
| 503 |
+
if context_time_scales is not None
|
| 504 |
+
else target_time_scales
|
| 505 |
+
if target_time_scales is not None
|
| 506 |
+
else time_scales
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
batch = self._build_padded_batch(
|
| 510 |
+
context_obs,
|
| 511 |
+
context_obs_time,
|
| 512 |
+
context_obs_mask,
|
| 513 |
+
context_rem_sim,
|
| 514 |
+
context_rem_sim_time,
|
| 515 |
+
context_rem_sim_mask,
|
| 516 |
+
dosing_amounts[context_indices],
|
| 517 |
+
dosing_routes[context_indices],
|
| 518 |
+
target_obs,
|
| 519 |
+
target_obs_time,
|
| 520 |
+
target_obs_mask,
|
| 521 |
+
target_rem_sim,
|
| 522 |
+
target_rem_sim_time,
|
| 523 |
+
target_rem_sim_mask,
|
| 524 |
+
dosing_amounts[target_indices] if len(target_indices) > 0 else None,
|
| 525 |
+
dosing_routes[target_indices] if len(target_indices) > 0 else None,
|
| 526 |
+
ts,
|
| 527 |
+
)
|
| 528 |
+
list_of_databatches.append(batch)
|
| 529 |
+
|
| 530 |
+
return list_of_databatches
|
| 531 |
+
|
| 532 |
+
def _generate_item_sample_target_dosing(
|
| 533 |
+
self,
|
| 534 |
+
idx: int,
|
| 535 |
+
n_targets: int = 100,
|
| 536 |
+
different_dosing: bool = False,
|
| 537 |
+
):
|
| 538 |
+
(
|
| 539 |
+
context_sim,
|
| 540 |
+
context_times,
|
| 541 |
+
target_sim,
|
| 542 |
+
target_times,
|
| 543 |
+
dosing_amounts_ctx,
|
| 544 |
+
dosing_routes_ctx,
|
| 545 |
+
dosing_amounts_tgt,
|
| 546 |
+
dosing_routes_tgt,
|
| 547 |
+
time_points,
|
| 548 |
+
time_scales,
|
| 549 |
+
) = prepare_full_simulation_with_repeated_targets(
|
| 550 |
+
self.meta_study_config,
|
| 551 |
+
self.meta_dosing_config,
|
| 552 |
+
n_targets,
|
| 553 |
+
different_dosing=different_dosing,
|
| 554 |
+
idx=idx,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# Observations
|
| 558 |
+
context_obs_pack = self._safe_generate(
|
| 559 |
+
self.context_fn, context_sim, context_times, time_scales=time_scales
|
| 560 |
+
)
|
| 561 |
+
target_obs_pack = self._safe_generate(
|
| 562 |
+
self.target_fn, target_sim, target_times, time_scales=time_scales
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
(
|
| 566 |
+
context_obs,
|
| 567 |
+
context_obs_time,
|
| 568 |
+
context_obs_mask,
|
| 569 |
+
context_rem_sim,
|
| 570 |
+
context_rem_sim_time,
|
| 571 |
+
context_rem_sim_mask,
|
| 572 |
+
context_time_scales,
|
| 573 |
+
) = context_obs_pack
|
| 574 |
+
|
| 575 |
+
(
|
| 576 |
+
target_obs,
|
| 577 |
+
target_obs_time,
|
| 578 |
+
target_obs_mask,
|
| 579 |
+
target_rem_sim,
|
| 580 |
+
target_rem_sim_time,
|
| 581 |
+
target_rem_sim_mask,
|
| 582 |
+
target_time_scales,
|
| 583 |
+
) = target_obs_pack
|
| 584 |
+
|
| 585 |
+
ts = (
|
| 586 |
+
context_time_scales
|
| 587 |
+
if context_time_scales is not None
|
| 588 |
+
else (target_time_scales or time_scales)
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# Build batch
|
| 592 |
+
batch = self._build_padded_batch(
|
| 593 |
+
# context
|
| 594 |
+
context_obs,
|
| 595 |
+
context_obs_time,
|
| 596 |
+
context_obs_mask,
|
| 597 |
+
context_rem_sim,
|
| 598 |
+
context_rem_sim_time,
|
| 599 |
+
context_rem_sim_mask,
|
| 600 |
+
dosing_amounts_ctx,
|
| 601 |
+
dosing_routes_ctx,
|
| 602 |
+
# target
|
| 603 |
+
target_obs,
|
| 604 |
+
target_obs_time,
|
| 605 |
+
target_obs_mask,
|
| 606 |
+
target_rem_sim,
|
| 607 |
+
target_rem_sim_time,
|
| 608 |
+
target_rem_sim_mask,
|
| 609 |
+
dosing_amounts_tgt,
|
| 610 |
+
dosing_routes_tgt,
|
| 611 |
+
# time scales
|
| 612 |
+
ts=ts,
|
| 613 |
+
target_capacity=n_targets,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
return [batch]
|
| 617 |
+
|
| 618 |
+
# ------------------------------------------------------------------ #
|
| 619 |
+
# utilities
|
| 620 |
+
# ------------------------------------------------------------------ #
|
| 621 |
+
|
| 622 |
+
def _build_padded_batch(
|
| 623 |
+
self,
|
| 624 |
+
ctx_obs: Tensor, # [c_ind, num_obs_c]
|
| 625 |
+
ctx_time: Tensor, # [c_ind, num_obs_c]
|
| 626 |
+
ctx_mask: Tensor, # [c_ind, num_obs_c]
|
| 627 |
+
ctx_rem: Optional[Tensor], # [c_ind, rem_obs_c] | None
|
| 628 |
+
ctx_rem_time: Optional[Tensor], # [c_ind, rem_obs_c] | None
|
| 629 |
+
ctx_rem_mask: Optional[Tensor], # [c_ind, rem_obs_c] | None
|
| 630 |
+
ctx_dose: Tensor, # [c_ind]
|
| 631 |
+
ctx_route: Tensor, # [c_ind]
|
| 632 |
+
tgt_obs: Optional[Tensor], # [t_ind, num_obs_t] | None
|
| 633 |
+
tgt_time: Optional[Tensor], # [t_ind, num_obs_t] | None
|
| 634 |
+
tgt_mask: Optional[Tensor], # [t_ind, num_obs_t] | None
|
| 635 |
+
tgt_rem: Optional[Tensor], # [t_ind, rem_obs_t] | None
|
| 636 |
+
tgt_rem_time: Optional[Tensor], # [t_ind, rem_obs_t] | None
|
| 637 |
+
tgt_rem_mask: Optional[Tensor], # [t_ind, rem_obs_t] | None
|
| 638 |
+
tgt_dose: Optional[Tensor], # [t_ind] | None
|
| 639 |
+
tgt_route: Optional[Tensor], # [t_ind] | None
|
| 640 |
+
ts: Tensor, # [B(=1), 2]
|
| 641 |
+
*,
|
| 642 |
+
target_capacity: Optional[
|
| 643 |
+
int
|
| 644 |
+
] = None, # ← NEW (optional). If None, use self.n_of_target_individuals
|
| 645 |
+
) -> AICMECompartmentsDataBatch:
|
| 646 |
+
"""Pad context and target tensors then pack them into a batch."""
|
| 647 |
+
|
| 648 |
+
max_c = self.max_context_individuals # (unchanged)
|
| 649 |
+
max_t = (
|
| 650 |
+
target_capacity if target_capacity is not None else self.n_of_target_individuals
|
| 651 |
+
) # ← ONLY CHANGE
|
| 652 |
+
|
| 653 |
+
# ── target padding (unchanged) ─────────────────────────────────────────
|
| 654 |
+
t_obs_p = self._pad_first_dim(
|
| 655 |
+
ensure_tensor_or_empty(
|
| 656 |
+
tgt_obs.unsqueeze(-1) if tgt_obs is not None else None, (1, 1, 1)
|
| 657 |
+
), # to [t_ind, Tt, 1]
|
| 658 |
+
max_t,
|
| 659 |
+
)
|
| 660 |
+
t_time_p = self._pad_first_dim(
|
| 661 |
+
ensure_tensor_or_empty(
|
| 662 |
+
tgt_time.unsqueeze(-1) if tgt_time is not None else None, (1, 1, 1)
|
| 663 |
+
), # to [t_ind, Tt, 1]
|
| 664 |
+
max_t,
|
| 665 |
+
)
|
| 666 |
+
t_mask_p = self._pad_first_dim(
|
| 667 |
+
ensure_mask_or_empty(
|
| 668 |
+
tgt_mask if tgt_mask is not None else None, (1, 1)
|
| 669 |
+
), # to [t_ind, Tt]
|
| 670 |
+
max_t,
|
| 671 |
+
)
|
| 672 |
+
t_rem_p = self._pad_first_dim(
|
| 673 |
+
ensure_tensor_or_empty(
|
| 674 |
+
tgt_rem.unsqueeze(-1) if tgt_rem is not None else None, (t_obs_p.size(0), 1, 1)
|
| 675 |
+
), # [t_ind, Rt,1]
|
| 676 |
+
max_t,
|
| 677 |
+
)
|
| 678 |
+
t_rem_time_p = self._pad_first_dim(
|
| 679 |
+
ensure_tensor_or_empty(
|
| 680 |
+
tgt_rem_time.unsqueeze(-1) if tgt_rem_time is not None else None,
|
| 681 |
+
(t_obs_p.size(0), 1, 1),
|
| 682 |
+
),
|
| 683 |
+
max_t,
|
| 684 |
+
)
|
| 685 |
+
t_rem_mask_p = self._pad_first_dim(
|
| 686 |
+
ensure_mask_or_empty(
|
| 687 |
+
tgt_rem_mask if tgt_rem_mask is not None else None, (t_obs_p.size(0), 1)
|
| 688 |
+
),
|
| 689 |
+
max_t,
|
| 690 |
+
)
|
| 691 |
+
t_dose_p = self._pad_first_dim(
|
| 692 |
+
ensure_tensor_or_empty(tgt_dose if tgt_dose is not None else None, (1,)), # [t_ind]
|
| 693 |
+
max_t,
|
| 694 |
+
)
|
| 695 |
+
t_route_p = self._pad_first_dim(
|
| 696 |
+
ensure_tensor_or_empty(tgt_route if tgt_route is not None else None, (1,)), # [t_ind]
|
| 697 |
+
max_t,
|
| 698 |
+
).long()
|
| 699 |
+
|
| 700 |
+
# ── context padding (unchanged) ────────────────────────────────────────
|
| 701 |
+
c_obs_p = self._pad_first_dim(ctx_obs, max_c).unsqueeze(-1) # [c_ind, Tc, 1]
|
| 702 |
+
c_time_p = self._pad_first_dim(ctx_time, max_c).unsqueeze(-1) # [c_ind, Tc, 1]
|
| 703 |
+
c_mask_p = self._pad_first_dim(ctx_mask, max_c) # [c_ind, Tc]
|
| 704 |
+
c_rem_p = self._pad_first_dim(
|
| 705 |
+
ensure_tensor_or_empty(
|
| 706 |
+
ctx_rem.unsqueeze(-1) if ctx_rem is not None else None, (ctx_obs.size(0), 1, 1)
|
| 707 |
+
),
|
| 708 |
+
max_c,
|
| 709 |
+
)
|
| 710 |
+
c_rem_time_p = self._pad_first_dim(
|
| 711 |
+
ensure_tensor_or_empty(
|
| 712 |
+
ctx_rem_time.unsqueeze(-1) if ctx_rem_time is not None else None,
|
| 713 |
+
(ctx_obs.size(0), 1, 1),
|
| 714 |
+
),
|
| 715 |
+
max_c,
|
| 716 |
+
)
|
| 717 |
+
c_rem_mask_p = self._pad_first_dim(
|
| 718 |
+
ensure_mask_or_empty(
|
| 719 |
+
ctx_rem_mask if ctx_rem_mask is not None else None, (ctx_obs.size(0), 1)
|
| 720 |
+
),
|
| 721 |
+
max_c,
|
| 722 |
+
)
|
| 723 |
+
c_dose_p = self._pad_first_dim(ctx_dose, max_c) # [c_ind]
|
| 724 |
+
c_route_p = self._pad_first_dim(ctx_route, max_c).long() # [c_ind]
|
| 725 |
+
|
| 726 |
+
total_c = ctx_obs.size(0)
|
| 727 |
+
mask_c_inds = torch.zeros(self.max_context_individuals, dtype=torch.bool)
|
| 728 |
+
mask_c_inds[:total_c] = True
|
| 729 |
+
|
| 730 |
+
total_t = tgt_obs.size(0) if tgt_obs is not None else 0
|
| 731 |
+
mask_t_inds = torch.zeros(
|
| 732 |
+
max_t, dtype=torch.bool
|
| 733 |
+
) # ← use max_t here (unchanged logic, just variable)
|
| 734 |
+
mask_t_inds[:total_t] = True
|
| 735 |
+
|
| 736 |
+
return AICMECompartmentsDataBatch(
|
| 737 |
+
target_obs=t_obs_p,
|
| 738 |
+
target_obs_time=t_time_p,
|
| 739 |
+
target_obs_mask=t_mask_p,
|
| 740 |
+
target_rem_sim=t_rem_p,
|
| 741 |
+
target_rem_sim_time=t_rem_time_p,
|
| 742 |
+
target_rem_sim_mask=t_rem_mask_p,
|
| 743 |
+
target_dosing_amounts=t_dose_p,
|
| 744 |
+
target_dosing_route_types=t_route_p,
|
| 745 |
+
context_obs=c_obs_p,
|
| 746 |
+
context_obs_time=c_time_p,
|
| 747 |
+
context_obs_mask=c_mask_p,
|
| 748 |
+
context_rem_sim=c_rem_p,
|
| 749 |
+
context_rem_sim_time=c_rem_time_p,
|
| 750 |
+
context_rem_sim_mask=c_rem_mask_p,
|
| 751 |
+
context_dosing_amounts=c_dose_p,
|
| 752 |
+
context_dosing_route_types=c_route_p,
|
| 753 |
+
mask_context_individuals=mask_c_inds,
|
| 754 |
+
mask_target_individuals=mask_t_inds,
|
| 755 |
+
study_name=[""],
|
| 756 |
+
context_subject_name=[[""] * max_c],
|
| 757 |
+
target_subject_name=[[""] * max_t], # ← still uses max_t
|
| 758 |
+
substance_name=[""],
|
| 759 |
+
time_scales=ts,
|
| 760 |
+
is_empirical=False,
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
@staticmethod
|
| 764 |
+
def _safe_generate(strategy, sim, times, **kw):
|
| 765 |
+
"""
|
| 766 |
+
Call ObservationStrategy.generate() only when `sim` is not None.
|
| 767 |
+
Returns a 7-tuple of Nones otherwise.
|
| 768 |
+
"""
|
| 769 |
+
if sim is None:
|
| 770 |
+
return (None, None, None, None, None, None, None)
|
| 771 |
+
|
| 772 |
+
for _ in range(10): # retries, like old manager
|
| 773 |
+
out = strategy.generate(sim, times, **kw)
|
| 774 |
+
if out[0] is not None: # got a non-empty slice
|
| 775 |
+
return out
|
| 776 |
+
raise RuntimeError(
|
| 777 |
+
"Unable to generate non-empty observations "
|
| 778 |
+
"after 10 attempts – check strategy parameters."
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
@staticmethod
|
| 782 |
+
def _pad_first_dim(t: torch.Tensor, size: int) -> torch.Tensor:
|
| 783 |
+
"""Pad tensor along the first dimension up to ``size``.
|
| 784 |
+
|
| 785 |
+
Parameters
|
| 786 |
+
----------
|
| 787 |
+
t : TensorType["I", *Ts]
|
| 788 |
+
Input tensor where ``I`` may be smaller than ``size``.
|
| 789 |
+
size : int
|
| 790 |
+
Desired first-dimension size after padding.
|
| 791 |
+
|
| 792 |
+
Returns
|
| 793 |
+
-------
|
| 794 |
+
TensorType["size", *Ts]
|
| 795 |
+
Tensor padded with zeros (or ``False`` for bool tensors) so that the
|
| 796 |
+
first dimension equals ``size``. If ``t`` already has ``size`` or
|
| 797 |
+
more elements along the first dimension, it is truncated.
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
current = t.size(0)
|
| 801 |
+
if current >= size:
|
| 802 |
+
return t[:size]
|
| 803 |
+
|
| 804 |
+
pad_shape = (size - current, *t.shape[1:])
|
| 805 |
+
pad_value = False if t.dtype == torch.bool else 0.0
|
| 806 |
+
padding = torch.full(pad_shape, pad_value, dtype=t.dtype, device=t.device)
|
| 807 |
+
return torch.cat([t, padding], dim=0)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
class AICMECompartmentsDataModule(pl.LightningDataModule):
|
| 811 |
+
"""LightningDataModule for synthetic PK simulation data."""
|
| 812 |
+
|
| 813 |
+
# Empirical target batches always use the legacy PK observation strategy
|
| 814 |
+
# with a fixed capacity profile, independent from synthetic target config.
|
| 815 |
+
_EMPIRICAL_TARGET_MAX_NUM_OBS = 15
|
| 816 |
+
_EMPIRICAL_TARGET_MIN_PAST = 0
|
| 817 |
+
_EMPIRICAL_TARGET_MAX_PAST = 5
|
| 818 |
+
|
| 819 |
+
def __init__(
|
| 820 |
+
self,
|
| 821 |
+
model_config: NodePKExperimentConfig,
|
| 822 |
+
):
|
| 823 |
+
super().__init__()
|
| 824 |
+
self.model_config = model_config
|
| 825 |
+
self.context_config = model_config.context_observations
|
| 826 |
+
self.target_config = model_config.target_observations
|
| 827 |
+
self.meta_config = model_config.meta_study
|
| 828 |
+
self.data_config = model_config.mix_data
|
| 829 |
+
self.study_config = model_config.meta_study
|
| 830 |
+
self.num_workers = model_config.train.num_workers
|
| 831 |
+
self.persistent_workers = model_config.train.persistent_workers
|
| 832 |
+
self.shuffle_val = getattr(model_config.train, "shuffle_val", True)
|
| 833 |
+
self.train_size = self.data_config.train_size
|
| 834 |
+
self.val_size = self.data_config.val_size
|
| 835 |
+
self.test_size = self.data_config.test_size
|
| 836 |
+
self.batch_size = model_config.train.batch_size
|
| 837 |
+
self._prepared = False
|
| 838 |
+
# Cached shape parameters for empirical batch builders
|
| 839 |
+
self.max_individuals: int | None = None
|
| 840 |
+
self.max_observations: int | None = None
|
| 841 |
+
self.max_remaining: int | None = None
|
| 842 |
+
self.empirical_target_config = None
|
| 843 |
+
self.empirical_target_strategy = None
|
| 844 |
+
self.empirical_test_batches: Dict[str, List["AICMECompartmentsDataBatch"]] = {}
|
| 845 |
+
self.empirical_test_batches_no_heldout: Dict[str, List["AICMECompartmentsDataBatch"]] = {}
|
| 846 |
+
|
| 847 |
+
def prepare_data(self):
|
| 848 |
+
# Use this method to download or prepare data if needed.
|
| 849 |
+
# This is called only once and on a single GPU.
|
| 850 |
+
# Here the Observation Manager Also Handles Empirical Data
|
| 851 |
+
tempfile_path = getattr(self.data_config, "tempfile_path", None)
|
| 852 |
+
if tempfile_path:
|
| 853 |
+
temp_dir = Path(data_dir).joinpath(*tempfile_path)
|
| 854 |
+
else:
|
| 855 |
+
temp_dir = Path(data_dir) / "preprocessed"
|
| 856 |
+
temp_dir.mkdir(parents=True, exist_ok=True)
|
| 857 |
+
|
| 858 |
+
self.context_strategy = ObservationStrategyFactory.from_config(
|
| 859 |
+
self.context_config,
|
| 860 |
+
self.meta_config,
|
| 861 |
+
)
|
| 862 |
+
self.target_strategy = ObservationStrategyFactory.from_config(
|
| 863 |
+
self.target_config,
|
| 864 |
+
self.meta_config,
|
| 865 |
+
)
|
| 866 |
+
# Empirical target path: enforce legacy PK strategy and fixed capacities.
|
| 867 |
+
# This is intentionally decoupled from synthetic target strategy settings.
|
| 868 |
+
self.empirical_target_config = replace(
|
| 869 |
+
self.target_config,
|
| 870 |
+
type=None,
|
| 871 |
+
split_past_future=True,
|
| 872 |
+
max_num_obs=self._EMPIRICAL_TARGET_MAX_NUM_OBS,
|
| 873 |
+
min_past=self._EMPIRICAL_TARGET_MIN_PAST,
|
| 874 |
+
max_past=self._EMPIRICAL_TARGET_MAX_PAST,
|
| 875 |
+
)
|
| 876 |
+
self.empirical_target_strategy = ObservationStrategyFactory.from_config(
|
| 877 |
+
self.empirical_target_config,
|
| 878 |
+
self.meta_config,
|
| 879 |
+
)
|
| 880 |
+
self.train_dataset = AICMECompartmentsDataset(
|
| 881 |
+
self.model_config,
|
| 882 |
+
ctx_fn=self.context_strategy,
|
| 883 |
+
tgt_fn=self.target_strategy,
|
| 884 |
+
number_of_process=self.train_size,
|
| 885 |
+
store_in_tempfile=self.data_config.store_in_tempfile,
|
| 886 |
+
keep_tempfile=self.data_config.keep_tempfile,
|
| 887 |
+
recreate_tempfile=self.data_config.recreate_tempfile,
|
| 888 |
+
tempfile_path=self.data_config.tempfile_path,
|
| 889 |
+
show_progress=self.data_config.tqdm_progress,
|
| 890 |
+
split="train",
|
| 891 |
+
)
|
| 892 |
+
self.val_dataset = AICMECompartmentsDataset(
|
| 893 |
+
self.model_config,
|
| 894 |
+
ctx_fn=self.context_strategy,
|
| 895 |
+
tgt_fn=self.target_strategy,
|
| 896 |
+
number_of_process=self.val_size,
|
| 897 |
+
store_in_tempfile=self.data_config.store_in_tempfile,
|
| 898 |
+
keep_tempfile=self.data_config.keep_tempfile,
|
| 899 |
+
recreate_tempfile=self.data_config.recreate_tempfile,
|
| 900 |
+
tempfile_path=self.data_config.tempfile_path,
|
| 901 |
+
show_progress=self.data_config.tqdm_progress,
|
| 902 |
+
split="val",
|
| 903 |
+
)
|
| 904 |
+
self.test_dataset = AICMECompartmentsDataset(
|
| 905 |
+
self.model_config,
|
| 906 |
+
ctx_fn=self.context_strategy,
|
| 907 |
+
tgt_fn=self.target_strategy,
|
| 908 |
+
number_of_process=self.test_size,
|
| 909 |
+
store_in_tempfile=self.data_config.store_in_tempfile,
|
| 910 |
+
keep_tempfile=self.data_config.keep_tempfile,
|
| 911 |
+
recreate_tempfile=self.data_config.recreate_tempfile,
|
| 912 |
+
tempfile_path=self.data_config.tempfile_path,
|
| 913 |
+
show_progress=self.data_config.tqdm_progress,
|
| 914 |
+
split="test",
|
| 915 |
+
)
|
| 916 |
+
# Record shapes for empirical builders
|
| 917 |
+
ctx_obs, ctx_rem = self.context_strategy.get_shapes()
|
| 918 |
+
tgt_obs, tgt_rem = self.target_strategy.get_shapes()
|
| 919 |
+
self.max_observations = max(ctx_obs, tgt_obs)
|
| 920 |
+
self.max_remaining = max(ctx_rem, tgt_rem)
|
| 921 |
+
self.max_individuals = max(
|
| 922 |
+
self.train_dataset.max_context_individuals,
|
| 923 |
+
self.train_dataset.n_of_target_individuals,
|
| 924 |
+
)
|
| 925 |
+
self._prepared = True
|
| 926 |
+
self._empirical_loaded = False
|
| 927 |
+
|
| 928 |
+
# Preload empirical datasets during prepare_data so they are available
|
| 929 |
+
# before training callbacks query them.
|
| 930 |
+
# In DDP, keep network/download activity on rank 0 only.
|
| 931 |
+
if self._is_global_zero_process():
|
| 932 |
+
self._load_empirical_test_batches()
|
| 933 |
+
self._empirical_loaded = True
|
| 934 |
+
|
| 935 |
+
def setup(self, stage=None):
|
| 936 |
+
# Use this method to split data into train, validation, and test sets.
|
| 937 |
+
# This is called on every GPU.
|
| 938 |
+
if not self._prepared:
|
| 939 |
+
self.prepare_data()
|
| 940 |
+
|
| 941 |
+
def train_dataloader(self):
|
| 942 |
+
# Returns the training dataloader.
|
| 943 |
+
num_workers, persistent_workers = self._resolve_dataloader_workers()
|
| 944 |
+
return DataLoader(
|
| 945 |
+
self.train_dataset,
|
| 946 |
+
batch_size=self.batch_size,
|
| 947 |
+
shuffle=True,
|
| 948 |
+
num_workers=num_workers,
|
| 949 |
+
persistent_workers=persistent_workers,
|
| 950 |
+
collate_fn=_collate_aicme_batches,
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
def val_dataloader(self):
|
| 954 |
+
# Returns the validation dataloader.
|
| 955 |
+
num_workers, persistent_workers = self._resolve_dataloader_workers()
|
| 956 |
+
return DataLoader(
|
| 957 |
+
self.val_dataset,
|
| 958 |
+
batch_size=self.batch_size,
|
| 959 |
+
shuffle=self.shuffle_val,
|
| 960 |
+
num_workers=num_workers,
|
| 961 |
+
persistent_workers=persistent_workers,
|
| 962 |
+
collate_fn=_collate_aicme_batches,
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
def test_dataloader(self):
|
| 966 |
+
# Optional: Returns the test dataloader.
|
| 967 |
+
# If you don't have a test set, you can omit this method.
|
| 968 |
+
num_workers, persistent_workers = self._resolve_dataloader_workers()
|
| 969 |
+
return DataLoader(
|
| 970 |
+
self.test_dataset,
|
| 971 |
+
batch_size=self.batch_size,
|
| 972 |
+
shuffle=False,
|
| 973 |
+
num_workers=num_workers,
|
| 974 |
+
persistent_workers=persistent_workers,
|
| 975 |
+
collate_fn=_collate_aicme_batches,
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
def obtain_shapes(self) -> Tuple[int, int, int]:
|
| 979 |
+
"""Expose dataset shape parameters for empirical batching.
|
| 980 |
+
|
| 981 |
+
Returns
|
| 982 |
+
-------
|
| 983 |
+
Tuple[int, int, int]
|
| 984 |
+
``(max_individuals, max_observations, max_remaining)`` as used by
|
| 985 |
+
:class:`AICMECompartmentsDataset`.
|
| 986 |
+
"""
|
| 987 |
+
|
| 988 |
+
if not self._prepared:
|
| 989 |
+
self.prepare_data()
|
| 990 |
+
|
| 991 |
+
assert self.max_individuals is not None
|
| 992 |
+
assert self.max_observations is not None
|
| 993 |
+
assert self.max_remaining is not None
|
| 994 |
+
return (
|
| 995 |
+
self.max_individuals,
|
| 996 |
+
self.max_observations,
|
| 997 |
+
self.max_remaining,
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
def _resolve_dataloader_workers(self) -> Tuple[int, bool]:
|
| 1001 |
+
"""Return DataLoader worker settings that are safe for single-process runs."""
|
| 1002 |
+
num_workers = max(0, int(self.num_workers))
|
| 1003 |
+
persistent_workers = self.persistent_workers and num_workers > 0
|
| 1004 |
+
return num_workers, persistent_workers
|
| 1005 |
+
|
| 1006 |
+
@staticmethod
|
| 1007 |
+
def _is_global_zero_process() -> bool:
|
| 1008 |
+
"""Return True for rank 0 (or single-process execution)."""
|
| 1009 |
+
|
| 1010 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 1011 |
+
return torch.distributed.get_rank() == 0
|
| 1012 |
+
return True
|
| 1013 |
+
|
| 1014 |
+
def _load_empirical_test_batches(self) -> None:
|
| 1015 |
+
"""Download and cache empirical Hugging Face datasets for evaluation."""
|
| 1016 |
+
|
| 1017 |
+
from sim_priors_pk.data.data_empirical import load_empirical_hf_batches_as_dm
|
| 1018 |
+
|
| 1019 |
+
datasets = getattr(self.data_config, "test_empirical_datasets", [])
|
| 1020 |
+
self.empirical_test_batches = {}
|
| 1021 |
+
self.empirical_test_batches_no_heldout = {}
|
| 1022 |
+
if not datasets:
|
| 1023 |
+
return
|
| 1024 |
+
|
| 1025 |
+
for repo_id in datasets:
|
| 1026 |
+
try:
|
| 1027 |
+
batches = load_empirical_hf_batches_as_dm(
|
| 1028 |
+
repo_id,
|
| 1029 |
+
meta_dosing=self.model_config.dosing,
|
| 1030 |
+
datamodule=self,
|
| 1031 |
+
held_out=True,
|
| 1032 |
+
)
|
| 1033 |
+
except Exception as exc: # noqa: BLE001 - surface download issues
|
| 1034 |
+
warnings.warn(
|
| 1035 |
+
f"Failed to load empirical dataset '{repo_id}': {exc}",
|
| 1036 |
+
stacklevel=2,
|
| 1037 |
+
)
|
| 1038 |
+
continue
|
| 1039 |
+
|
| 1040 |
+
if not batches:
|
| 1041 |
+
warnings.warn(
|
| 1042 |
+
f"No empirical batches returned for dataset '{repo_id}'",
|
| 1043 |
+
stacklevel=2,
|
| 1044 |
+
)
|
| 1045 |
+
continue
|
| 1046 |
+
|
| 1047 |
+
self.empirical_test_batches[repo_id] = batches
|
| 1048 |
+
try:
|
| 1049 |
+
no_heldout_batches = load_empirical_hf_batches_as_dm(
|
| 1050 |
+
repo_id,
|
| 1051 |
+
meta_dosing=self.model_config.dosing,
|
| 1052 |
+
datamodule=self,
|
| 1053 |
+
held_out=False,
|
| 1054 |
+
)
|
| 1055 |
+
except Exception as exc: # noqa: BLE001 - surface download issues
|
| 1056 |
+
warnings.warn(
|
| 1057 |
+
f"Failed to load no-heldout empirical dataset '{repo_id}': {exc}",
|
| 1058 |
+
stacklevel=2,
|
| 1059 |
+
)
|
| 1060 |
+
continue
|
| 1061 |
+
|
| 1062 |
+
if not no_heldout_batches:
|
| 1063 |
+
warnings.warn(
|
| 1064 |
+
f"No no-heldout empirical batches returned for dataset '{repo_id}'",
|
| 1065 |
+
stacklevel=2,
|
| 1066 |
+
)
|
| 1067 |
+
continue
|
| 1068 |
+
|
| 1069 |
+
self.empirical_test_batches_no_heldout[repo_id] = no_heldout_batches
|
| 1070 |
+
|
| 1071 |
+
def get_empirical_test_batches(
|
| 1072 |
+
self,
|
| 1073 |
+
*,
|
| 1074 |
+
no_heldout: bool = False,
|
| 1075 |
+
device: Optional[torch.device | str] = None,
|
| 1076 |
+
) -> Dict[str, List["AICMECompartmentsDataBatch"]]:
|
| 1077 |
+
"""Return cached empirical batches keyed by Hugging Face dataset id.
|
| 1078 |
+
|
| 1079 |
+
Parameters
|
| 1080 |
+
----------
|
| 1081 |
+
no_heldout:
|
| 1082 |
+
If ``True``, return batches where all empirical individuals remain
|
| 1083 |
+
in context (no held-out target). If ``False`` (default), return the
|
| 1084 |
+
leave-one-out batches.
|
| 1085 |
+
device:
|
| 1086 |
+
Optional device where returned batches should live. When provided,
|
| 1087 |
+
returned batches are moved to ``device`` without mutating the
|
| 1088 |
+
internal cache.
|
| 1089 |
+
"""
|
| 1090 |
+
|
| 1091 |
+
# Safety fallback for direct/manual datamodule usage.
|
| 1092 |
+
if not getattr(self, "_empirical_loaded", False):
|
| 1093 |
+
if not self._prepared:
|
| 1094 |
+
self.prepare_data()
|
| 1095 |
+
elif self._is_global_zero_process():
|
| 1096 |
+
self._load_empirical_test_batches()
|
| 1097 |
+
self._empirical_loaded = True
|
| 1098 |
+
|
| 1099 |
+
batch_map = (
|
| 1100 |
+
self.empirical_test_batches_no_heldout if no_heldout else self.empirical_test_batches
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
if device is None:
|
| 1104 |
+
return batch_map
|
| 1105 |
+
|
| 1106 |
+
return {
|
| 1107 |
+
repo_id: list_of_databath_to_device(batch_list, device)
|
| 1108 |
+
for repo_id, batch_list in batch_map.items()
|
| 1109 |
+
}
|
| 1110 |
+
|
| 1111 |
+
def get_empirical_batches(
|
| 1112 |
+
self,
|
| 1113 |
+
*,
|
| 1114 |
+
split: str,
|
| 1115 |
+
empirical_name: Optional[str],
|
| 1116 |
+
device: Optional[torch.device | str] = None,
|
| 1117 |
+
) -> List["AICMECompartmentsDataBatch"]:
|
| 1118 |
+
"""Return one empirical batch list using scheduler-oriented split aliases.
|
| 1119 |
+
|
| 1120 |
+
Supported split aliases:
|
| 1121 |
+
- ``empirical_heldout``: leave-one-out empirical targets
|
| 1122 |
+
- ``empirical_no_heldout``: all empirical individuals remain in context
|
| 1123 |
+
"""
|
| 1124 |
+
|
| 1125 |
+
normalized_split = str(split).strip().lower()
|
| 1126 |
+
if normalized_split == "empirical_heldout":
|
| 1127 |
+
batch_map = self.get_empirical_test_batches(no_heldout=False, device=device)
|
| 1128 |
+
elif normalized_split == "empirical_no_heldout":
|
| 1129 |
+
batch_map = self.get_empirical_test_batches(no_heldout=True, device=device)
|
| 1130 |
+
else:
|
| 1131 |
+
raise ValueError(
|
| 1132 |
+
f"Unsupported empirical split alias '{split}'. "
|
| 1133 |
+
"Expected 'empirical_heldout' or 'empirical_no_heldout'."
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
if empirical_name is None:
|
| 1137 |
+
raise ValueError("`empirical_name` must be provided for empirical scheduler tasks.")
|
| 1138 |
+
try:
|
| 1139 |
+
return batch_map[str(empirical_name)]
|
| 1140 |
+
except KeyError as exc:
|
| 1141 |
+
raise ValueError(
|
| 1142 |
+
f"No empirical batches found for split='{split}' and empirical_name='{empirical_name}'."
|
| 1143 |
+
) from exc
|
| 1144 |
+
|
| 1145 |
+
@staticmethod
|
| 1146 |
+
def _normalize_substance_name(name: object) -> str:
|
| 1147 |
+
"""Normalize substance names for robust matching."""
|
| 1148 |
+
|
| 1149 |
+
return "".join(ch.lower() for ch in str(name) if ch.isalnum())
|
| 1150 |
+
|
| 1151 |
+
def select_empirical_batch_list(
|
| 1152 |
+
self,
|
| 1153 |
+
dataset_key: Optional[str] = None,
|
| 1154 |
+
*,
|
| 1155 |
+
no_heldout: bool = False,
|
| 1156 |
+
) -> Tuple[Optional[str], List["AICMECompartmentsDataBatch"]]:
|
| 1157 |
+
"""Select one empirical dataset batch list for plotting/evaluation.
|
| 1158 |
+
|
| 1159 |
+
Parameters
|
| 1160 |
+
----------
|
| 1161 |
+
dataset_key:
|
| 1162 |
+
Explicit dataset key to use. If missing or unknown, the first
|
| 1163 |
+
non-empty dataset in cache is selected.
|
| 1164 |
+
no_heldout:
|
| 1165 |
+
Whether to read from the no-heldout cache.
|
| 1166 |
+
|
| 1167 |
+
Returns
|
| 1168 |
+
-------
|
| 1169 |
+
Tuple[Optional[str], List[AICMECompartmentsDataBatch]]
|
| 1170 |
+
Selected dataset key (or ``None`` if unavailable) and batch list.
|
| 1171 |
+
"""
|
| 1172 |
+
|
| 1173 |
+
empirical_batches = self.get_empirical_test_batches(no_heldout=no_heldout)
|
| 1174 |
+
if dataset_key is not None and dataset_key in empirical_batches:
|
| 1175 |
+
selected_key = dataset_key
|
| 1176 |
+
batch_list = empirical_batches[dataset_key]
|
| 1177 |
+
else:
|
| 1178 |
+
selected_key = None
|
| 1179 |
+
batch_list = None
|
| 1180 |
+
for repo_id, batches in empirical_batches.items():
|
| 1181 |
+
if batches:
|
| 1182 |
+
selected_key = repo_id
|
| 1183 |
+
batch_list = batches
|
| 1184 |
+
break
|
| 1185 |
+
|
| 1186 |
+
if not batch_list:
|
| 1187 |
+
label = "no-heldout" if no_heldout else "heldout"
|
| 1188 |
+
raise RuntimeError(f"No empirical {label} batches available for predictive plotting.")
|
| 1189 |
+
|
| 1190 |
+
return selected_key, batch_list
|
| 1191 |
+
|
| 1192 |
+
def describe_empirical_test_batches(
|
| 1193 |
+
self,
|
| 1194 |
+
empirical_batches: Optional[Dict[str, List["AICMECompartmentsDataBatch"]]] = None,
|
| 1195 |
+
*,
|
| 1196 |
+
no_heldout: bool = False,
|
| 1197 |
+
batch_index: int = 0,
|
| 1198 |
+
print_available: bool = True,
|
| 1199 |
+
) -> Tuple[List[str], List[str]]:
|
| 1200 |
+
"""Describe empirical test batches and return available studies/drugs.
|
| 1201 |
+
|
| 1202 |
+
This helper is designed to be called after
|
| 1203 |
+
:meth:`get_empirical_test_batches` in notebook/script workflows.
|
| 1204 |
+
|
| 1205 |
+
Parameters
|
| 1206 |
+
----------
|
| 1207 |
+
empirical_batches:
|
| 1208 |
+
Optional pre-fetched empirical batches (typically from
|
| 1209 |
+
:meth:`get_empirical_test_batches`). If ``None``, batches are
|
| 1210 |
+
fetched internally.
|
| 1211 |
+
no_heldout:
|
| 1212 |
+
Whether to describe no-heldout batches.
|
| 1213 |
+
batch_index:
|
| 1214 |
+
Batch index to inspect within each dataset. Default is ``0``.
|
| 1215 |
+
print_available:
|
| 1216 |
+
If ``True``, print available datasets/studies/drugs.
|
| 1217 |
+
|
| 1218 |
+
Returns
|
| 1219 |
+
-------
|
| 1220 |
+
Tuple[List[str], List[str]]
|
| 1221 |
+
Unique available study names and drug names from the selected
|
| 1222 |
+
``batch_index`` across datasets.
|
| 1223 |
+
"""
|
| 1224 |
+
|
| 1225 |
+
batch_map = empirical_batches
|
| 1226 |
+
if batch_map is None:
|
| 1227 |
+
batch_map = self.get_empirical_test_batches(no_heldout=no_heldout)
|
| 1228 |
+
|
| 1229 |
+
if batch_index < 0:
|
| 1230 |
+
raise ValueError("batch_index must be non-negative")
|
| 1231 |
+
|
| 1232 |
+
available_studies: List[str] = []
|
| 1233 |
+
available_drugs: List[str] = []
|
| 1234 |
+
seen_studies: set[str] = set()
|
| 1235 |
+
seen_drugs: set[str] = set()
|
| 1236 |
+
|
| 1237 |
+
if print_available:
|
| 1238 |
+
label = "no_heldout=True" if no_heldout else "heldout"
|
| 1239 |
+
print(f"Available empirical datasets ({label}):", list(batch_map.keys()))
|
| 1240 |
+
|
| 1241 |
+
for repo_id, batch_list in batch_map.items():
|
| 1242 |
+
if print_available:
|
| 1243 |
+
print(f"Dataset '{repo_id}' contains {len(batch_list)} empirical batch(es).")
|
| 1244 |
+
if batch_index >= len(batch_list):
|
| 1245 |
+
if print_available:
|
| 1246 |
+
print(
|
| 1247 |
+
f" Skipping dataset '{repo_id}': batch_index={batch_index} "
|
| 1248 |
+
f"is out of range."
|
| 1249 |
+
)
|
| 1250 |
+
continue
|
| 1251 |
+
|
| 1252 |
+
batch = batch_list[batch_index]
|
| 1253 |
+
studies, drugs = self.describe_empirical_batch(batch, print_available=False)
|
| 1254 |
+
for study in studies:
|
| 1255 |
+
if study not in seen_studies:
|
| 1256 |
+
seen_studies.add(study)
|
| 1257 |
+
available_studies.append(study)
|
| 1258 |
+
for drug in drugs:
|
| 1259 |
+
if drug not in seen_drugs:
|
| 1260 |
+
seen_drugs.add(drug)
|
| 1261 |
+
available_drugs.append(drug)
|
| 1262 |
+
|
| 1263 |
+
if print_available:
|
| 1264 |
+
print(f" Batch {batch_index} studies:", studies)
|
| 1265 |
+
print(f" Batch {batch_index} drugs:", drugs)
|
| 1266 |
+
|
| 1267 |
+
if print_available:
|
| 1268 |
+
print("Available studies:", available_studies)
|
| 1269 |
+
print("Available drugs:", available_drugs)
|
| 1270 |
+
|
| 1271 |
+
return available_studies, available_drugs
|
| 1272 |
+
|
| 1273 |
+
@staticmethod
|
| 1274 |
+
def describe_empirical_batch(
|
| 1275 |
+
batch: "AICMECompartmentsDataBatch",
|
| 1276 |
+
*,
|
| 1277 |
+
print_available: bool = True,
|
| 1278 |
+
) -> Tuple[List[str], List[str]]:
|
| 1279 |
+
"""Return display-ready study and substance names for a batch.
|
| 1280 |
+
|
| 1281 |
+
Parameters
|
| 1282 |
+
----------
|
| 1283 |
+
batch:
|
| 1284 |
+
Empirical batch to inspect.
|
| 1285 |
+
print_available:
|
| 1286 |
+
If ``True``, print available studies and drugs to stdout.
|
| 1287 |
+
"""
|
| 1288 |
+
|
| 1289 |
+
studies = [str(name) if name else f"study_{i}" for i, name in enumerate(batch.study_name)]
|
| 1290 |
+
drugs = [
|
| 1291 |
+
str(name) if name else f"substance_{i}" for i, name in enumerate(batch.substance_name)
|
| 1292 |
+
]
|
| 1293 |
+
|
| 1294 |
+
if print_available:
|
| 1295 |
+
print("Available studies in selected batch:", studies)
|
| 1296 |
+
print("Available drugs in selected batch:", drugs)
|
| 1297 |
+
|
| 1298 |
+
return studies, drugs
|
| 1299 |
+
|
| 1300 |
+
@staticmethod
|
| 1301 |
+
def slice_single_substance_batch(
|
| 1302 |
+
batch: "AICMECompartmentsDataBatch",
|
| 1303 |
+
b_idx: int,
|
| 1304 |
+
) -> "AICMECompartmentsDataBatch":
|
| 1305 |
+
"""Extract one substance entry from a multi-substance batch.
|
| 1306 |
+
|
| 1307 |
+
Parameters
|
| 1308 |
+
----------
|
| 1309 |
+
batch:
|
| 1310 |
+
Batch with leading batch dimension ``B``.
|
| 1311 |
+
b_idx:
|
| 1312 |
+
Substance index along ``B``.
|
| 1313 |
+
|
| 1314 |
+
Returns
|
| 1315 |
+
-------
|
| 1316 |
+
AICMECompartmentsDataBatch
|
| 1317 |
+
Single-substance batch with tensors sliced to ``B=1``.
|
| 1318 |
+
"""
|
| 1319 |
+
|
| 1320 |
+
if b_idx < 0 or b_idx >= len(batch.substance_name):
|
| 1321 |
+
raise IndexError(
|
| 1322 |
+
f"Substance index {b_idx} is out of range for batch size "
|
| 1323 |
+
f"{len(batch.substance_name)}."
|
| 1324 |
+
)
|
| 1325 |
+
|
| 1326 |
+
values = []
|
| 1327 |
+
for field_name in batch._fields:
|
| 1328 |
+
value = getattr(batch, field_name)
|
| 1329 |
+
if isinstance(value, torch.Tensor):
|
| 1330 |
+
# Keep tensor rank stable by preserving a singleton leading B axis.
|
| 1331 |
+
values.append(value[b_idx : b_idx + 1])
|
| 1332 |
+
elif field_name in {
|
| 1333 |
+
"study_name",
|
| 1334 |
+
"substance_name",
|
| 1335 |
+
"context_subject_name",
|
| 1336 |
+
"target_subject_name",
|
| 1337 |
+
}:
|
| 1338 |
+
values.append([value[b_idx]])
|
| 1339 |
+
else:
|
| 1340 |
+
values.append(value)
|
| 1341 |
+
return batch.__class__(*values)
|
| 1342 |
+
|
| 1343 |
+
@classmethod
|
| 1344 |
+
def slice_single_substance_batch_by_name(
|
| 1345 |
+
cls,
|
| 1346 |
+
batch: "AICMECompartmentsDataBatch",
|
| 1347 |
+
substance_name: str,
|
| 1348 |
+
) -> "AICMECompartmentsDataBatch":
|
| 1349 |
+
"""Extract one substance entry by matching drug name."""
|
| 1350 |
+
|
| 1351 |
+
_, available_drugs = cls.describe_empirical_batch(batch, print_available=False)
|
| 1352 |
+
norm_target = cls._normalize_substance_name(substance_name)
|
| 1353 |
+
matches = [
|
| 1354 |
+
i
|
| 1355 |
+
for i, name in enumerate(available_drugs)
|
| 1356 |
+
if cls._normalize_substance_name(name) == norm_target
|
| 1357 |
+
]
|
| 1358 |
+
if not matches:
|
| 1359 |
+
raise ValueError(
|
| 1360 |
+
f"Selected drug '{substance_name}' not found in heldout batch. "
|
| 1361 |
+
f"Choose from: {available_drugs}"
|
| 1362 |
+
)
|
| 1363 |
+
return cls.slice_single_substance_batch(batch, matches[0])
|
| 1364 |
+
|
| 1365 |
+
def select_empirical_drug_batch(
|
| 1366 |
+
self,
|
| 1367 |
+
empirical_batches: Dict[str, List["AICMECompartmentsDataBatch"]],
|
| 1368 |
+
selected_drug: str,
|
| 1369 |
+
*,
|
| 1370 |
+
permutation_indexes: Optional[int | Sequence[int]] = None,
|
| 1371 |
+
print_selection: bool = True,
|
| 1372 |
+
) -> Tuple[
|
| 1373 |
+
"AICMECompartmentsDataBatch | List[AICMECompartmentsDataBatch]",
|
| 1374 |
+
str,
|
| 1375 |
+
str,
|
| 1376 |
+
]:
|
| 1377 |
+
"""Select one drug from empirical batches, optionally across permutations.
|
| 1378 |
+
|
| 1379 |
+
Parameters
|
| 1380 |
+
----------
|
| 1381 |
+
empirical_batches:
|
| 1382 |
+
Mapping returned by :meth:`get_empirical_test_batches`.
|
| 1383 |
+
selected_drug:
|
| 1384 |
+
Drug name to match across all empirical batches.
|
| 1385 |
+
permutation_indexes:
|
| 1386 |
+
Optional permutation index or list of permutation indices within the
|
| 1387 |
+
selected empirical dataset's batch list. When ``None`` (default),
|
| 1388 |
+
the method preserves legacy behaviour and returns the first matching
|
| 1389 |
+
single-substance batch. When a list/tuple is provided, returns a
|
| 1390 |
+
list of single-substance batches in the requested permutation order.
|
| 1391 |
+
print_selection:
|
| 1392 |
+
If ``True``, print where the match was found.
|
| 1393 |
+
"""
|
| 1394 |
+
|
| 1395 |
+
norm_target = self._normalize_substance_name(selected_drug)
|
| 1396 |
+
requested_permutations: Optional[List[int]]
|
| 1397 |
+
return_many = isinstance(permutation_indexes, (list, tuple))
|
| 1398 |
+
if permutation_indexes is None:
|
| 1399 |
+
requested_permutations = None
|
| 1400 |
+
elif return_many:
|
| 1401 |
+
if len(permutation_indexes) == 0:
|
| 1402 |
+
raise ValueError("'permutation_indexes' must not be empty.")
|
| 1403 |
+
requested_permutations = [int(idx) for idx in permutation_indexes]
|
| 1404 |
+
if len(set(requested_permutations)) != len(requested_permutations):
|
| 1405 |
+
raise ValueError("'permutation_indexes' must contain unique indices.")
|
| 1406 |
+
else:
|
| 1407 |
+
requested_permutations = [int(permutation_indexes)]
|
| 1408 |
+
|
| 1409 |
+
all_available_drugs: List[str] = []
|
| 1410 |
+
seen_drugs: set[str] = set()
|
| 1411 |
+
|
| 1412 |
+
for repo_id, batch_list in empirical_batches.items():
|
| 1413 |
+
for batch_index, batch in enumerate(batch_list):
|
| 1414 |
+
_, available_drugs = self.describe_empirical_batch(batch, print_available=False)
|
| 1415 |
+
for drug in available_drugs:
|
| 1416 |
+
if drug not in seen_drugs:
|
| 1417 |
+
seen_drugs.add(drug)
|
| 1418 |
+
all_available_drugs.append(drug)
|
| 1419 |
+
|
| 1420 |
+
matches = [
|
| 1421 |
+
i
|
| 1422 |
+
for i, name in enumerate(available_drugs)
|
| 1423 |
+
if self._normalize_substance_name(name) == norm_target
|
| 1424 |
+
]
|
| 1425 |
+
if matches:
|
| 1426 |
+
if requested_permutations is None:
|
| 1427 |
+
selected_batches: List[AICMECompartmentsDataBatch] = [
|
| 1428 |
+
self.slice_single_substance_batch(batch, matches[0])
|
| 1429 |
+
]
|
| 1430 |
+
chosen_permutations = [batch_index]
|
| 1431 |
+
else:
|
| 1432 |
+
selected_batches = []
|
| 1433 |
+
chosen_permutations = requested_permutations
|
| 1434 |
+
for permutation_index in requested_permutations:
|
| 1435 |
+
if permutation_index < 0 or permutation_index >= len(batch_list):
|
| 1436 |
+
raise IndexError(
|
| 1437 |
+
f"Permutation index {permutation_index} is out of range for "
|
| 1438 |
+
f"dataset '{repo_id}' with {len(batch_list)} permutations."
|
| 1439 |
+
)
|
| 1440 |
+
|
| 1441 |
+
perm_batch = batch_list[permutation_index]
|
| 1442 |
+
_, perm_drugs = self.describe_empirical_batch(
|
| 1443 |
+
perm_batch, print_available=False
|
| 1444 |
+
)
|
| 1445 |
+
perm_matches = [
|
| 1446 |
+
i
|
| 1447 |
+
for i, name in enumerate(perm_drugs)
|
| 1448 |
+
if self._normalize_substance_name(name) == norm_target
|
| 1449 |
+
]
|
| 1450 |
+
if not perm_matches:
|
| 1451 |
+
raise ValueError(
|
| 1452 |
+
f"Selected drug '{selected_drug}' was not found in dataset "
|
| 1453 |
+
f"'{repo_id}' at permutation index {permutation_index}."
|
| 1454 |
+
)
|
| 1455 |
+
selected_batches.append(
|
| 1456 |
+
self.slice_single_substance_batch(perm_batch, perm_matches[0])
|
| 1457 |
+
)
|
| 1458 |
+
|
| 1459 |
+
studies, drugs = self.describe_empirical_batch(
|
| 1460 |
+
selected_batches[0], print_available=False
|
| 1461 |
+
)
|
| 1462 |
+
selected_study = studies[0]
|
| 1463 |
+
selected_name = drugs[0]
|
| 1464 |
+
if print_selection:
|
| 1465 |
+
print("Selected empirical dataset key:", repo_id)
|
| 1466 |
+
if len(chosen_permutations) == 1:
|
| 1467 |
+
print("Selected empirical batch index:", chosen_permutations[0])
|
| 1468 |
+
else:
|
| 1469 |
+
print("Selected empirical batch indexes:", chosen_permutations)
|
| 1470 |
+
print("Selected study:", selected_study)
|
| 1471 |
+
print("Selected drug:", selected_name)
|
| 1472 |
+
if return_many:
|
| 1473 |
+
return selected_batches, selected_study, selected_name
|
| 1474 |
+
return selected_batches[0], selected_study, selected_name
|
| 1475 |
+
|
| 1476 |
+
raise ValueError(
|
| 1477 |
+
f"Selected drug '{selected_drug}' was not found in empirical batches. "
|
| 1478 |
+
f"Choose from: {all_available_drugs}"
|
| 1479 |
+
)
|
| 1480 |
+
|
| 1481 |
+
def _select_strategy(self, who: str):
|
| 1482 |
+
"""Return the observation strategy requested via ``who``.
|
| 1483 |
+
|
| 1484 |
+
Parameters
|
| 1485 |
+
----------
|
| 1486 |
+
who:
|
| 1487 |
+
Either ``"target"`` or ``"context"``.
|
| 1488 |
+
|
| 1489 |
+
Returns
|
| 1490 |
+
-------
|
| 1491 |
+
ObservationStrategy
|
| 1492 |
+
The strategy matching the requested role.
|
| 1493 |
+
"""
|
| 1494 |
+
|
| 1495 |
+
if who == "target":
|
| 1496 |
+
return self.target_strategy
|
| 1497 |
+
if who == "context":
|
| 1498 |
+
return self.context_strategy
|
| 1499 |
+
raise ValueError("'who' must be either 'target' or 'context'.")
|
| 1500 |
+
|
| 1501 |
+
def _select_strategies(self, who: str) -> List[object]:
|
| 1502 |
+
"""Return strategy list for the requested role.
|
| 1503 |
+
|
| 1504 |
+
For ``who='target'`` this includes both synthetic and empirical target
|
| 1505 |
+
strategies so past-selection overrides remain consistent when empirical
|
| 1506 |
+
batches are generated from the datamodule.
|
| 1507 |
+
"""
|
| 1508 |
+
|
| 1509 |
+
if who == "context":
|
| 1510 |
+
return [self.context_strategy]
|
| 1511 |
+
if who == "target":
|
| 1512 |
+
strategies: List[object] = [self.target_strategy]
|
| 1513 |
+
empirical_target_strategy = getattr(self, "empirical_target_strategy", None)
|
| 1514 |
+
if empirical_target_strategy is not None:
|
| 1515 |
+
strategies.append(empirical_target_strategy)
|
| 1516 |
+
# Keep order stable while avoiding duplicate objects.
|
| 1517 |
+
deduped: List[object] = []
|
| 1518 |
+
seen_ids: set[int] = set()
|
| 1519 |
+
for strategy in strategies:
|
| 1520 |
+
strategy_id = id(strategy)
|
| 1521 |
+
if strategy_id in seen_ids:
|
| 1522 |
+
continue
|
| 1523 |
+
seen_ids.add(strategy_id)
|
| 1524 |
+
deduped.append(strategy)
|
| 1525 |
+
return deduped
|
| 1526 |
+
raise ValueError("'who' must be either 'target' or 'context'.")
|
| 1527 |
+
|
| 1528 |
+
def fix_past_selection(self, fix_past_value: int, *, who: str = "target") -> None:
|
| 1529 |
+
"""Force a fixed number of past observations for the selected strategy.
|
| 1530 |
+
|
| 1531 |
+
The override is only applied for strategies with ``split_past_future``
|
| 1532 |
+
enabled; for others the call is ignored.
|
| 1533 |
+
"""
|
| 1534 |
+
|
| 1535 |
+
if not self._prepared:
|
| 1536 |
+
self.prepare_data()
|
| 1537 |
+
|
| 1538 |
+
for strategy in self._select_strategies(who):
|
| 1539 |
+
if hasattr(strategy, "fix_past_selection"):
|
| 1540 |
+
strategy.fix_past_selection(fix_past_value)
|
| 1541 |
+
# Reset lazy-load flag so empirical data is reloaded with new strategy settings
|
| 1542 |
+
self._empirical_loaded = False
|
| 1543 |
+
|
| 1544 |
+
def release_past_selection(self, *, who: str = "target") -> None:
|
| 1545 |
+
"""Restore the default past sampling behaviour for the given strategy."""
|
| 1546 |
+
|
| 1547 |
+
if not self._prepared:
|
| 1548 |
+
self.prepare_data()
|
| 1549 |
+
|
| 1550 |
+
for strategy in self._select_strategies(who):
|
| 1551 |
+
if hasattr(strategy, "release_past_selection"):
|
| 1552 |
+
strategy.release_past_selection()
|
| 1553 |
+
# Reset lazy-load flag so empirical data is reloaded with restored strategy settings
|
| 1554 |
+
self._empirical_loaded = False
|
| 1555 |
+
|
| 1556 |
+
def set_shared_target_dosing(self, enable: bool = True, n_targets: int = 100) -> None:
|
| 1557 |
+
"""Enable/disable shared target dosing across all datasets.
|
| 1558 |
+
|
| 1559 |
+
Parameters
|
| 1560 |
+
----------
|
| 1561 |
+
enable : bool
|
| 1562 |
+
Whether to enable shared-target dosing.
|
| 1563 |
+
n_targets : int
|
| 1564 |
+
Number of target individuals to sample when enabled.
|
| 1565 |
+
"""
|
| 1566 |
+
self.use_shared_target_dosing = enable
|
| 1567 |
+
self.shared_target_n_targets = n_targets
|
| 1568 |
+
|
| 1569 |
+
for ds in (
|
| 1570 |
+
getattr(self, "train_dataset", None),
|
| 1571 |
+
getattr(self, "val_dataset", None),
|
| 1572 |
+
getattr(self, "test_dataset", None),
|
| 1573 |
+
):
|
| 1574 |
+
if ds is not None:
|
| 1575 |
+
ds.use_shared_target_dosing = enable
|
| 1576 |
+
ds.shared_target_n_targets = n_targets
|
| 1577 |
+
|
| 1578 |
+
def unset_shared_target_dosing(self) -> None:
|
| 1579 |
+
"""Disable shared target dosing and restore default behaviour."""
|
| 1580 |
+
self.set_shared_target_dosing(False)
|
| 1581 |
+
|
| 1582 |
+
@staticmethod
|
| 1583 |
+
def _add_batch_dim_to_synthetic_batch(
|
| 1584 |
+
batch: AICMECompartmentsDataBatch,
|
| 1585 |
+
) -> AICMECompartmentsDataBatch:
|
| 1586 |
+
"""Add a leading ``B=1`` axis to tensor fields missing batch dimension."""
|
| 1587 |
+
|
| 1588 |
+
values: list = []
|
| 1589 |
+
for name, value in zip(batch._fields, batch):
|
| 1590 |
+
if not isinstance(value, torch.Tensor):
|
| 1591 |
+
values.append(value)
|
| 1592 |
+
continue
|
| 1593 |
+
|
| 1594 |
+
if name == "time_scales":
|
| 1595 |
+
# ``time_scales`` is often already [B, 2] while other fields are
|
| 1596 |
+
# emitted as [I, ...] by dataset-level generation.
|
| 1597 |
+
values.append(value.unsqueeze(0) if value.dim() == 1 else value)
|
| 1598 |
+
continue
|
| 1599 |
+
|
| 1600 |
+
values.append(value.unsqueeze(0))
|
| 1601 |
+
|
| 1602 |
+
return AICMECompartmentsDataBatch(*values)
|
| 1603 |
+
|
| 1604 |
+
def _generate_synthetic_list_with_repeated_target(
|
| 1605 |
+
self,
|
| 1606 |
+
*,
|
| 1607 |
+
shared_context_pack: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor],
|
| 1608 |
+
target_sim: Tensor,
|
| 1609 |
+
target_times: Tensor,
|
| 1610 |
+
target_dosing_amounts: Tensor,
|
| 1611 |
+
target_dosing_routes: Tensor,
|
| 1612 |
+
base_time_scales: Tensor,
|
| 1613 |
+
num_targets: int,
|
| 1614 |
+
) -> AICMECompartmentsDataBatch:
|
| 1615 |
+
"""Package one list element from shared context and dosing-specific targets.
|
| 1616 |
+
|
| 1617 |
+
Parameters
|
| 1618 |
+
----------
|
| 1619 |
+
shared_context_pack:
|
| 1620 |
+
Context tensors generated once and reused across all list elements.
|
| 1621 |
+
target_sim:
|
| 1622 |
+
Target simulation for one dosing condition with shape ``[n_targets, T]``.
|
| 1623 |
+
target_times:
|
| 1624 |
+
Target simulation times with shape ``[n_targets, T]``.
|
| 1625 |
+
target_dosing_amounts:
|
| 1626 |
+
Target dosing amounts with shape ``[n_targets]``.
|
| 1627 |
+
target_dosing_routes:
|
| 1628 |
+
Target dosing route types with shape ``[n_targets]``.
|
| 1629 |
+
base_time_scales:
|
| 1630 |
+
Simulation-level time scales from the sampler.
|
| 1631 |
+
num_targets:
|
| 1632 |
+
Number of target individuals capacity for this synthetic sample.
|
| 1633 |
+
"""
|
| 1634 |
+
(
|
| 1635 |
+
context_obs,
|
| 1636 |
+
context_obs_time,
|
| 1637 |
+
context_obs_mask,
|
| 1638 |
+
context_rem_sim,
|
| 1639 |
+
context_rem_sim_time,
|
| 1640 |
+
context_rem_sim_mask,
|
| 1641 |
+
dosing_amounts_ctx,
|
| 1642 |
+
dosing_routes_ctx,
|
| 1643 |
+
) = shared_context_pack
|
| 1644 |
+
|
| 1645 |
+
target_obs_pack = self.train_dataset._safe_generate(
|
| 1646 |
+
self.train_dataset.target_fn,
|
| 1647 |
+
target_sim,
|
| 1648 |
+
target_times,
|
| 1649 |
+
time_scales=base_time_scales,
|
| 1650 |
+
)
|
| 1651 |
+
(
|
| 1652 |
+
target_obs,
|
| 1653 |
+
target_obs_time,
|
| 1654 |
+
target_obs_mask,
|
| 1655 |
+
target_rem_sim,
|
| 1656 |
+
target_rem_sim_time,
|
| 1657 |
+
target_rem_sim_mask,
|
| 1658 |
+
_target_time_scales,
|
| 1659 |
+
) = target_obs_pack
|
| 1660 |
+
|
| 1661 |
+
# Keep the time-scale metadata aligned with the shared context payload.
|
| 1662 |
+
ts = base_time_scales
|
| 1663 |
+
|
| 1664 |
+
return self.train_dataset._build_padded_batch(
|
| 1665 |
+
# context (shared across all list elements)
|
| 1666 |
+
context_obs,
|
| 1667 |
+
context_obs_time,
|
| 1668 |
+
context_obs_mask,
|
| 1669 |
+
context_rem_sim,
|
| 1670 |
+
context_rem_sim_time,
|
| 1671 |
+
context_rem_sim_mask,
|
| 1672 |
+
dosing_amounts_ctx,
|
| 1673 |
+
dosing_routes_ctx,
|
| 1674 |
+
# target (specific to one repeated-dosing condition)
|
| 1675 |
+
target_obs,
|
| 1676 |
+
target_obs_time,
|
| 1677 |
+
target_obs_mask,
|
| 1678 |
+
target_rem_sim,
|
| 1679 |
+
target_rem_sim_time,
|
| 1680 |
+
target_rem_sim_mask,
|
| 1681 |
+
target_dosing_amounts,
|
| 1682 |
+
target_dosing_routes,
|
| 1683 |
+
# time scales
|
| 1684 |
+
ts=ts,
|
| 1685 |
+
target_capacity=num_targets,
|
| 1686 |
+
)
|
| 1687 |
+
|
| 1688 |
+
def prepare_full_simulation_list_with_repeated_targets(
|
| 1689 |
+
self,
|
| 1690 |
+
num_targets: int,
|
| 1691 |
+
batch_index: int = 0,
|
| 1692 |
+
num_of_different_dosages: int = 1,
|
| 1693 |
+
device: Optional[torch.device | str] = None,
|
| 1694 |
+
) -> List["AICMECompartmentsDataBatch"]:
|
| 1695 |
+
"""Build one shared context and ``L`` repeated-target dosing batches.
|
| 1696 |
+
|
| 1697 |
+
This helper is responsible for context creation exactly once, then
|
| 1698 |
+
looping over ``num_of_different_dosages`` target dosing conditions.
|
| 1699 |
+
Packaging is delegated to
|
| 1700 |
+
:meth:`_generate_synthetic_list_with_repeated_target`.
|
| 1701 |
+
|
| 1702 |
+
Parameters
|
| 1703 |
+
----------
|
| 1704 |
+
num_targets:
|
| 1705 |
+
Number of target individuals per dosing condition.
|
| 1706 |
+
batch_index:
|
| 1707 |
+
Synthetic sample index used by the simulation backend.
|
| 1708 |
+
num_of_different_dosages:
|
| 1709 |
+
Number of target dosing conditions ``L``.
|
| 1710 |
+
device:
|
| 1711 |
+
Optional device where returned batches should live.
|
| 1712 |
+
"""
|
| 1713 |
+
|
| 1714 |
+
if num_targets < 0:
|
| 1715 |
+
raise ValueError("num_targets must be non-negative")
|
| 1716 |
+
if batch_index < 0:
|
| 1717 |
+
raise ValueError("batch_index must be non-negative")
|
| 1718 |
+
if num_of_different_dosages < 0:
|
| 1719 |
+
raise ValueError("num_of_different_dosages must be non-negative")
|
| 1720 |
+
|
| 1721 |
+
if not self._prepared:
|
| 1722 |
+
self.prepare_data()
|
| 1723 |
+
|
| 1724 |
+
(
|
| 1725 |
+
context_sim,
|
| 1726 |
+
context_times,
|
| 1727 |
+
dosing_amounts_ctx,
|
| 1728 |
+
dosing_routes_ctx,
|
| 1729 |
+
target_simulations,
|
| 1730 |
+
target_times_list,
|
| 1731 |
+
target_dosing_amounts_list,
|
| 1732 |
+
target_dosing_routes_list,
|
| 1733 |
+
_time_points,
|
| 1734 |
+
time_scales,
|
| 1735 |
+
) = prepare_full_simulation_list_with_repeated_targets_backend(
|
| 1736 |
+
self.meta_config,
|
| 1737 |
+
self.model_config.dosing,
|
| 1738 |
+
n_targets=num_targets,
|
| 1739 |
+
num_of_different_dosages=num_of_different_dosages,
|
| 1740 |
+
idx=batch_index,
|
| 1741 |
+
)
|
| 1742 |
+
|
| 1743 |
+
# Build context once and reuse verbatim for all list elements.
|
| 1744 |
+
context_obs_pack = self.train_dataset._safe_generate(
|
| 1745 |
+
self.train_dataset.context_fn,
|
| 1746 |
+
context_sim,
|
| 1747 |
+
context_times,
|
| 1748 |
+
time_scales=time_scales,
|
| 1749 |
+
)
|
| 1750 |
+
(
|
| 1751 |
+
context_obs,
|
| 1752 |
+
context_obs_time,
|
| 1753 |
+
context_obs_mask,
|
| 1754 |
+
context_rem_sim,
|
| 1755 |
+
context_rem_sim_time,
|
| 1756 |
+
context_rem_sim_mask,
|
| 1757 |
+
context_time_scales,
|
| 1758 |
+
) = context_obs_pack
|
| 1759 |
+
|
| 1760 |
+
shared_context_pack = (
|
| 1761 |
+
context_obs,
|
| 1762 |
+
context_obs_time,
|
| 1763 |
+
context_obs_mask,
|
| 1764 |
+
context_rem_sim,
|
| 1765 |
+
context_rem_sim_time,
|
| 1766 |
+
context_rem_sim_mask,
|
| 1767 |
+
dosing_amounts_ctx,
|
| 1768 |
+
dosing_routes_ctx,
|
| 1769 |
+
)
|
| 1770 |
+
|
| 1771 |
+
base_time_scales = context_time_scales if context_time_scales is not None else time_scales
|
| 1772 |
+
|
| 1773 |
+
synthetic_batches: List[AICMECompartmentsDataBatch] = []
|
| 1774 |
+
for target_sim, target_times, target_dosing_amounts, target_dosing_routes in zip(
|
| 1775 |
+
target_simulations,
|
| 1776 |
+
target_times_list,
|
| 1777 |
+
target_dosing_amounts_list,
|
| 1778 |
+
target_dosing_routes_list,
|
| 1779 |
+
):
|
| 1780 |
+
synthetic_batches.append(
|
| 1781 |
+
self._generate_synthetic_list_with_repeated_target(
|
| 1782 |
+
shared_context_pack=shared_context_pack,
|
| 1783 |
+
target_sim=target_sim,
|
| 1784 |
+
target_times=target_times,
|
| 1785 |
+
target_dosing_amounts=target_dosing_amounts,
|
| 1786 |
+
target_dosing_routes=target_dosing_routes,
|
| 1787 |
+
base_time_scales=base_time_scales,
|
| 1788 |
+
num_targets=num_targets,
|
| 1789 |
+
)
|
| 1790 |
+
)
|
| 1791 |
+
|
| 1792 |
+
if device is None:
|
| 1793 |
+
return synthetic_batches
|
| 1794 |
+
|
| 1795 |
+
return list_of_databath_to_device(synthetic_batches, device)
|
| 1796 |
+
|
| 1797 |
+
def generate_synthetic_with_repeated_target(
|
| 1798 |
+
self,
|
| 1799 |
+
num_targets: int,
|
| 1800 |
+
batch_index: int = 0,
|
| 1801 |
+
different_dosing: bool = False,
|
| 1802 |
+
device: Optional[torch.device | str] = None,
|
| 1803 |
+
) -> List["AICMECompartmentsDataBatch"]:
|
| 1804 |
+
"""Generate one synthetic batch list with configurable target dosing.
|
| 1805 |
+
|
| 1806 |
+
The generated sample follows the current datamodule ``data_config`` and
|
| 1807 |
+
observation strategies while overriding only the number of target
|
| 1808 |
+
individuals. The returned tensors include an explicit leading batch
|
| 1809 |
+
dimension ``B=1`` to match dataloader outputs.
|
| 1810 |
+
|
| 1811 |
+
Parameters
|
| 1812 |
+
----------
|
| 1813 |
+
num_targets:
|
| 1814 |
+
Number of target individuals to include in the generated synthetic
|
| 1815 |
+
sample.
|
| 1816 |
+
batch_index:
|
| 1817 |
+
Dataset index used by the internal synthetic generator.
|
| 1818 |
+
different_dosing:
|
| 1819 |
+
If ``False`` (default), target individuals share one repeated dosing
|
| 1820 |
+
configuration.
|
| 1821 |
+
If ``True``, each target individual receives an independent dosing
|
| 1822 |
+
sample drawn from the same dosing distribution as context.
|
| 1823 |
+
device:
|
| 1824 |
+
Optional device where returned batches should live. When provided,
|
| 1825 |
+
returned batches are moved to ``device``.
|
| 1826 |
+
|
| 1827 |
+
Returns
|
| 1828 |
+
-------
|
| 1829 |
+
List[AICMECompartmentsDataBatch]
|
| 1830 |
+
A list containing a single synthetic databatch.
|
| 1831 |
+
"""
|
| 1832 |
+
|
| 1833 |
+
if num_targets < 0:
|
| 1834 |
+
raise ValueError("num_targets must be non-negative")
|
| 1835 |
+
if batch_index < 0:
|
| 1836 |
+
raise ValueError("batch_index must be non-negative")
|
| 1837 |
+
|
| 1838 |
+
if not self._prepared:
|
| 1839 |
+
self.prepare_data()
|
| 1840 |
+
|
| 1841 |
+
batches = self.train_dataset._generate_item_sample_target_dosing(
|
| 1842 |
+
batch_index,
|
| 1843 |
+
n_targets=num_targets,
|
| 1844 |
+
different_dosing=different_dosing,
|
| 1845 |
+
)
|
| 1846 |
+
batch_list = [self._add_batch_dim_to_synthetic_batch(batch) for batch in batches]
|
| 1847 |
+
|
| 1848 |
+
if device is None:
|
| 1849 |
+
return batch_list
|
| 1850 |
+
|
| 1851 |
+
return list_of_databath_to_device(batch_list, device)
|
| 1852 |
+
|
| 1853 |
+
def generate_synthetic_list_of_repeated_target(
|
| 1854 |
+
self,
|
| 1855 |
+
num_targets: int,
|
| 1856 |
+
batch_index: int = 0,
|
| 1857 |
+
num_of_different_dosages: int = 1,
|
| 1858 |
+
device: Optional[torch.device | str] = None,
|
| 1859 |
+
) -> List["AICMECompartmentsDataBatch"]:
|
| 1860 |
+
"""Generate a list of synthetic batches sharing one context.
|
| 1861 |
+
|
| 1862 |
+
The returned list has length ``num_of_different_dosages``. Context
|
| 1863 |
+
fields are identical across all elements, while targets are regenerated
|
| 1864 |
+
per element using repeated dosing within each element.
|
| 1865 |
+
"""
|
| 1866 |
+
|
| 1867 |
+
synthetic_batches = self.prepare_full_simulation_list_with_repeated_targets(
|
| 1868 |
+
num_targets=num_targets,
|
| 1869 |
+
batch_index=batch_index,
|
| 1870 |
+
num_of_different_dosages=num_of_different_dosages,
|
| 1871 |
+
device=device,
|
| 1872 |
+
)
|
| 1873 |
+
batch_list = [self._add_batch_dim_to_synthetic_batch(batch) for batch in synthetic_batches]
|
| 1874 |
+
return batch_list
|
sim_priors_pk/data/extra/compartment_models_vectorized.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def sample_individual_configs_vectorized(study_config):
|
| 5 |
+
"""
|
| 6 |
+
Vectorizes the sampling of parameters for a population of individuals.
|
| 7 |
+
|
| 8 |
+
Parameters
|
| 9 |
+
----------
|
| 10 |
+
study_config : StudyConfig
|
| 11 |
+
Contains the study settings and distribution parameters.
|
| 12 |
+
|
| 13 |
+
Returns
|
| 14 |
+
-------
|
| 15 |
+
config_dict : dict
|
| 16 |
+
Dictionary containing the vectorized parameters and time-magnitudes.
|
| 17 |
+
Keys:
|
| 18 |
+
'k_a', 'k_e', 'V': Tensors of shape (N,)
|
| 19 |
+
'k_1p', 'k_p1': Tensors of shape (N, P)
|
| 20 |
+
'k_a_tmag', 'k_e_tmag', 'V_tmag': Scalars
|
| 21 |
+
'k_1p_tmag', 'k_p1_tmag': Tensors of shape (P,)
|
| 22 |
+
'num_peripherals': int
|
| 23 |
+
"""
|
| 24 |
+
N = study_config.num_individuals
|
| 25 |
+
P = study_config.num_peripherals
|
| 26 |
+
|
| 27 |
+
# Sample the central parameters as tensors of shape (N,)
|
| 28 |
+
k_a = torch.from_numpy(np.random.lognormal(study_config.log_k_a_mean, study_config.log_k_a_std, size=N)).float()
|
| 29 |
+
k_e = torch.from_numpy(np.random.lognormal(study_config.log_k_e_mean, study_config.log_k_e_std, size=N)).float()
|
| 30 |
+
V = torch.from_numpy(np.random.lognormal(study_config.log_V_mean, study_config.log_V_std, size=N)).float()
|
| 31 |
+
|
| 32 |
+
# Sample the peripheral parameters as tensors of shape (N, P)
|
| 33 |
+
k_1p = []
|
| 34 |
+
k_p1 = []
|
| 35 |
+
for i in range(P):
|
| 36 |
+
k_1p_i = torch.from_numpy(np.random.lognormal(study_config.log_k_1p_mean[i],
|
| 37 |
+
study_config.log_k_1p_std[i], size=N)).float()
|
| 38 |
+
k_p1_i = torch.from_numpy(np.random.lognormal(study_config.log_k_p1_mean[i],
|
| 39 |
+
study_config.log_k_p1_std[i], size=N)).float()
|
| 40 |
+
k_1p.append(k_1p_i)
|
| 41 |
+
k_p1.append(k_p1_i)
|
| 42 |
+
# Stack along the peripheral dimension: shape becomes (N, P)
|
| 43 |
+
k_1p = torch.stack(k_1p, dim=1)
|
| 44 |
+
k_p1 = torch.stack(k_p1, dim=1)
|
| 45 |
+
|
| 46 |
+
# Pack time-magnitudes (assumed scalars for central parameters and lists for peripherals)
|
| 47 |
+
k_a_tmag = study_config.k_a_tmag # scalar
|
| 48 |
+
k_e_tmag = study_config.k_e_tmag # scalar
|
| 49 |
+
V_tmag = study_config.V_tmag # scalar
|
| 50 |
+
# For peripherals, we assume the study_config gives lists/arrays of length P.
|
| 51 |
+
k_1p_tmag = torch.tensor(study_config.k_1p_tmag).float() # shape (P,)
|
| 52 |
+
k_p1_tmag = torch.tensor(study_config.k_p1_tmag).float() # shape (P,)
|
| 53 |
+
|
| 54 |
+
config_dict = {
|
| 55 |
+
'k_a': k_a,
|
| 56 |
+
'k_e': k_e,
|
| 57 |
+
'V': V,
|
| 58 |
+
'k_1p': k_1p,
|
| 59 |
+
'k_p1': k_p1,
|
| 60 |
+
'k_a_tmag': k_a_tmag,
|
| 61 |
+
'k_e_tmag': k_e_tmag,
|
| 62 |
+
'V_tmag': V_tmag,
|
| 63 |
+
'k_1p_tmag': k_1p_tmag,
|
| 64 |
+
'k_p1_tmag': k_p1_tmag,
|
| 65 |
+
'num_peripherals': P,
|
| 66 |
+
}
|
| 67 |
+
return config_dict
|
| 68 |
+
|
| 69 |
+
import torch
|
| 70 |
+
|
| 71 |
+
def compute_rates(config, t):
|
| 72 |
+
"""
|
| 73 |
+
Computes the dynamic rates for all individuals at a given time t.
|
| 74 |
+
|
| 75 |
+
Parameters
|
| 76 |
+
----------
|
| 77 |
+
config : dict
|
| 78 |
+
Dictionary returned by sample_individual_configs_vectorized.
|
| 79 |
+
t : float or torch.Tensor
|
| 80 |
+
Current time point.
|
| 81 |
+
|
| 82 |
+
Returns
|
| 83 |
+
-------
|
| 84 |
+
k_a, k_e, V : torch.Tensor
|
| 85 |
+
Tensors of shape (N,).
|
| 86 |
+
k_1p, k_p1 : torch.Tensor
|
| 87 |
+
Tensors of shape (N, P).
|
| 88 |
+
"""
|
| 89 |
+
# Ensure t is a tensor
|
| 90 |
+
if not isinstance(t, torch.Tensor):
|
| 91 |
+
t = torch.tensor(t, dtype=config['k_a_tmag'].dtype, device=config['k_a_tmag'].device)
|
| 92 |
+
|
| 93 |
+
k_a = config['k_a'] * torch.exp(-config['k_a_tmag'] * t)
|
| 94 |
+
k_e = config['k_e'] * torch.exp(-config['k_e_tmag'] * t)
|
| 95 |
+
V = config['V'] * torch.exp(-config['V_tmag'] * t)
|
| 96 |
+
|
| 97 |
+
# Use broadcasting for peripheral compartments
|
| 98 |
+
k_1p = config['k_1p'] * torch.exp(-config['k_1p_tmag'] * t)
|
| 99 |
+
k_p1 = config['k_p1'] * torch.exp(-config['k_p1_tmag'] * t)
|
| 100 |
+
|
| 101 |
+
return k_a, k_e, V, k_1p, k_p1
|
| 102 |
+
|
| 103 |
+
def ode_func(t_val, y, config):
|
| 104 |
+
"""
|
| 105 |
+
ODE function using vectorized rate computations.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
t_val : torch.Tensor
|
| 110 |
+
Current time point.
|
| 111 |
+
y : torch.Tensor
|
| 112 |
+
Current state, shape (N, M) where M = 2 + num_peripherals.
|
| 113 |
+
config : dict
|
| 114 |
+
Vectorized individual configuration dictionary.
|
| 115 |
+
|
| 116 |
+
Returns
|
| 117 |
+
-------
|
| 118 |
+
dy_dt : torch.Tensor
|
| 119 |
+
Time derivative of y, shape (N, M).
|
| 120 |
+
"""
|
| 121 |
+
# Get the dynamic rates for all individuals at time t_val.
|
| 122 |
+
k_a, k_e, _, k_1p, k_p1 = compute_rates(config, t_val)
|
| 123 |
+
N = y.size(0)
|
| 124 |
+
P = config['num_peripherals']
|
| 125 |
+
M = 2 + P
|
| 126 |
+
|
| 127 |
+
# Build the ODE rate matrix A(t) in a vectorized fashion
|
| 128 |
+
A_all = torch.zeros((N, M, M), dtype=torch.float32)
|
| 129 |
+
A_all[:, 0, 0] = -k_a # Loss from gut
|
| 130 |
+
A_all[:, 1, 0] = k_a # Transfer gut -> central
|
| 131 |
+
A_all[:, 1, 1] = -k_e - k_1p.sum(dim=1) # Loss from central and distribution to peripherals
|
| 132 |
+
A_all[:, 1, 2:2+P] = k_p1 # Transfer central -> peripherals
|
| 133 |
+
A_all[:, 2:2+P, 1] = k_1p # Transfer peripherals -> central
|
| 134 |
+
# Peripheral compartments clearance:
|
| 135 |
+
for i in range(P):
|
| 136 |
+
A_all[:, 2 + i, 2 + i] = -k_p1[:, i]
|
| 137 |
+
|
| 138 |
+
# Compute dy/dt = A_all @ y for each individual.
|
| 139 |
+
dy_dt = torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
|
| 140 |
+
return dy_dt
|
| 141 |
+
|
| 142 |
+
def sample_study_vectorized(study_config, dosing_config, t, solver_method="rk4"):
|
| 143 |
+
"""
|
| 144 |
+
Simulates the pharmacokinetic study using vectorized individual configurations.
|
| 145 |
+
|
| 146 |
+
Parameters
|
| 147 |
+
----------
|
| 148 |
+
study_config : StudyConfig
|
| 149 |
+
Contains global study settings and distribution parameters.
|
| 150 |
+
dosing_config : DosingConfig
|
| 151 |
+
Contains dosing information.
|
| 152 |
+
t : torch.Tensor
|
| 153 |
+
Time points at which the simulation is evaluated.
|
| 154 |
+
|
| 155 |
+
Returns
|
| 156 |
+
-------
|
| 157 |
+
full_simulation : torch.Tensor
|
| 158 |
+
Concentration profiles (N, len(t)).
|
| 159 |
+
full_times : torch.Tensor
|
| 160 |
+
Time points replicated for each individual.
|
| 161 |
+
"""
|
| 162 |
+
from torchdiffeq import odeint
|
| 163 |
+
|
| 164 |
+
# Get the vectorized configuration dictionary
|
| 165 |
+
config = sample_individual_configs_vectorized(study_config)
|
| 166 |
+
N = study_config.num_individuals
|
| 167 |
+
P = study_config.num_peripherals
|
| 168 |
+
M = 2 + P
|
| 169 |
+
|
| 170 |
+
# Initial conditions: dose in the gut (first compartment), zeros elsewhere.
|
| 171 |
+
y0 = torch.zeros((N, M), dtype=torch.float32)
|
| 172 |
+
y0[:, 0] = dosing_config.dose
|
| 173 |
+
|
| 174 |
+
def wrapped_ode(t_val, y):
|
| 175 |
+
return ode_func(t_val, y, config)
|
| 176 |
+
|
| 177 |
+
# Solve the ODE system for all individuals in batch
|
| 178 |
+
y = odeint(wrapped_ode, y0, t, method=solver_method)
|
| 179 |
+
# Extract central compartment (index 1) for each individual
|
| 180 |
+
full_simulation = y[:, :, 1].T
|
| 181 |
+
full_times = t.unsqueeze(0).repeat(N, 1)
|
| 182 |
+
return full_simulation, full_times
|
sim_priors_pk/data/extra/kernels.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gpytorch
|
| 3 |
+
|
| 4 |
+
def create_kernel(config):
|
| 5 |
+
kernel_params = config.kernel_params
|
| 6 |
+
if 'type' not in kernel_params:
|
| 7 |
+
raise ValueError("Kernel type must be specified in kernel_params")
|
| 8 |
+
if kernel_params['type'] == 'RBF':
|
| 9 |
+
kernel = gpytorch.kernels.RBFKernel(ard_num_dims=config.input_dim, requires_grad=False)
|
| 10 |
+
kernel_params_ = kernel_params.get('params', {})
|
| 11 |
+
kernel_length_scale = kernel_params_["raw_lengthscale"]
|
| 12 |
+
kernel_length_scale = torch.tensor([kernel_length_scale] * config.input_dim)
|
| 13 |
+
kernel.initialize(raw_lengthscale=kernel_length_scale)
|
| 14 |
+
return kernel
|
| 15 |
+
raise ValueError(f"Unsupported kernel type: {kernel_params['type']}")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_kernel_mix(kernel_params,input_dim=1):
|
| 19 |
+
if 'type' not in kernel_params:
|
| 20 |
+
raise ValueError("Kernel type must be specified in kernel_params")
|
| 21 |
+
if kernel_params['type'] == 'RBF':
|
| 22 |
+
kernel = gpytorch.kernels.RBFKernel(ard_num_dims=input_dim, requires_grad=False)
|
| 23 |
+
kernel_params_ = kernel_params.get('params', {})
|
| 24 |
+
kernel_length_scale = kernel_params_["raw_lengthscale"]
|
| 25 |
+
kernel_length_scale = torch.tensor([kernel_length_scale] * input_dim)
|
| 26 |
+
kernel.initialize(raw_lengthscale=kernel_length_scale)
|
| 27 |
+
return kernel
|
| 28 |
+
raise ValueError(f"Unsupported kernel type: {kernel_params['type']}")
|
sim_priors_pk/hub_runtime/README.md
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hub Runtime Bundle
|
| 2 |
+
|
| 3 |
+
This directory contains the parallel Hugging Face export path for
|
| 4 |
+
consumer-facing model bundles.
|
| 5 |
+
|
| 6 |
+
The existing training export remains unchanged:
|
| 7 |
+
|
| 8 |
+
- native export: `BasicLightningExperiment._push_model_to_hub(...)`
|
| 9 |
+
- runtime export: `push_loaded_model_runtime_bundle(...)`
|
| 10 |
+
|
| 11 |
+
The runtime export is intended for users who should be able to load a model
|
| 12 |
+
from the Hugging Face Hub through `transformers` without installing the local
|
| 13 |
+
`sim_priors_pk` package.
|
| 14 |
+
|
| 15 |
+
## Important Constraint
|
| 16 |
+
|
| 17 |
+
The consumer entrypoint is `transformers`, but `transformers` alone is **not**
|
| 18 |
+
enough today.
|
| 19 |
+
|
| 20 |
+
These runtime bundles still execute PyTorch-based custom code and reconstruct
|
| 21 |
+
the internal PK architecture, so the user needs the runtime Python
|
| 22 |
+
dependencies, but not a local checkout of this repository.
|
| 23 |
+
|
| 24 |
+
Reliable consumer install:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
pip install torch transformers huggingface_hub lightning datasets pandas torchtyping gpytorch pot torchdiffeq torchsde ruamel.yaml pyyaml
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
What the consumer does **not** need:
|
| 31 |
+
|
| 32 |
+
- `pip install sim_priors_pk`
|
| 33 |
+
- a local clone of this repository
|
| 34 |
+
- access to the training checkpoint directory
|
| 35 |
+
|
| 36 |
+
## Consumer Workflow
|
| 37 |
+
|
| 38 |
+
Use the runtime repo, not the native training-artifact repo.
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
from transformers import AutoModel
|
| 42 |
+
|
| 43 |
+
model = AutoModel.from_pretrained(
|
| 44 |
+
"your-org/your-model-runtime",
|
| 45 |
+
trust_remote_code=True,
|
| 46 |
+
)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Then call the stable runtime task API:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
outputs = model.run_task(
|
| 53 |
+
task="generate", # or "predict"
|
| 54 |
+
studies=studies, # one StudyJSON or a list[StudyJSON]
|
| 55 |
+
num_samples=8,
|
| 56 |
+
)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
The return payload is:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
{
|
| 63 |
+
"task": "generate",
|
| 64 |
+
"io_schema_version": "studyjson-v1",
|
| 65 |
+
"model_info": {...},
|
| 66 |
+
"results": [
|
| 67 |
+
{
|
| 68 |
+
"input_index": 0,
|
| 69 |
+
"samples": [study_json_0, study_json_1, ...],
|
| 70 |
+
}
|
| 71 |
+
],
|
| 72 |
+
}
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## Generate Example
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
from transformers import AutoModel
|
| 79 |
+
|
| 80 |
+
model = AutoModel.from_pretrained(
|
| 81 |
+
"your-org/your-model-runtime",
|
| 82 |
+
trust_remote_code=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
studies = [
|
| 86 |
+
{
|
| 87 |
+
"context": [
|
| 88 |
+
{
|
| 89 |
+
"name_id": "ctx_0",
|
| 90 |
+
"observations": [0.2, 0.5, 0.3],
|
| 91 |
+
"observation_times": [0.5, 1.0, 2.0],
|
| 92 |
+
"dosing": [1.0],
|
| 93 |
+
"dosing_type": ["oral"],
|
| 94 |
+
"dosing_times": [0.0],
|
| 95 |
+
"dosing_name": ["oral"],
|
| 96 |
+
}
|
| 97 |
+
],
|
| 98 |
+
"target": [],
|
| 99 |
+
"meta_data": {
|
| 100 |
+
"study_name": "demo",
|
| 101 |
+
"substance_name": "drug_x",
|
| 102 |
+
},
|
| 103 |
+
}
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
outputs = model.run_task(
|
| 107 |
+
task="generate",
|
| 108 |
+
studies=studies,
|
| 109 |
+
num_samples=4,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
generated_studies = outputs["results"][0]["samples"]
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Predict Example
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
from transformers import AutoModel
|
| 119 |
+
|
| 120 |
+
model = AutoModel.from_pretrained(
|
| 121 |
+
"your-org/your-model-runtime",
|
| 122 |
+
trust_remote_code=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
predict_studies = [
|
| 126 |
+
{
|
| 127 |
+
"context": [
|
| 128 |
+
{
|
| 129 |
+
"name_id": "ctx_0",
|
| 130 |
+
"observations": [0.2, 0.5, 0.3],
|
| 131 |
+
"observation_times": [0.5, 1.0, 2.0],
|
| 132 |
+
"dosing": [1.0],
|
| 133 |
+
"dosing_type": ["oral"],
|
| 134 |
+
"dosing_times": [0.0],
|
| 135 |
+
"dosing_name": ["oral"],
|
| 136 |
+
}
|
| 137 |
+
],
|
| 138 |
+
"target": [
|
| 139 |
+
{
|
| 140 |
+
"name_id": "tgt_0",
|
| 141 |
+
"observations": [0.25, 0.31],
|
| 142 |
+
"observation_times": [0.5, 1.0],
|
| 143 |
+
"remaining": [0.0, 0.0, 0.0],
|
| 144 |
+
"remaining_times": [2.0, 4.0, 8.0],
|
| 145 |
+
"dosing": [1.0],
|
| 146 |
+
"dosing_type": ["oral"],
|
| 147 |
+
"dosing_times": [0.0],
|
| 148 |
+
"dosing_name": ["oral"],
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"meta_data": {
|
| 152 |
+
"study_name": "demo",
|
| 153 |
+
"substance_name": "drug_x",
|
| 154 |
+
},
|
| 155 |
+
}
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
outputs = model.run_task(
|
| 159 |
+
task="predict",
|
| 160 |
+
studies=predict_studies,
|
| 161 |
+
num_samples=4,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
prediction_samples = outputs["results"][0]["samples"]
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## Producer Workflow
|
| 168 |
+
|
| 169 |
+
To publish a runtime repo from a locally loaded experiment:
|
| 170 |
+
|
| 171 |
+
```python
|
| 172 |
+
from sim_priors_pk.hub_runtime import push_loaded_model_runtime_bundle
|
| 173 |
+
|
| 174 |
+
runtime_repo_id = push_loaded_model_runtime_bundle(
|
| 175 |
+
experiment=experiment,
|
| 176 |
+
model_card_path=["hf_model_cards", "AICME-PK_Readme.md"],
|
| 177 |
+
)
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
By default this creates a separate repo:
|
| 181 |
+
|
| 182 |
+
```text
|
| 183 |
+
<namespace>/<hf_model_name>-runtime
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
That keeps the native training artifact export and the consumer runtime export
|
| 187 |
+
separate.
|
sim_priors_pk/hub_runtime/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public helpers for the parallel Hugging Face runtime bundle path."""
|
| 2 |
+
|
| 3 |
+
from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
|
| 4 |
+
from sim_priors_pk.hub_runtime.modeling_sim_priors_pk import PKHubModel
|
| 5 |
+
from sim_priors_pk.hub_runtime.runtime_bundle import (
|
| 6 |
+
RuntimeBundleArtifacts,
|
| 7 |
+
build_runtime_bundle_dir,
|
| 8 |
+
default_runtime_repo_id,
|
| 9 |
+
push_loaded_model_runtime_bundle,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"PKHubConfig",
|
| 14 |
+
"PKHubModel",
|
| 15 |
+
"RuntimeBundleArtifacts",
|
| 16 |
+
"build_runtime_bundle_dir",
|
| 17 |
+
"default_runtime_repo_id",
|
| 18 |
+
"push_loaded_model_runtime_bundle",
|
| 19 |
+
]
|
sim_priors_pk/hub_runtime/configuration_sim_priors_pk.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face configuration for self-contained PK runtime bundles."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
|
| 9 |
+
from sim_priors_pk.hub_runtime.runtime_contract import STUDY_JSON_IO_VERSION
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PKHubConfig(PretrainedConfig):
|
| 13 |
+
"""Public Hub config describing a consumer-facing PK runtime bundle."""
|
| 14 |
+
|
| 15 |
+
model_type = "sim_priors_pk"
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
architecture_name: Optional[str] = None,
|
| 20 |
+
experiment_type: str = "nodepk",
|
| 21 |
+
experiment_config: Optional[Dict[str, Any]] = None,
|
| 22 |
+
builder_config: Optional[Dict[str, Any]] = None,
|
| 23 |
+
supported_tasks: Optional[List[str]] = None,
|
| 24 |
+
default_task: Optional[str] = None,
|
| 25 |
+
io_schema_version: str = STUDY_JSON_IO_VERSION,
|
| 26 |
+
original_repo_id: Optional[str] = None,
|
| 27 |
+
runtime_repo_id: Optional[str] = None,
|
| 28 |
+
**kwargs,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__(**kwargs)
|
| 31 |
+
self.architecture_name = architecture_name
|
| 32 |
+
self.experiment_type = experiment_type
|
| 33 |
+
self.experiment_config = dict(experiment_config or {})
|
| 34 |
+
self.builder_config = dict(builder_config or {})
|
| 35 |
+
self.supported_tasks = list(supported_tasks or [])
|
| 36 |
+
self.default_task = default_task or (self.supported_tasks[0] if self.supported_tasks else None)
|
| 37 |
+
self.io_schema_version = io_schema_version
|
| 38 |
+
self.original_repo_id = original_repo_id
|
| 39 |
+
self.runtime_repo_id = runtime_repo_id
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
__all__ = ["PKHubConfig"]
|
sim_priors_pk/hub_runtime/modeling_sim_priors_pk.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face AutoModel wrapper for consumer-facing PK runtime bundles."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Optional, Sequence, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import PreTrainedModel
|
| 9 |
+
|
| 10 |
+
from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
|
| 11 |
+
from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
|
| 12 |
+
from sim_priors_pk.hub_runtime.runtime_contract import (
|
| 13 |
+
RuntimeBuilderConfig,
|
| 14 |
+
build_batch_from_studies,
|
| 15 |
+
infer_supported_tasks,
|
| 16 |
+
instantiate_backbone_from_hub_config,
|
| 17 |
+
normalize_studies_input,
|
| 18 |
+
split_runtime_samples,
|
| 19 |
+
validate_studies_for_task,
|
| 20 |
+
)
|
| 21 |
+
from sim_priors_pk.models.amortized_inference.generative_pk import (
|
| 22 |
+
NewGenerativeMixin,
|
| 23 |
+
NewPredictiveMixin,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class PKHubModel(PreTrainedModel):
|
| 28 |
+
"""Thin wrapper exposing a stable StudyJSON runtime API on top of PK models."""
|
| 29 |
+
|
| 30 |
+
config_class = PKHubConfig
|
| 31 |
+
base_model_prefix = "backbone"
|
| 32 |
+
|
| 33 |
+
def __init__(self, config: PKHubConfig, backbone: Optional[torch.nn.Module] = None) -> None:
|
| 34 |
+
super().__init__(config)
|
| 35 |
+
self.backbone = backbone if backbone is not None else instantiate_backbone_from_hub_config(config)
|
| 36 |
+
self.backbone.eval()
|
| 37 |
+
|
| 38 |
+
def forward(self, *args, **kwargs):
|
| 39 |
+
"""Delegate raw forward calls to the wrapped PK backbone."""
|
| 40 |
+
|
| 41 |
+
return self.backbone(*args, **kwargs)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def supported_tasks(self) -> Sequence[str]:
|
| 45 |
+
"""Tasks supported by this runtime model."""
|
| 46 |
+
|
| 47 |
+
return tuple(getattr(self.config, "supported_tasks", []) or infer_supported_tasks(self.backbone))
|
| 48 |
+
|
| 49 |
+
@torch.inference_mode()
|
| 50 |
+
def run_task(
|
| 51 |
+
self,
|
| 52 |
+
*,
|
| 53 |
+
task: str,
|
| 54 |
+
studies: Union[StudyJSON, Sequence[StudyJSON]],
|
| 55 |
+
num_samples: int = 1,
|
| 56 |
+
**kwargs: Any,
|
| 57 |
+
) -> Dict[str, Any]:
|
| 58 |
+
"""Run the public StudyJSON inference contract for the requested task."""
|
| 59 |
+
|
| 60 |
+
supported_tasks = list(self.supported_tasks)
|
| 61 |
+
if task not in supported_tasks:
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"Unsupported task {task!r}. Supported tasks: {supported_tasks or 'none'}."
|
| 64 |
+
)
|
| 65 |
+
if int(num_samples) < 1:
|
| 66 |
+
raise ValueError("num_samples must be >= 1.")
|
| 67 |
+
|
| 68 |
+
canonical_studies = normalize_studies_input(studies)
|
| 69 |
+
builder_config = RuntimeBuilderConfig.from_dict(self.config.builder_config)
|
| 70 |
+
validate_studies_for_task(canonical_studies, task=task, builder_config=builder_config)
|
| 71 |
+
|
| 72 |
+
experiment_config_payload = getattr(self.config, "experiment_config", {})
|
| 73 |
+
meta_dosing_payload = experiment_config_payload.get("dosing", {})
|
| 74 |
+
batch = build_batch_from_studies(
|
| 75 |
+
canonical_studies,
|
| 76 |
+
builder_config=builder_config,
|
| 77 |
+
meta_dosing=self.backbone.meta_dosing.__class__(**meta_dosing_payload)
|
| 78 |
+
if meta_dosing_payload
|
| 79 |
+
else self.backbone.meta_dosing,
|
| 80 |
+
)
|
| 81 |
+
batch = batch.to(self.device)
|
| 82 |
+
|
| 83 |
+
if task == "generate":
|
| 84 |
+
if not isinstance(self.backbone, NewGenerativeMixin):
|
| 85 |
+
raise ValueError(f"Backbone {type(self.backbone).__name__} does not support generate.")
|
| 86 |
+
output_studies = self.backbone.sample_new_individuals_to_studyjson(
|
| 87 |
+
batch,
|
| 88 |
+
sample_size=int(num_samples),
|
| 89 |
+
num_steps=kwargs.get("num_steps"),
|
| 90 |
+
)
|
| 91 |
+
elif task == "predict":
|
| 92 |
+
if not isinstance(self.backbone, NewPredictiveMixin):
|
| 93 |
+
raise ValueError(f"Backbone {type(self.backbone).__name__} does not support predict.")
|
| 94 |
+
output_studies = self.backbone.sample_individual_prediction_from_batch_list_to_studyjson(
|
| 95 |
+
[batch],
|
| 96 |
+
sample_size=int(num_samples),
|
| 97 |
+
)[0]
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"Unsupported task {task!r}.")
|
| 100 |
+
|
| 101 |
+
results = [
|
| 102 |
+
{
|
| 103 |
+
"input_index": index,
|
| 104 |
+
"samples": split_runtime_samples(task, study),
|
| 105 |
+
}
|
| 106 |
+
for index, study in enumerate(output_studies)
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
return {
|
| 110 |
+
"task": task,
|
| 111 |
+
"io_schema_version": self.config.io_schema_version,
|
| 112 |
+
"model_info": {
|
| 113 |
+
"architecture_name": self.config.architecture_name,
|
| 114 |
+
"experiment_type": self.config.experiment_type,
|
| 115 |
+
"supported_tasks": supported_tasks,
|
| 116 |
+
"runtime_repo_id": self.config.runtime_repo_id,
|
| 117 |
+
"original_repo_id": self.config.original_repo_id,
|
| 118 |
+
},
|
| 119 |
+
"results": results,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
__all__ = ["PKHubModel"]
|
sim_priors_pk/hub_runtime/runtime_bundle.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Manual export path for consumer-facing Hugging Face runtime bundles."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
import shutil
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from tempfile import TemporaryDirectory
|
| 10 |
+
from typing import Optional, Sequence
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from huggingface_hub import HfApi, create_repo
|
| 14 |
+
|
| 15 |
+
from sim_priors_pk import config_dir, project_dir
|
| 16 |
+
from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
|
| 17 |
+
from sim_priors_pk.hub_runtime.modeling_sim_priors_pk import PKHubModel
|
| 18 |
+
from sim_priors_pk.hub_runtime.runtime_contract import (
|
| 19 |
+
build_runtime_config_payload,
|
| 20 |
+
resolve_model_card_text,
|
| 21 |
+
runtime_readme_text,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
ROOT_CONFIGURATION_FILENAME = "configuration_sim_priors_pk.py"
|
| 25 |
+
ROOT_MODELING_FILENAME = "modeling_sim_priors_pk.py"
|
| 26 |
+
_HF_TOKEN_PATTERN = re.compile(r"hf_[A-Za-z0-9]{20,}")
|
| 27 |
+
_COMET_KEY_ASSIGNMENT_PATTERN = re.compile(r"(COMET_API_KEY\s*=\s*)(['\"]).*?\2")
|
| 28 |
+
_HF_KEY_ASSIGNMENT_PATTERN = re.compile(r"(HF_KEYS\s*=\s*)(['\"]).*?\2")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class RuntimeBundleArtifacts:
|
| 33 |
+
"""Return metadata for a staged runtime bundle."""
|
| 34 |
+
|
| 35 |
+
bundle_dir: Path
|
| 36 |
+
runtime_repo_id: str
|
| 37 |
+
original_repo_id: Optional[str]
|
| 38 |
+
readme_path: Path
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def default_runtime_repo_id(experiment, *, suffix: str = "-runtime") -> str:
|
| 42 |
+
"""Resolve the default runtime bundle repo id for a loaded experiment."""
|
| 43 |
+
|
| 44 |
+
if getattr(experiment, "exp_config", None) is None:
|
| 45 |
+
raise RuntimeError("Experiment config is not loaded.")
|
| 46 |
+
if getattr(experiment, "hf_token", None) is None:
|
| 47 |
+
raise RuntimeError(
|
| 48 |
+
"No Hugging Face token available. Set hugging_face_token in the config or KEYS.txt."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
user = HfApi().whoami(token=experiment.hf_token)["name"]
|
| 52 |
+
return f"{user}/{experiment.exp_config.hf_model_name}{suffix}"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _default_original_repo_id(experiment) -> Optional[str]:
|
| 56 |
+
"""Infer the legacy/native Hub repo id if enough metadata is available."""
|
| 57 |
+
|
| 58 |
+
if getattr(experiment, "exp_config", None) is None:
|
| 59 |
+
return None
|
| 60 |
+
if getattr(experiment, "hf_token", None) is None:
|
| 61 |
+
return None
|
| 62 |
+
user = HfApi().whoami(token=experiment.hf_token)["name"]
|
| 63 |
+
return f"{user}/{experiment.exp_config.hf_model_name}"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _validate_loaded_experiment(experiment) -> None:
|
| 67 |
+
"""Ensure the loaded experiment has the minimum state needed for manual export."""
|
| 68 |
+
|
| 69 |
+
if getattr(experiment, "model", None) is None:
|
| 70 |
+
raise RuntimeError("Experiment model is not loaded.")
|
| 71 |
+
if getattr(experiment, "exp_config", None) is None:
|
| 72 |
+
raise RuntimeError("Experiment config is not loaded.")
|
| 73 |
+
if getattr(experiment, "experiment_dir", None) is None:
|
| 74 |
+
raise RuntimeError("Experiment directory is required before pushing.")
|
| 75 |
+
if getattr(experiment, "hf_token", None) is None:
|
| 76 |
+
raise RuntimeError(
|
| 77 |
+
"No Hugging Face token available. Set hugging_face_token in the config or KEYS.txt."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _copy_runtime_support_files(bundle_dir: Path) -> None:
|
| 82 |
+
"""Copy the local package and root remote-code entrypoints into the bundle."""
|
| 83 |
+
|
| 84 |
+
package_src = project_dir / "sim_priors_pk"
|
| 85 |
+
package_dst = bundle_dir / "sim_priors_pk"
|
| 86 |
+
shutil.copytree(package_src, package_dst, dirs_exist_ok=True, ignore=shutil.ignore_patterns("__pycache__"))
|
| 87 |
+
|
| 88 |
+
root_config_src = package_src / "hub_runtime" / ROOT_CONFIGURATION_FILENAME
|
| 89 |
+
root_modeling_src = package_src / "hub_runtime" / ROOT_MODELING_FILENAME
|
| 90 |
+
shutil.copy2(root_config_src, bundle_dir / ROOT_CONFIGURATION_FILENAME)
|
| 91 |
+
shutil.copy2(root_modeling_src, bundle_dir / ROOT_MODELING_FILENAME)
|
| 92 |
+
|
| 93 |
+
for extra_name in ("requirements.txt", "LICENSE"):
|
| 94 |
+
extra_src = project_dir / extra_name
|
| 95 |
+
if extra_src.is_file():
|
| 96 |
+
shutil.copy2(extra_src, bundle_dir / extra_name)
|
| 97 |
+
|
| 98 |
+
_scrub_runtime_bundle_secrets(bundle_dir)
|
| 99 |
+
_validate_no_hf_secrets(bundle_dir)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _scrub_runtime_bundle_secrets(bundle_dir: Path) -> None:
|
| 103 |
+
"""Remove token-like secrets from copied source files before Hub upload."""
|
| 104 |
+
|
| 105 |
+
candidate_files = [
|
| 106 |
+
*bundle_dir.rglob("*.py"),
|
| 107 |
+
*bundle_dir.rglob("*.md"),
|
| 108 |
+
*bundle_dir.rglob("*.txt"),
|
| 109 |
+
*bundle_dir.rglob("*.json"),
|
| 110 |
+
]
|
| 111 |
+
for path in candidate_files:
|
| 112 |
+
try:
|
| 113 |
+
text = path.read_text(encoding="utf-8")
|
| 114 |
+
except UnicodeDecodeError:
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
updated = text
|
| 118 |
+
updated = _HF_TOKEN_PATTERN.sub("hf_REDACTED", updated)
|
| 119 |
+
updated = _COMET_KEY_ASSIGNMENT_PATTERN.sub(r"\1\2REDACTED\2", updated)
|
| 120 |
+
updated = _HF_KEY_ASSIGNMENT_PATTERN.sub(r"\1\2REDACTED\2", updated)
|
| 121 |
+
|
| 122 |
+
if path.as_posix().endswith("sim_priors_pk/utils/__init__.py"):
|
| 123 |
+
updated = (
|
| 124 |
+
"PASCAL_BASE_DIR = ''\n"
|
| 125 |
+
"NERSC_BASE_DIR = ''\n"
|
| 126 |
+
"NERSC_EXPERIMENT_DIR = ''\n"
|
| 127 |
+
"COMET_API_KEY = 'REDACTED'\n"
|
| 128 |
+
"HF_KEYS = 'REDACTED'\n"
|
| 129 |
+
"WORKSPACE = ''\n"
|
| 130 |
+
"PROJECT = ''\n"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if updated != text:
|
| 134 |
+
path.write_text(updated, encoding="utf-8")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _validate_no_hf_secrets(bundle_dir: Path) -> None:
|
| 138 |
+
"""Fail fast if token-like Hugging Face secrets remain after scrubbing."""
|
| 139 |
+
|
| 140 |
+
offending_files: list[str] = []
|
| 141 |
+
for path in bundle_dir.rglob("*"):
|
| 142 |
+
if not path.is_file():
|
| 143 |
+
continue
|
| 144 |
+
if path.suffix not in {".py", ".md", ".txt", ".json"}:
|
| 145 |
+
continue
|
| 146 |
+
try:
|
| 147 |
+
text = path.read_text(encoding="utf-8")
|
| 148 |
+
except UnicodeDecodeError:
|
| 149 |
+
continue
|
| 150 |
+
if _HF_TOKEN_PATTERN.search(text):
|
| 151 |
+
offending_files.append(str(path.relative_to(bundle_dir)))
|
| 152 |
+
|
| 153 |
+
if offending_files:
|
| 154 |
+
raise RuntimeError(
|
| 155 |
+
"Refusing to upload runtime bundle because token-like Hugging Face secrets "
|
| 156 |
+
f"remain after scrubbing: {offending_files}"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def build_runtime_bundle_dir(
|
| 161 |
+
*,
|
| 162 |
+
experiment,
|
| 163 |
+
bundle_dir: Path,
|
| 164 |
+
model_card_path: Optional[Sequence[str]] = None,
|
| 165 |
+
hf_repo_id: Optional[str] = None,
|
| 166 |
+
original_repo_id: Optional[str] = None,
|
| 167 |
+
) -> RuntimeBundleArtifacts:
|
| 168 |
+
"""Stage a self-contained runtime bundle in ``bundle_dir`` without uploading it."""
|
| 169 |
+
|
| 170 |
+
_validate_loaded_experiment(experiment)
|
| 171 |
+
bundle_dir.mkdir(parents=True, exist_ok=True)
|
| 172 |
+
|
| 173 |
+
runtime_repo_id = hf_repo_id or default_runtime_repo_id(experiment)
|
| 174 |
+
native_repo_id = original_repo_id or _default_original_repo_id(experiment)
|
| 175 |
+
|
| 176 |
+
normalized_model_card_path = tuple(
|
| 177 |
+
model_card_path
|
| 178 |
+
if model_card_path is not None
|
| 179 |
+
else getattr(experiment.exp_config, "hf_model_card_path", ("hf_model_cards", "README.md"))
|
| 180 |
+
)
|
| 181 |
+
local_model_card_path = Path(config_dir).joinpath(*normalized_model_card_path)
|
| 182 |
+
base_model_card = resolve_model_card_text(local_model_card_path)
|
| 183 |
+
|
| 184 |
+
runtime_payload = build_runtime_config_payload(
|
| 185 |
+
backbone=experiment.model,
|
| 186 |
+
exp_config=experiment.exp_config,
|
| 187 |
+
original_repo_id=native_repo_id,
|
| 188 |
+
runtime_repo_id=runtime_repo_id,
|
| 189 |
+
)
|
| 190 |
+
runtime_config = PKHubConfig(
|
| 191 |
+
**runtime_payload,
|
| 192 |
+
auto_map={
|
| 193 |
+
"AutoConfig": f"{ROOT_CONFIGURATION_FILENAME[:-3]}.PKHubConfig",
|
| 194 |
+
"AutoModel": f"{ROOT_MODELING_FILENAME[:-3]}.PKHubModel",
|
| 195 |
+
},
|
| 196 |
+
architectures=["PKHubModel"],
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
runtime_model = PKHubModel(runtime_config, backbone=experiment.model)
|
| 200 |
+
state_dict = {name: tensor.detach().cpu() for name, tensor in runtime_model.state_dict().items()}
|
| 201 |
+
torch.save(state_dict, bundle_dir / "pytorch_model.bin")
|
| 202 |
+
runtime_config.save_pretrained(str(bundle_dir))
|
| 203 |
+
|
| 204 |
+
_copy_runtime_support_files(bundle_dir)
|
| 205 |
+
|
| 206 |
+
readme_text = runtime_readme_text(
|
| 207 |
+
base_model_card=base_model_card,
|
| 208 |
+
runtime_repo_id=runtime_repo_id,
|
| 209 |
+
original_repo_id=native_repo_id,
|
| 210 |
+
supported_tasks=runtime_config.supported_tasks,
|
| 211 |
+
default_task=runtime_config.default_task,
|
| 212 |
+
)
|
| 213 |
+
readme_path = bundle_dir / "README.md"
|
| 214 |
+
readme_path.write_text(readme_text, encoding="utf-8")
|
| 215 |
+
|
| 216 |
+
return RuntimeBundleArtifacts(
|
| 217 |
+
bundle_dir=bundle_dir,
|
| 218 |
+
runtime_repo_id=runtime_repo_id,
|
| 219 |
+
original_repo_id=native_repo_id,
|
| 220 |
+
readme_path=readme_path,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def push_loaded_model_runtime_bundle(
|
| 225 |
+
experiment,
|
| 226 |
+
model_card_path: Optional[Sequence[str]] = None,
|
| 227 |
+
hf_repo_id: Optional[str] = None,
|
| 228 |
+
alias_name: str = "runtime_bundle_hf",
|
| 229 |
+
commit_message: str = "manual runtime bundle push",
|
| 230 |
+
*,
|
| 231 |
+
original_repo_id: Optional[str] = None,
|
| 232 |
+
exist_ok: bool = True,
|
| 233 |
+
) -> str:
|
| 234 |
+
"""Build and upload the consumer-facing runtime bundle for a loaded experiment."""
|
| 235 |
+
|
| 236 |
+
_validate_loaded_experiment(experiment)
|
| 237 |
+
runtime_repo_id = hf_repo_id or default_runtime_repo_id(experiment)
|
| 238 |
+
create_repo(runtime_repo_id, exist_ok=exist_ok, token=experiment.hf_token)
|
| 239 |
+
|
| 240 |
+
bundle_root = Path(experiment.experiment_dir) / alias_name
|
| 241 |
+
bundle_root.mkdir(parents=True, exist_ok=True)
|
| 242 |
+
|
| 243 |
+
with TemporaryDirectory(dir=str(bundle_root), prefix="hf_runtime_bundle_") as temp_dir:
|
| 244 |
+
staged_dir = Path(temp_dir)
|
| 245 |
+
build_runtime_bundle_dir(
|
| 246 |
+
experiment=experiment,
|
| 247 |
+
bundle_dir=staged_dir,
|
| 248 |
+
model_card_path=model_card_path,
|
| 249 |
+
hf_repo_id=runtime_repo_id,
|
| 250 |
+
original_repo_id=original_repo_id,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
api = HfApi(token=experiment.hf_token)
|
| 254 |
+
api.upload_folder(
|
| 255 |
+
folder_path=str(staged_dir),
|
| 256 |
+
repo_id=runtime_repo_id,
|
| 257 |
+
commit_message=commit_message,
|
| 258 |
+
token=experiment.hf_token,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return runtime_repo_id
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
__all__ = [
|
| 265 |
+
"RuntimeBundleArtifacts",
|
| 266 |
+
"build_runtime_bundle_dir",
|
| 267 |
+
"default_runtime_repo_id",
|
| 268 |
+
"push_loaded_model_runtime_bundle",
|
| 269 |
+
]
|
sim_priors_pk/hub_runtime/runtime_contract.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared runtime-contract helpers for consumer-facing Hub bundles.
|
| 2 |
+
|
| 3 |
+
This module is imported both by the local exporter and by the copied package
|
| 4 |
+
inside the generated Hugging Face runtime bundle. Keep dependencies limited to
|
| 5 |
+
modules that are already required for model inference.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from dataclasses import asdict, dataclass
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, get_args, get_origin
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import PretrainedConfig
|
| 17 |
+
|
| 18 |
+
from sim_priors_pk.config_classes.data_config import (
|
| 19 |
+
MetaDosingConfig,
|
| 20 |
+
MetaStudyConfig,
|
| 21 |
+
MixDataConfig,
|
| 22 |
+
ObservationsConfig,
|
| 23 |
+
SimpleMetaStudyConfig,
|
| 24 |
+
)
|
| 25 |
+
from sim_priors_pk.config_classes.diffusion_pk_config import DiffusionPKExperimentConfig
|
| 26 |
+
from sim_priors_pk.config_classes.flow_pk_config import FlowPKExperimentConfig, VectorFieldPKConfig
|
| 27 |
+
from sim_priors_pk.config_classes.node_pk_config import (
|
| 28 |
+
EncoderDecoderNetworkConfig,
|
| 29 |
+
NodePKExperimentConfig,
|
| 30 |
+
)
|
| 31 |
+
from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
|
| 32 |
+
from sim_priors_pk.config_classes.training_config import TrainingConfig
|
| 33 |
+
from sim_priors_pk.data.data_empirical.builder import EmpiricalBatchConfig, JSON2AICMEBuilder
|
| 34 |
+
from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON, canonicalize_study
|
| 35 |
+
from sim_priors_pk.data.data_generation.observations_classes import ObservationStrategyFactory
|
| 36 |
+
from sim_priors_pk.models import get_model_class
|
| 37 |
+
from sim_priors_pk.models.amortized_inference.generative_pk import (
|
| 38 |
+
NewGenerativeMixin,
|
| 39 |
+
NewPredictiveMixin,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
SUPPORTED_RUNTIME_ARCHITECTURES = {
|
| 43 |
+
"AICMEPK",
|
| 44 |
+
"ContextVAEPK",
|
| 45 |
+
"FlowPK",
|
| 46 |
+
"PredictionPK",
|
| 47 |
+
}
|
| 48 |
+
STUDY_JSON_IO_VERSION = "studyjson-v1"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class RuntimeBuilderConfig:
|
| 53 |
+
"""Fixed builder capacities serialized into the Hub runtime config."""
|
| 54 |
+
|
| 55 |
+
max_context_individuals: int
|
| 56 |
+
max_target_individuals: int
|
| 57 |
+
max_context_observations: int
|
| 58 |
+
max_target_observations: int
|
| 59 |
+
max_context_remaining: int
|
| 60 |
+
max_target_remaining: int
|
| 61 |
+
|
| 62 |
+
def to_dict(self) -> Dict[str, int]:
|
| 63 |
+
"""Return a JSON-serializable representation."""
|
| 64 |
+
|
| 65 |
+
return asdict(self)
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def from_dict(cls, payload: Mapping[str, Any]) -> "RuntimeBuilderConfig":
|
| 69 |
+
"""Instantiate the builder capacities from serialized config payload."""
|
| 70 |
+
|
| 71 |
+
return cls(
|
| 72 |
+
max_context_individuals=int(payload["max_context_individuals"]),
|
| 73 |
+
max_target_individuals=int(payload["max_target_individuals"]),
|
| 74 |
+
max_context_observations=int(payload["max_context_observations"]),
|
| 75 |
+
max_target_observations=int(payload["max_target_observations"]),
|
| 76 |
+
max_context_remaining=int(payload["max_context_remaining"]),
|
| 77 |
+
max_target_remaining=int(payload["max_target_remaining"]),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def to_empirical_batch_config(self, *, max_databatch_size: int) -> EmpiricalBatchConfig:
|
| 81 |
+
"""Translate runtime capacities to the builder used by StudyJSON IO."""
|
| 82 |
+
|
| 83 |
+
return EmpiricalBatchConfig(
|
| 84 |
+
max_databatch_size=int(max_databatch_size),
|
| 85 |
+
max_individuals=max(self.max_context_individuals, self.max_target_individuals),
|
| 86 |
+
max_observations=max(self.max_context_observations, self.max_target_observations),
|
| 87 |
+
max_remaining=max(self.max_context_remaining, self.max_target_remaining),
|
| 88 |
+
max_context_individuals=self.max_context_individuals,
|
| 89 |
+
max_target_individuals=self.max_target_individuals,
|
| 90 |
+
max_context_observations=self.max_context_observations,
|
| 91 |
+
max_target_observations=self.max_target_observations,
|
| 92 |
+
max_context_remaining=self.max_context_remaining,
|
| 93 |
+
max_target_remaining=self.max_target_remaining,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _coerce_annotation(annotation: Any, value: Any) -> Any:
|
| 98 |
+
"""Best-effort coercion of JSON-loaded values into dataclass field types."""
|
| 99 |
+
|
| 100 |
+
if value is None:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
origin = get_origin(annotation)
|
| 104 |
+
args = get_args(annotation)
|
| 105 |
+
|
| 106 |
+
if origin is Union:
|
| 107 |
+
non_none = [arg for arg in args if arg is not type(None)]
|
| 108 |
+
for candidate in non_none:
|
| 109 |
+
if candidate in (dict, Dict, Any, Mapping):
|
| 110 |
+
continue
|
| 111 |
+
try:
|
| 112 |
+
return _coerce_annotation(candidate, value)
|
| 113 |
+
except Exception:
|
| 114 |
+
continue
|
| 115 |
+
return value
|
| 116 |
+
|
| 117 |
+
if origin in (list, List, Sequence):
|
| 118 |
+
(inner_type,) = args if args else (Any,)
|
| 119 |
+
return [_coerce_annotation(inner_type, item) for item in value]
|
| 120 |
+
|
| 121 |
+
if origin in (tuple,):
|
| 122 |
+
if not args:
|
| 123 |
+
return tuple(value)
|
| 124 |
+
if len(args) == 2 and args[1] is Ellipsis:
|
| 125 |
+
return tuple(_coerce_annotation(args[0], item) for item in value)
|
| 126 |
+
return tuple(_coerce_annotation(inner, item) for inner, item in zip(args, value))
|
| 127 |
+
|
| 128 |
+
if origin in (dict, Dict, Mapping):
|
| 129 |
+
return dict(value)
|
| 130 |
+
|
| 131 |
+
if annotation is Any:
|
| 132 |
+
return value
|
| 133 |
+
|
| 134 |
+
if annotation is MetaStudyConfig and isinstance(value, Mapping) and value.get("simple_mode"):
|
| 135 |
+
return SimpleMetaStudyConfig(**dict(value))
|
| 136 |
+
|
| 137 |
+
if hasattr(annotation, "__dataclass_fields__") and isinstance(value, Mapping):
|
| 138 |
+
kwargs = {}
|
| 139 |
+
for field_name, field_def in annotation.__dataclass_fields__.items():
|
| 140 |
+
if field_name in value:
|
| 141 |
+
kwargs[field_name] = _coerce_annotation(field_def.type, value[field_name])
|
| 142 |
+
return annotation(**kwargs)
|
| 143 |
+
|
| 144 |
+
return value
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _rebuild_node_config(payload: Mapping[str, Any]) -> NodePKExperimentConfig:
|
| 148 |
+
"""Reconstruct a ``NodePKExperimentConfig`` from serialized dict content."""
|
| 149 |
+
|
| 150 |
+
return NodePKExperimentConfig(
|
| 151 |
+
experiment_type=str(payload.get("experiment_type", "nodepk")).lower(),
|
| 152 |
+
name_str=str(payload.get("name_str", "NodePK")),
|
| 153 |
+
comet_ai_key=payload.get("comet_ai_key"),
|
| 154 |
+
experiment_name=str(payload.get("experiment_name", "node_pk_compartments")),
|
| 155 |
+
hugging_face_token=payload.get("hugging_face_token"),
|
| 156 |
+
upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)),
|
| 157 |
+
hf_model_name=str(payload.get("hf_model_name", "NodePK_runtime")),
|
| 158 |
+
hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))),
|
| 159 |
+
tags=list(payload.get("tags", [])),
|
| 160 |
+
experiment_indentifier=payload.get("experiment_indentifier"),
|
| 161 |
+
my_results_path=payload.get("my_results_path"),
|
| 162 |
+
experiment_dir=payload.get("experiment_dir"),
|
| 163 |
+
verbose=bool(payload.get("verbose", False)),
|
| 164 |
+
run_index=int(payload.get("run_index", 0)),
|
| 165 |
+
debug_test=bool(payload.get("debug_test", False)),
|
| 166 |
+
network=_coerce_annotation(EncoderDecoderNetworkConfig, payload.get("network", {})),
|
| 167 |
+
mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})),
|
| 168 |
+
context_observations=_coerce_annotation(
|
| 169 |
+
ObservationsConfig, payload.get("context_observations", {})
|
| 170 |
+
),
|
| 171 |
+
target_observations=_coerce_annotation(
|
| 172 |
+
ObservationsConfig, payload.get("target_observations", {})
|
| 173 |
+
),
|
| 174 |
+
meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})),
|
| 175 |
+
dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})),
|
| 176 |
+
train=_coerce_annotation(TrainingConfig, payload.get("train", {})),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _rebuild_flow_config(payload: Mapping[str, Any]) -> FlowPKExperimentConfig:
|
| 181 |
+
"""Reconstruct a ``FlowPKExperimentConfig`` from serialized dict content."""
|
| 182 |
+
|
| 183 |
+
return FlowPKExperimentConfig(
|
| 184 |
+
experiment_type=str(payload.get("experiment_type", "flowpk")).lower(),
|
| 185 |
+
name_str=str(payload.get("name_str", "FlowPK")),
|
| 186 |
+
comet_ai_key=payload.get("comet_ai_key"),
|
| 187 |
+
experiment_name=str(payload.get("experiment_name", "flow_pk_compartments")),
|
| 188 |
+
hugging_face_token=payload.get("hugging_face_token"),
|
| 189 |
+
upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)),
|
| 190 |
+
hf_model_name=str(payload.get("hf_model_name", "FlowPK_runtime")),
|
| 191 |
+
hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))),
|
| 192 |
+
tags=list(payload.get("tags", [])),
|
| 193 |
+
experiment_indentifier=payload.get("experiment_indentifier"),
|
| 194 |
+
my_results_path=payload.get("my_results_path"),
|
| 195 |
+
experiment_dir=payload.get("experiment_dir"),
|
| 196 |
+
verbose=bool(payload.get("verbose", False)),
|
| 197 |
+
run_index=int(payload.get("run_index", 0)),
|
| 198 |
+
debug_test=bool(payload.get("debug_test", False)),
|
| 199 |
+
flow_num_steps=int(payload.get("flow_num_steps", 50)),
|
| 200 |
+
vector_field=_coerce_annotation(VectorFieldPKConfig, payload.get("vector_field", {})),
|
| 201 |
+
source_process=_coerce_annotation(SourceProcessConfig, payload.get("source_process", {})),
|
| 202 |
+
mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})),
|
| 203 |
+
context_observations=_coerce_annotation(
|
| 204 |
+
ObservationsConfig, payload.get("context_observations", {})
|
| 205 |
+
),
|
| 206 |
+
target_observations=_coerce_annotation(
|
| 207 |
+
ObservationsConfig, payload.get("target_observations", {})
|
| 208 |
+
),
|
| 209 |
+
meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})),
|
| 210 |
+
dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})),
|
| 211 |
+
train=_coerce_annotation(TrainingConfig, payload.get("train", {})),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _rebuild_diffusion_config(payload: Mapping[str, Any]) -> DiffusionPKExperimentConfig:
|
| 216 |
+
"""Reconstruct a ``DiffusionPKExperimentConfig`` from serialized dict content."""
|
| 217 |
+
|
| 218 |
+
return DiffusionPKExperimentConfig(
|
| 219 |
+
experiment_type=str(payload.get("experiment_type", "diffusionpk")).lower(),
|
| 220 |
+
name_str=str(payload.get("name_str", "ContinuousDiffusionPK")),
|
| 221 |
+
diffusion_type=str(payload.get("diffusion_type", "continuous")),
|
| 222 |
+
comet_ai_key=payload.get("comet_ai_key"),
|
| 223 |
+
experiment_name=str(payload.get("experiment_name", "diffusion_pk_compartments")),
|
| 224 |
+
hugging_face_token=payload.get("hugging_face_token"),
|
| 225 |
+
upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)),
|
| 226 |
+
hf_model_name=str(payload.get("hf_model_name", "DiffusionPK_runtime")),
|
| 227 |
+
hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))),
|
| 228 |
+
tags=list(payload.get("tags", [])),
|
| 229 |
+
experiment_indentifier=payload.get("experiment_indentifier"),
|
| 230 |
+
my_results_path=payload.get("my_results_path"),
|
| 231 |
+
experiment_dir=payload.get("experiment_dir"),
|
| 232 |
+
verbose=bool(payload.get("verbose", False)),
|
| 233 |
+
run_index=int(payload.get("run_index", 0)),
|
| 234 |
+
debug_test=bool(payload.get("debug_test", False)),
|
| 235 |
+
predict_gaussian_noise=bool(payload.get("predict_gaussian_noise", True)),
|
| 236 |
+
network=_coerce_annotation(EncoderDecoderNetworkConfig, payload.get("network", {})),
|
| 237 |
+
source_process=_coerce_annotation(SourceProcessConfig, payload.get("source_process", {})),
|
| 238 |
+
mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})),
|
| 239 |
+
context_observations=_coerce_annotation(
|
| 240 |
+
ObservationsConfig, payload.get("context_observations", {})
|
| 241 |
+
),
|
| 242 |
+
target_observations=_coerce_annotation(
|
| 243 |
+
ObservationsConfig, payload.get("target_observations", {})
|
| 244 |
+
),
|
| 245 |
+
meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})),
|
| 246 |
+
dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})),
|
| 247 |
+
train=_coerce_annotation(TrainingConfig, payload.get("train", {})),
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def rebuild_experiment_config(
|
| 252 |
+
payload: Mapping[str, Any],
|
| 253 |
+
) -> Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig]:
|
| 254 |
+
"""Rebuild the serialized experiment config stored in the Hub config."""
|
| 255 |
+
|
| 256 |
+
experiment_type = str(payload.get("experiment_type", "nodepk")).lower()
|
| 257 |
+
if experiment_type == "nodepk":
|
| 258 |
+
return _rebuild_node_config(payload)
|
| 259 |
+
if experiment_type == "flowpk":
|
| 260 |
+
return _rebuild_flow_config(payload)
|
| 261 |
+
if experiment_type == "diffusionpk":
|
| 262 |
+
return _rebuild_diffusion_config(payload)
|
| 263 |
+
raise ValueError(f"Unsupported experiment_type for runtime bundle: {experiment_type!r}.")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def compute_runtime_builder_config(
|
| 267 |
+
exp_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig],
|
| 268 |
+
) -> RuntimeBuilderConfig:
|
| 269 |
+
"""Compute fixed empirical StudyJSON capacities from the experiment config."""
|
| 270 |
+
|
| 271 |
+
context_strategy = ObservationStrategyFactory.from_config(
|
| 272 |
+
exp_config.context_observations,
|
| 273 |
+
exp_config.meta_study,
|
| 274 |
+
)
|
| 275 |
+
target_strategy = ObservationStrategyFactory.from_config(
|
| 276 |
+
exp_config.target_observations,
|
| 277 |
+
exp_config.meta_study,
|
| 278 |
+
)
|
| 279 |
+
ctx_obs_cap, ctx_rem_cap = context_strategy.get_shapes()
|
| 280 |
+
tgt_obs_cap, tgt_rem_cap = target_strategy.get_shapes()
|
| 281 |
+
|
| 282 |
+
max_context_individuals = int(exp_config.meta_study.num_individuals_range[-1])
|
| 283 |
+
max_target_individuals = int(getattr(exp_config.mix_data, "n_of_target_individuals", 1))
|
| 284 |
+
if max_target_individuals < 0:
|
| 285 |
+
raise ValueError("n_of_target_individuals must be >= 0 for Hub runtime export.")
|
| 286 |
+
|
| 287 |
+
return RuntimeBuilderConfig(
|
| 288 |
+
max_context_individuals=max_context_individuals,
|
| 289 |
+
max_target_individuals=max_target_individuals,
|
| 290 |
+
max_context_observations=int(ctx_obs_cap),
|
| 291 |
+
max_target_observations=int(tgt_obs_cap),
|
| 292 |
+
max_context_remaining=int(ctx_rem_cap),
|
| 293 |
+
max_target_remaining=int(tgt_rem_cap),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def infer_supported_tasks(backbone: torch.nn.Module) -> List[str]:
|
| 298 |
+
"""Infer the public task surface supported by the wrapped model."""
|
| 299 |
+
|
| 300 |
+
tasks: List[str] = []
|
| 301 |
+
if isinstance(backbone, NewGenerativeMixin):
|
| 302 |
+
tasks.append("generate")
|
| 303 |
+
if isinstance(backbone, NewPredictiveMixin):
|
| 304 |
+
tasks.append("predict")
|
| 305 |
+
return tasks
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def validate_runtime_architecture(backbone: torch.nn.Module) -> str:
|
| 309 |
+
"""Ensure the loaded architecture is supported by the runtime bundle v1."""
|
| 310 |
+
|
| 311 |
+
architecture_name = backbone.__class__.__name__
|
| 312 |
+
if architecture_name not in SUPPORTED_RUNTIME_ARCHITECTURES:
|
| 313 |
+
raise ValueError(
|
| 314 |
+
"Runtime Hub export only supports "
|
| 315 |
+
f"{sorted(SUPPORTED_RUNTIME_ARCHITECTURES)}, got {architecture_name!r}."
|
| 316 |
+
)
|
| 317 |
+
return architecture_name
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def build_runtime_config_payload(
|
| 321 |
+
*,
|
| 322 |
+
backbone: torch.nn.Module,
|
| 323 |
+
exp_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig],
|
| 324 |
+
original_repo_id: Optional[str],
|
| 325 |
+
runtime_repo_id: Optional[str],
|
| 326 |
+
) -> Dict[str, Any]:
|
| 327 |
+
"""Build the serializable fields stored in the Hub config."""
|
| 328 |
+
|
| 329 |
+
architecture_name = validate_runtime_architecture(backbone)
|
| 330 |
+
supported_tasks = infer_supported_tasks(backbone)
|
| 331 |
+
if not supported_tasks:
|
| 332 |
+
raise ValueError(f"Model {architecture_name!r} does not expose runtime tasks.")
|
| 333 |
+
|
| 334 |
+
builder_config = compute_runtime_builder_config(exp_config)
|
| 335 |
+
return {
|
| 336 |
+
"architecture_name": architecture_name,
|
| 337 |
+
"experiment_type": str(getattr(exp_config, "experiment_type", "nodepk")).lower(),
|
| 338 |
+
"experiment_config": asdict(exp_config),
|
| 339 |
+
"builder_config": builder_config.to_dict(),
|
| 340 |
+
"supported_tasks": supported_tasks,
|
| 341 |
+
"default_task": supported_tasks[0],
|
| 342 |
+
"io_schema_version": STUDY_JSON_IO_VERSION,
|
| 343 |
+
"original_repo_id": original_repo_id,
|
| 344 |
+
"runtime_repo_id": runtime_repo_id,
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def instantiate_backbone_from_hub_config(config: PretrainedConfig) -> torch.nn.Module:
|
| 349 |
+
"""Rebuild the internal PK model represented by the public Hub wrapper."""
|
| 350 |
+
|
| 351 |
+
experiment_config_payload = getattr(config, "experiment_config", None)
|
| 352 |
+
if not isinstance(experiment_config_payload, Mapping):
|
| 353 |
+
raise ValueError("Hub config is missing the serialized experiment_config payload.")
|
| 354 |
+
exp_config = rebuild_experiment_config(experiment_config_payload)
|
| 355 |
+
model_cls = get_model_class(exp_config)
|
| 356 |
+
backbone = model_cls(exp_config)
|
| 357 |
+
backbone.eval()
|
| 358 |
+
return backbone
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def normalize_studies_input(
|
| 362 |
+
studies: Union[StudyJSON, Sequence[StudyJSON]],
|
| 363 |
+
) -> List[StudyJSON]:
|
| 364 |
+
"""Normalize runtime input to a mutable list of canonicalized studies."""
|
| 365 |
+
|
| 366 |
+
if isinstance(studies, Mapping):
|
| 367 |
+
raw_studies = [dict(studies)]
|
| 368 |
+
else:
|
| 369 |
+
raw_studies = [dict(study) for study in studies]
|
| 370 |
+
return [canonicalize_study(study, drop_tgt_too_few=False) for study in raw_studies]
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def validate_studies_for_task(
|
| 374 |
+
studies: Sequence[StudyJSON],
|
| 375 |
+
*,
|
| 376 |
+
task: str,
|
| 377 |
+
builder_config: RuntimeBuilderConfig,
|
| 378 |
+
) -> None:
|
| 379 |
+
"""Validate task semantics and reject inputs that exceed runtime capacities."""
|
| 380 |
+
|
| 381 |
+
for study_idx, study in enumerate(studies):
|
| 382 |
+
context = list(study.get("context", []))
|
| 383 |
+
target = list(study.get("target", []))
|
| 384 |
+
|
| 385 |
+
if task == "generate":
|
| 386 |
+
if not context:
|
| 387 |
+
raise ValueError("`generate` requires at least one context individual per study.")
|
| 388 |
+
if target:
|
| 389 |
+
raise ValueError("`generate` expects target to be empty in the input StudyJSON.")
|
| 390 |
+
elif task == "predict":
|
| 391 |
+
if not target:
|
| 392 |
+
raise ValueError("`predict` requires at least one target individual per study.")
|
| 393 |
+
else:
|
| 394 |
+
raise ValueError(f"Unsupported task {task!r}.")
|
| 395 |
+
|
| 396 |
+
if len(context) > builder_config.max_context_individuals:
|
| 397 |
+
raise ValueError(
|
| 398 |
+
f"Study {study_idx} exceeds context individual capacity "
|
| 399 |
+
f"({len(context)} > {builder_config.max_context_individuals})."
|
| 400 |
+
)
|
| 401 |
+
if len(target) > builder_config.max_target_individuals:
|
| 402 |
+
raise ValueError(
|
| 403 |
+
f"Study {study_idx} exceeds target individual capacity "
|
| 404 |
+
f"({len(target)} > {builder_config.max_target_individuals})."
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
_validate_individual_block(
|
| 408 |
+
study_idx=study_idx,
|
| 409 |
+
block_name="context",
|
| 410 |
+
individuals=context,
|
| 411 |
+
max_observations=builder_config.max_context_observations,
|
| 412 |
+
max_remaining=builder_config.max_context_remaining,
|
| 413 |
+
)
|
| 414 |
+
_validate_individual_block(
|
| 415 |
+
study_idx=study_idx,
|
| 416 |
+
block_name="target",
|
| 417 |
+
individuals=target,
|
| 418 |
+
max_observations=builder_config.max_target_observations,
|
| 419 |
+
max_remaining=builder_config.max_target_remaining,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _validate_individual_block(
|
| 424 |
+
*,
|
| 425 |
+
study_idx: int,
|
| 426 |
+
block_name: str,
|
| 427 |
+
individuals: Sequence[IndividualJSON],
|
| 428 |
+
max_observations: int,
|
| 429 |
+
max_remaining: int,
|
| 430 |
+
) -> None:
|
| 431 |
+
"""Reject studies that would otherwise be truncated by the empirical builder."""
|
| 432 |
+
|
| 433 |
+
for ind_idx, individual in enumerate(individuals):
|
| 434 |
+
obs_len = len(individual.get("observations", []))
|
| 435 |
+
rem_len = len(individual.get("remaining", []))
|
| 436 |
+
if obs_len > max_observations:
|
| 437 |
+
raise ValueError(
|
| 438 |
+
f"Study {study_idx} {block_name}[{ind_idx}] exceeds observation capacity "
|
| 439 |
+
f"({obs_len} > {max_observations})."
|
| 440 |
+
)
|
| 441 |
+
if rem_len > max_remaining:
|
| 442 |
+
raise ValueError(
|
| 443 |
+
f"Study {study_idx} {block_name}[{ind_idx}] exceeds remaining capacity "
|
| 444 |
+
f"({rem_len} > {max_remaining})."
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def build_batch_from_studies(
|
| 449 |
+
studies: Sequence[StudyJSON],
|
| 450 |
+
*,
|
| 451 |
+
builder_config: RuntimeBuilderConfig,
|
| 452 |
+
meta_dosing: MetaDosingConfig,
|
| 453 |
+
):
|
| 454 |
+
"""Convert canonical studies into the internal PK databatch representation."""
|
| 455 |
+
|
| 456 |
+
builder = JSON2AICMEBuilder(
|
| 457 |
+
builder_config.to_empirical_batch_config(max_databatch_size=max(1, len(studies)))
|
| 458 |
+
)
|
| 459 |
+
return builder.build_one_aicmebatch(list(studies), meta_dosing)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def split_runtime_samples(task: str, study: StudyJSON) -> List[StudyJSON]:
|
| 463 |
+
"""Convert model-specific StudyJSON outputs into per-sample StudyJSONs."""
|
| 464 |
+
|
| 465 |
+
if task == "generate":
|
| 466 |
+
return _split_generate_samples(study)
|
| 467 |
+
if task == "predict":
|
| 468 |
+
return _split_predict_samples(study)
|
| 469 |
+
raise ValueError(f"Unsupported task {task!r}.")
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _split_generate_samples(study: StudyJSON) -> List[StudyJSON]:
|
| 473 |
+
"""Split generated target individuals into one StudyJSON per sample."""
|
| 474 |
+
|
| 475 |
+
targets = list(study.get("target", []))
|
| 476 |
+
if not targets:
|
| 477 |
+
return [deepcopy(study)]
|
| 478 |
+
|
| 479 |
+
split: List[StudyJSON] = []
|
| 480 |
+
for target in targets:
|
| 481 |
+
split.append(
|
| 482 |
+
{
|
| 483 |
+
"context": deepcopy(study.get("context", [])),
|
| 484 |
+
"target": [deepcopy(target)],
|
| 485 |
+
"meta_data": deepcopy(study.get("meta_data", {})),
|
| 486 |
+
}
|
| 487 |
+
)
|
| 488 |
+
return split
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def _split_predict_samples(study: StudyJSON) -> List[StudyJSON]:
|
| 492 |
+
"""Split target prediction samples into one StudyJSON per sample index."""
|
| 493 |
+
|
| 494 |
+
targets = list(study.get("target", []))
|
| 495 |
+
if not targets:
|
| 496 |
+
return [deepcopy(study)]
|
| 497 |
+
|
| 498 |
+
sample_count = 0
|
| 499 |
+
for target in targets:
|
| 500 |
+
sample_count = max(sample_count, len(target.get("prediction_samples", [])))
|
| 501 |
+
if sample_count == 0:
|
| 502 |
+
return [deepcopy(study)]
|
| 503 |
+
|
| 504 |
+
split: List[StudyJSON] = []
|
| 505 |
+
for sample_idx in range(sample_count):
|
| 506 |
+
target_block: List[IndividualJSON] = []
|
| 507 |
+
for target in targets:
|
| 508 |
+
target_copy: IndividualJSON = deepcopy(target)
|
| 509 |
+
samples = list(target.get("prediction_samples", []))
|
| 510 |
+
if samples:
|
| 511 |
+
if sample_idx >= len(samples):
|
| 512 |
+
raise ValueError(
|
| 513 |
+
"All target individuals must expose the same number of prediction samples."
|
| 514 |
+
)
|
| 515 |
+
target_copy["prediction_samples"] = [deepcopy(samples[sample_idx])]
|
| 516 |
+
target_block.append(target_copy)
|
| 517 |
+
|
| 518 |
+
split.append(
|
| 519 |
+
{
|
| 520 |
+
"context": deepcopy(study.get("context", [])),
|
| 521 |
+
"target": target_block,
|
| 522 |
+
"meta_data": deepcopy(study.get("meta_data", {})),
|
| 523 |
+
}
|
| 524 |
+
)
|
| 525 |
+
return split
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def runtime_readme_text(
|
| 529 |
+
*,
|
| 530 |
+
base_model_card: str,
|
| 531 |
+
runtime_repo_id: str,
|
| 532 |
+
original_repo_id: Optional[str],
|
| 533 |
+
supported_tasks: Sequence[str],
|
| 534 |
+
default_task: str,
|
| 535 |
+
) -> str:
|
| 536 |
+
"""Compose the README uploaded with the consumer-facing runtime bundle."""
|
| 537 |
+
|
| 538 |
+
original_line = (
|
| 539 |
+
f"- Native training/artifact repo: `{original_repo_id}`"
|
| 540 |
+
if original_repo_id
|
| 541 |
+
else "- Native training/artifact repo: not recorded"
|
| 542 |
+
)
|
| 543 |
+
tasks_literal = ", ".join(f"`{task}`" for task in supported_tasks)
|
| 544 |
+
|
| 545 |
+
usage = f"""
|
| 546 |
+
|
| 547 |
+
## Runtime Bundle
|
| 548 |
+
|
| 549 |
+
This repository is the consumer-facing runtime bundle for this PK model.
|
| 550 |
+
|
| 551 |
+
- Runtime repo: `{runtime_repo_id}`
|
| 552 |
+
{original_line}
|
| 553 |
+
- Supported tasks: {tasks_literal}
|
| 554 |
+
- Default task: `{default_task}`
|
| 555 |
+
- Load path: `AutoModel.from_pretrained(..., trust_remote_code=True)`
|
| 556 |
+
|
| 557 |
+
### Installation
|
| 558 |
+
|
| 559 |
+
You do **not** need to install `sim_priors_pk` to use this runtime bundle.
|
| 560 |
+
|
| 561 |
+
`transformers` is the public loading entrypoint, but `transformers` alone is
|
| 562 |
+
not sufficient because this is a PyTorch model with custom runtime code. A
|
| 563 |
+
reliable consumer environment is:
|
| 564 |
+
|
| 565 |
+
```bash
|
| 566 |
+
pip install torch transformers huggingface_hub lightning datasets pandas torchtyping gpytorch pot torchdiffeq torchsde ruamel.yaml pyyaml
|
| 567 |
+
```
|
| 568 |
+
|
| 569 |
+
### Python Usage
|
| 570 |
+
|
| 571 |
+
```python
|
| 572 |
+
from transformers import AutoModel
|
| 573 |
+
|
| 574 |
+
model = AutoModel.from_pretrained("{runtime_repo_id}", trust_remote_code=True)
|
| 575 |
+
|
| 576 |
+
studies = [
|
| 577 |
+
{{
|
| 578 |
+
"context": [
|
| 579 |
+
{{
|
| 580 |
+
"name_id": "ctx_0",
|
| 581 |
+
"observations": [0.2, 0.5, 0.3],
|
| 582 |
+
"observation_times": [0.5, 1.0, 2.0],
|
| 583 |
+
"dosing": [1.0],
|
| 584 |
+
"dosing_type": ["oral"],
|
| 585 |
+
"dosing_times": [0.0],
|
| 586 |
+
"dosing_name": ["oral"],
|
| 587 |
+
}}
|
| 588 |
+
],
|
| 589 |
+
"target": [],
|
| 590 |
+
"meta_data": {{"study_name": "demo", "substance_name": "drug_x"}},
|
| 591 |
+
}}
|
| 592 |
+
]
|
| 593 |
+
|
| 594 |
+
outputs = model.run_task(
|
| 595 |
+
task="{default_task}",
|
| 596 |
+
studies=studies,
|
| 597 |
+
num_samples=4,
|
| 598 |
+
)
|
| 599 |
+
print(outputs["results"][0]["samples"])
|
| 600 |
+
```
|
| 601 |
+
|
| 602 |
+
### Predictive Sampling
|
| 603 |
+
|
| 604 |
+
```python
|
| 605 |
+
from transformers import AutoModel
|
| 606 |
+
|
| 607 |
+
model = AutoModel.from_pretrained("{runtime_repo_id}", trust_remote_code=True)
|
| 608 |
+
|
| 609 |
+
predict_studies = [
|
| 610 |
+
{{
|
| 611 |
+
"context": [
|
| 612 |
+
{{
|
| 613 |
+
"name_id": "ctx_0",
|
| 614 |
+
"observations": [0.2, 0.5, 0.3],
|
| 615 |
+
"observation_times": [0.5, 1.0, 2.0],
|
| 616 |
+
"dosing": [1.0],
|
| 617 |
+
"dosing_type": ["oral"],
|
| 618 |
+
"dosing_times": [0.0],
|
| 619 |
+
"dosing_name": ["oral"],
|
| 620 |
+
}}
|
| 621 |
+
],
|
| 622 |
+
"target": [
|
| 623 |
+
{{
|
| 624 |
+
"name_id": "tgt_0",
|
| 625 |
+
"observations": [0.25, 0.31],
|
| 626 |
+
"observation_times": [0.5, 1.0],
|
| 627 |
+
"remaining": [0.0, 0.0, 0.0],
|
| 628 |
+
"remaining_times": [2.0, 4.0, 8.0],
|
| 629 |
+
"dosing": [1.0],
|
| 630 |
+
"dosing_type": ["oral"],
|
| 631 |
+
"dosing_times": [0.0],
|
| 632 |
+
"dosing_name": ["oral"],
|
| 633 |
+
}}
|
| 634 |
+
],
|
| 635 |
+
"meta_data": {{"study_name": "demo", "substance_name": "drug_x"}},
|
| 636 |
+
}}
|
| 637 |
+
]
|
| 638 |
+
|
| 639 |
+
outputs = model.run_task(
|
| 640 |
+
task="predict",
|
| 641 |
+
studies=predict_studies,
|
| 642 |
+
num_samples=4,
|
| 643 |
+
)
|
| 644 |
+
print(outputs["results"][0]["samples"][0]["target"][0]["prediction_samples"])
|
| 645 |
+
```
|
| 646 |
+
|
| 647 |
+
### Notes
|
| 648 |
+
|
| 649 |
+
- `trust_remote_code=True` is required because this model uses custom Hugging Face Hub runtime code.
|
| 650 |
+
- The consumer API is `transformers` + `run_task(...)`; the consumer does not need a local clone of this repository.
|
| 651 |
+
- This runtime bundle is intentionally separate from the native training export so you can evaluate both distribution paths in parallel.
|
| 652 |
+
"""
|
| 653 |
+
|
| 654 |
+
return base_model_card.rstrip() + "\n" + usage.strip() + "\n"
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def resolve_model_card_text(model_card_path: Path) -> str:
|
| 658 |
+
"""Read and validate the model card that seeds the runtime README."""
|
| 659 |
+
|
| 660 |
+
if not model_card_path.is_file():
|
| 661 |
+
raise FileNotFoundError(f"Model card not found at: {model_card_path}")
|
| 662 |
+
return model_card_path.read_text(encoding="utf-8")
|
sim_priors_pk/metrics/__init__.py
ADDED
|
File without changes
|
sim_priors_pk/metrics/pk_metrics.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List,Tuple
|
| 4 |
+
from matplotlib import pyplot as plt
|
| 5 |
+
from torchtyping import TensorType, patch_typeguard
|
| 6 |
+
from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
|
| 7 |
+
from scipy import stats
|
| 8 |
+
import torch
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
Tensor = torch.Tensor # for brevity – keep your own alias if you prefer
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
def ensure_folder_exists(folder_name: str):
|
| 15 |
+
if not os.path.exists(folder_name):
|
| 16 |
+
os.makedirs(folder_name)
|
| 17 |
+
print(f"✅ Created folder: {folder_name}")
|
| 18 |
+
else:
|
| 19 |
+
print(f"📁 Folder already exists: {folder_name}")
|
| 20 |
+
|
| 21 |
+
def combine_samples(
|
| 22 |
+
samples_list: list[TensorType["S", "B", "I", "T", 1]]
|
| 23 |
+
) -> TensorType["S", "P", "T"]:
|
| 24 |
+
"""
|
| 25 |
+
Given:
|
| 26 |
+
samples_list: list of length P, each tensor of shape [S, B, I, T, 1]
|
| 27 |
+
(here B = I = 1)
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
combined: tensor of shape [S, P, T]
|
| 31 |
+
"""
|
| 32 |
+
# 1) Extract the [S, T] slice from each sample (drop B=1, I=1, last dim=1)
|
| 33 |
+
# - s[:, 0, 0, :, 0] has shape [S, T]
|
| 34 |
+
squeezed: list[TensorType["S", "T"]] = [
|
| 35 |
+
s[:, 0, 0, :, 0]
|
| 36 |
+
for s in samples_list
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# 2) Stack along a new “permutation” axis P → [S, P, T]
|
| 40 |
+
combined: TensorType["S", "P", "T"] = torch.stack(squeezed, dim=1)
|
| 41 |
+
return combined
|
| 42 |
+
|
| 43 |
+
def extract_context_by_mask(
|
| 44 |
+
db: AICMECompartmentsDataBatch
|
| 45 |
+
) -> Tuple[
|
| 46 |
+
List[TensorType["n_i"]], # context observations per compartment
|
| 47 |
+
List[TensorType["n_i"]] # context times per compartment
|
| 48 |
+
]:
|
| 49 |
+
"""
|
| 50 |
+
For B=1, from a single AICMECompartmentsDataBatch:
|
| 51 |
+
- db.context_obs: [1, c_ind, num_obs_c, 1]
|
| 52 |
+
- db.context_obs_time: [1, c_ind, num_obs_c, 1]
|
| 53 |
+
- db.context_obs_mask: [1, c_ind, num_obs_c]
|
| 54 |
+
|
| 55 |
+
Returns two lists of length c_ind:
|
| 56 |
+
obs_list[i].shape == (n_i,) selects those obs where mask==1
|
| 57 |
+
time_list[i].shape == (n_i,) selects the corresponding times
|
| 58 |
+
"""
|
| 59 |
+
# Unpack and assert B=1
|
| 60 |
+
B, c_ind, num_obs_c, one = db.context_obs.shape
|
| 61 |
+
assert B == 1 and one == 1, f"Expected B=1 and last dim=1, got B={B}, last={one}"
|
| 62 |
+
|
| 63 |
+
# Drop the batch and singleton dims:
|
| 64 |
+
# [1, c_ind, num_obs_c, 1] → [c_ind, num_obs_c]
|
| 65 |
+
obs = db.context_obs.squeeze(0).squeeze(-1) # TensorType["c_ind", "num_obs_c"]
|
| 66 |
+
times = db.context_obs_time.squeeze(0).squeeze(-1) # TensorType["c_ind", "num_obs_c"]
|
| 67 |
+
mask = db.context_obs_mask.squeeze(0) # TensorType["c_ind", "num_obs_c"]
|
| 68 |
+
|
| 69 |
+
obs_list: List[torch.Tensor] = []
|
| 70 |
+
time_list: List[torch.Tensor] = []
|
| 71 |
+
|
| 72 |
+
for i in range(c_ind):
|
| 73 |
+
mi = mask[i].bool() # [num_obs_c]
|
| 74 |
+
obs_i = obs[i][mi] # [n_i]
|
| 75 |
+
times_i = times[i][mi] # [n_i]
|
| 76 |
+
obs_list.append(obs_i)
|
| 77 |
+
time_list.append(times_i)
|
| 78 |
+
|
| 79 |
+
return obs_list, time_list
|
| 80 |
+
|
| 81 |
+
def compute_pd(
|
| 82 |
+
y_obs : TensorType["I", "T"], # observed data
|
| 83 |
+
y_sim : TensorType["S", "I", "T"], # S simulated datasets
|
| 84 |
+
mask : TensorType["I", "T"], # True/1 = valid obs
|
| 85 |
+
) -> TensorType["I", "T"]: # pd, NaN where mask == 0
|
| 86 |
+
"""
|
| 87 |
+
NOTICE THAT THERE IS NO BATCH INDEX, this works only on individual substances
|
| 88 |
+
|
| 89 |
+
Prediction discrepancy (pd) — Eq. (4) Comets et al. 2008
|
| 90 |
+
|
| 91 |
+
Parameters
|
| 92 |
+
----------
|
| 93 |
+
y_obs : [I, T] observed values (padding value doesn't matter,
|
| 94 |
+
because `mask` says which entries to trust)
|
| 95 |
+
y_sim : [S, I, T] S Monte-Carlo replicates generated from the model
|
| 96 |
+
mask : [I, T] binary mask — True at valid observation points
|
| 97 |
+
|
| 98 |
+
Returns
|
| 99 |
+
-------
|
| 100 |
+
pd : [I, T] empirical CDF value at (i,j); NaN where mask==0
|
| 101 |
+
"""
|
| 102 |
+
S, I, T = y_sim.shape
|
| 103 |
+
assert y_obs.shape == (I, T), "y_obs must be [I,T]"
|
| 104 |
+
assert mask.shape == (I, T), "mask must be [I,T]"
|
| 105 |
+
|
| 106 |
+
# Expand y_obs to [S,I,T] so we can broadcast the < comparison
|
| 107 |
+
y_obs_exp = y_obs.unsqueeze(0).expand(S, -1, -1) # [S,I,T]
|
| 108 |
+
|
| 109 |
+
# δ_{ijk} = 1 if y_sim < y_obs else 0
|
| 110 |
+
delta = (y_sim < y_obs_exp).float() # [S,I,T]
|
| 111 |
+
|
| 112 |
+
# average over the S simulations → empirical CDF
|
| 113 |
+
pd = delta.mean(dim=0) # [I,T]
|
| 114 |
+
|
| 115 |
+
# put NaN where mask == 0 so the caller knows which are pads
|
| 116 |
+
pd = torch.where(mask.bool(), pd, torch.full_like(pd, float("nan")))
|
| 117 |
+
|
| 118 |
+
return pd
|
| 119 |
+
|
| 120 |
+
def sample_covariance_manual_torch(
|
| 121 |
+
X: TensorType["S", "Tv"] # simulations for one subject, S×Tᵥ
|
| 122 |
+
):
|
| 123 |
+
"""
|
| 124 |
+
Pure-Torch analogue of your NumPy helper.
|
| 125 |
+
Returns unbiased covariance [Tᵥ,Tᵥ] and mean vector [Tᵥ].
|
| 126 |
+
"""
|
| 127 |
+
S, _ = X.shape
|
| 128 |
+
mean_vec = X.mean(dim=0) # [Tᵥ]
|
| 129 |
+
Xc = X - mean_vec # [S,Tᵥ]
|
| 130 |
+
cov = Xc.t() @ Xc / (S - 1) # [Tᵥ,Tᵥ]
|
| 131 |
+
return cov, mean_vec
|
| 132 |
+
|
| 133 |
+
def whiten_manual_torch_old(
|
| 134 |
+
X: TensorType["S", "Tv"], # data to whiten
|
| 135 |
+
eps: float = 1e-8 # ridge for numerical safety
|
| 136 |
+
):
|
| 137 |
+
"""
|
| 138 |
+
Manual whitening à la your NumPy code.
|
| 139 |
+
Returns whitened X and the whitening matrix W (Σ^{-1/2}).
|
| 140 |
+
"""
|
| 141 |
+
cov, mean_vec = sample_covariance_manual_torch(X) # Σ, μ
|
| 142 |
+
eigvals, eigvecs = torch.linalg.eigh(cov + eps * torch.eye(cov.size(0), device=X.device))
|
| 143 |
+
D_inv_sqrt = torch.diag(torch.rsqrt(eigvals)) # diag(1/√λ)
|
| 144 |
+
W = eigvecs @ D_inv_sqrt @ eigvecs.t() # Σ^{-1/2}
|
| 145 |
+
X_white = (X - mean_vec) @ W # apply whitening
|
| 146 |
+
return X_white, W, mean_vec
|
| 147 |
+
|
| 148 |
+
def compute_npde_full_old(
|
| 149 |
+
y_obs: TensorType["I", "T"],
|
| 150 |
+
y_sim: TensorType["S", "I", "T"],
|
| 151 |
+
mask : TensorType["I", "T"],
|
| 152 |
+
eps : float = 1e-8
|
| 153 |
+
) -> TensorType["I", "T"]:
|
| 154 |
+
"""
|
| 155 |
+
Full NPDE with within-subject decorrelation (Σ^{-1/2}) computed
|
| 156 |
+
**exactly** as in your NumPy snippet.
|
| 157 |
+
|
| 158 |
+
NOTICE THAT THERE IS NO BATCH INDEX, this works only on individual substances
|
| 159 |
+
|
| 160 |
+
Shapes
|
| 161 |
+
------
|
| 162 |
+
y_obs : [I,T] observations (padding allowed)
|
| 163 |
+
y_sim : [S,I,T] S Monte-Carlo replicates
|
| 164 |
+
mask : [I,T] True/1 = valid time-points
|
| 165 |
+
"""
|
| 166 |
+
S, I, T = y_sim.shape
|
| 167 |
+
N01 = torch.distributions.Normal(0.0, 1.0)
|
| 168 |
+
out = torch.full_like(y_obs, float("nan")) # result placeholder
|
| 169 |
+
|
| 170 |
+
for i in range(I):
|
| 171 |
+
# ---- select the irregular grid for subject i -------------------
|
| 172 |
+
valid_idx = mask[i].bool()
|
| 173 |
+
if not valid_idx.any():
|
| 174 |
+
continue # nothing to do
|
| 175 |
+
|
| 176 |
+
y_i_obs = y_obs[i, valid_idx] # [Tᵥ]
|
| 177 |
+
y_i_sim = y_sim[:, i, valid_idx] # [S,Tᵥ]
|
| 178 |
+
|
| 179 |
+
# ---- whitening per your NumPy logic ----------------------------
|
| 180 |
+
y_i_sim_white, W, mean_vec = whiten_manual_torch(y_i_sim, eps) # [S,Tᵥ]
|
| 181 |
+
if W is None:
|
| 182 |
+
# Whitening degraded → set result to NaN or skip this subject
|
| 183 |
+
out[i, valid_idx] = float("nan")
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
# same transform for the single observation vector
|
| 187 |
+
y_i_obs_white = (y_i_obs - mean_vec) @ W # [Tᵥ]
|
| 188 |
+
|
| 189 |
+
# ---- empirical CDF on whitened scale (Eq. 4) -------------------
|
| 190 |
+
delta = (y_i_sim_white < y_i_obs_white).float() # [S,Tᵥ]
|
| 191 |
+
pde = delta.mean(dim=0) # [Tᵥ]
|
| 192 |
+
|
| 193 |
+
# ---- edge-case rule (Eq. 6) ------------------------------------
|
| 194 |
+
one_over_S = 1.0 / S
|
| 195 |
+
pde = torch.where(pde == 0, torch.full_like(pde, one_over_S), pde)
|
| 196 |
+
pde = torch.where(pde == 1, torch.full_like(pde, 1 - one_over_S), pde)
|
| 197 |
+
|
| 198 |
+
# ---- NPDE (Eq. 7) ---------------------------------------------
|
| 199 |
+
npde = N01.icdf(pde) # [Tᵥ]
|
| 200 |
+
|
| 201 |
+
# ---- write back to full-size tensor ----------------------------
|
| 202 |
+
out[i, valid_idx] = npde
|
| 203 |
+
|
| 204 |
+
return out
|
| 205 |
+
|
| 206 |
+
# ---------------------------------------------------------------------
|
| 207 |
+
# 1. Robust whitening
|
| 208 |
+
# ---------------------------------------------------------------------
|
| 209 |
+
def whiten_manual_torch(
|
| 210 |
+
X: Tensor, # [S, Tᵥ]
|
| 211 |
+
eps: float = 1e-8,
|
| 212 |
+
max_attempts: int = 5,
|
| 213 |
+
base_jitter: float = 1e-6
|
| 214 |
+
) -> Tuple[Tensor, torch.Tensor | None, Tensor, bool]:
|
| 215 |
+
"""
|
| 216 |
+
Returns
|
| 217 |
+
-------
|
| 218 |
+
X_white : [S,Tᵥ] whitened simulations
|
| 219 |
+
W : [Tᵥ,Tᵥ] | None Σ^{-½} (None ⇒ degraded to diag)
|
| 220 |
+
mean : [Tᵥ] sample mean
|
| 221 |
+
ok : bool True if full Σ^{-½} was used
|
| 222 |
+
"""
|
| 223 |
+
S, T = X.shape
|
| 224 |
+
X64 = X.double()
|
| 225 |
+
mean = X64.mean(dim=0)
|
| 226 |
+
Xm = X64 - mean
|
| 227 |
+
cov = (Xm.T @ Xm) / (S - 1)
|
| 228 |
+
I = torch.eye(T, dtype=X64.dtype, device=X.device)
|
| 229 |
+
|
| 230 |
+
W = None
|
| 231 |
+
for k in range(max_attempts):
|
| 232 |
+
jitter = base_jitter * (10.0 ** k)
|
| 233 |
+
try:
|
| 234 |
+
eigvals, eigvecs = torch.linalg.eigh(cov + (eps + jitter) * I)
|
| 235 |
+
if torch.any(eigvals <= 0):
|
| 236 |
+
raise RuntimeError("non-positive eigenvalues")
|
| 237 |
+
inv_sqrt = torch.rsqrt(eigvals)
|
| 238 |
+
W = eigvecs @ torch.diag(inv_sqrt) @ eigvecs.T
|
| 239 |
+
break
|
| 240 |
+
except RuntimeError:
|
| 241 |
+
pass # try bigger jitter
|
| 242 |
+
|
| 243 |
+
if W is None: # final fallback
|
| 244 |
+
var = cov.diag().clamp_min(eps)
|
| 245 |
+
W = torch.diag(torch.rsqrt(var)) # diagonal only
|
| 246 |
+
ok = False
|
| 247 |
+
else:
|
| 248 |
+
ok = True
|
| 249 |
+
|
| 250 |
+
X_white = (Xm @ W).float()
|
| 251 |
+
return X_white, W.float() if ok else None, mean.float(), ok
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ---------------------------------------------------------------------
|
| 255 |
+
# 2. NPDE with an *output* validity mask
|
| 256 |
+
# ---------------------------------------------------------------------
|
| 257 |
+
|
| 258 |
+
def compute_npde_full(
|
| 259 |
+
y_obs: TensorType["I", "T"],
|
| 260 |
+
y_sim: TensorType["S", "I", "T"],
|
| 261 |
+
mask : TensorType["I", "T"],
|
| 262 |
+
eps : float = 1e-8,
|
| 263 |
+
) -> Tuple[TensorType["I", "T"], TensorType["I", "T"]]:
|
| 264 |
+
"""
|
| 265 |
+
Full NPDE with within-subject decorrelation (Σ^{-1/2}) computed
|
| 266 |
+
**exactly** as in your NumPy snippet.
|
| 267 |
+
|
| 268 |
+
NOTICE THAT THERE IS NO BATCH INDEX, this works only on individual substances
|
| 269 |
+
|
| 270 |
+
Args
|
| 271 |
+
------
|
| 272 |
+
y_obs : [I,T] observations (padding allowed)
|
| 273 |
+
y_sim : [S,I,T] S Monte-Carlo replicates
|
| 274 |
+
mask : [I,T] True/1 = valid time-points
|
| 275 |
+
|
| 276 |
+
Returns
|
| 277 |
+
-------
|
| 278 |
+
npde : [I,T] – same shape as `y_obs`
|
| 279 |
+
valid_mask : [I,T] – True where npde is statistically valid
|
| 280 |
+
"""
|
| 281 |
+
S, I, T = y_sim.shape
|
| 282 |
+
N01 = torch.distributions.Normal(0.0, 1.0)
|
| 283 |
+
|
| 284 |
+
npde_out = torch.full_like(y_obs, float("nan"))
|
| 285 |
+
valid_out = mask.clone().bool() # start with the user mask
|
| 286 |
+
|
| 287 |
+
for i in range(I):
|
| 288 |
+
# ---- select the irregular grid for subject i -------------------
|
| 289 |
+
valid_idx = mask[i].bool()
|
| 290 |
+
if not valid_idx.any():
|
| 291 |
+
valid_out[i] = False
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
y_i_obs = y_obs[i, valid_idx] # [Tᵥ]
|
| 295 |
+
y_i_sim = y_sim[:, i, valid_idx] # [S,Tᵥ]
|
| 296 |
+
|
| 297 |
+
# ---- whitening per your NumPy logic ----------------------------
|
| 298 |
+
y_i_sim_white, W, mean_vec, ok = whiten_manual_torch(y_i_sim, eps)
|
| 299 |
+
|
| 300 |
+
if not ok: # whitening failed → invalidate
|
| 301 |
+
valid_out[i, valid_idx] = False
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
# same transform for the single observation vector
|
| 305 |
+
y_i_obs_white = (y_i_obs - mean_vec) @ W
|
| 306 |
+
|
| 307 |
+
# ---- empirical CDF on whitened scale (Eq. 4) -------------------
|
| 308 |
+
delta = (y_i_sim_white < y_i_obs_white).float()
|
| 309 |
+
pde = delta.mean(dim=0)
|
| 310 |
+
|
| 311 |
+
# ---- edge-case rule (Eq. 6) ------------------------------------
|
| 312 |
+
one_over_S = 1.0 / S
|
| 313 |
+
pde = torch.where(pde == 0, torch.full_like(pde, one_over_S), pde)
|
| 314 |
+
pde = torch.where(pde == 1, torch.full_like(pde, 1 - one_over_S), pde)
|
| 315 |
+
|
| 316 |
+
# ---- NPDE (Eq. 7) ---------------------------------------------
|
| 317 |
+
npde = N01.icdf(pde)
|
| 318 |
+
npde_out[i, valid_idx] = npde
|
| 319 |
+
|
| 320 |
+
return npde_out, valid_out
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def compute_npde_in_batch(
|
| 324 |
+
y_obs: TensorType["B", "I", "T"],
|
| 325 |
+
y_sim: TensorType["S", "B", "I", "T"],
|
| 326 |
+
mask: TensorType["B", "I", "T"],
|
| 327 |
+
eps: float = 1e-8,
|
| 328 |
+
) -> TensorType["B", "I", "T"]:
|
| 329 |
+
"""Compute NPDE for each element in a batch.
|
| 330 |
+
|
| 331 |
+
Parameters
|
| 332 |
+
----------
|
| 333 |
+
y_obs : [B, I, T] Observed values per batch item (context observations).
|
| 334 |
+
y_sim : [S, B, I, T] Simulated predictions.
|
| 335 |
+
mask : [B, I, T] Validity mask for observations.
|
| 336 |
+
|
| 337 |
+
Returns
|
| 338 |
+
-------
|
| 339 |
+
Tensor of shape [B, I, T] with NPDE values.
|
| 340 |
+
"""
|
| 341 |
+
B = y_obs.size(0)
|
| 342 |
+
results = []
|
| 343 |
+
for b in range(B):
|
| 344 |
+
npde_b = compute_npde_full(y_obs[b], y_sim[:, b], mask[b], eps)
|
| 345 |
+
results.append(npde_b)
|
| 346 |
+
return torch.stack(results, dim=0)
|
| 347 |
+
|
| 348 |
+
def shapiro_wilk_normality(npde: TensorType["T"]) -> Tuple[float, float]:
|
| 349 |
+
"""Return Shapiro-Wilk normality test statistic and p-value for a 1-D tensor."""
|
| 350 |
+
npde_np = npde[torch.isfinite(npde)].detach().cpu().numpy()
|
| 351 |
+
w, p = stats.shapiro(npde_np)
|
| 352 |
+
return float(w), float(p)
|
| 353 |
+
|
| 354 |
+
def qq_plot(npde: TensorType["T"], train:bool =False, epoch:str|int = "na", **kwargs) -> str | None:
|
| 355 |
+
"""
|
| 356 |
+
Generate and optionally save/show a QQ plot of NPDE values.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
npde: Tensor containing NPDE values.
|
| 360 |
+
train (bool, optional): If True (default), saves plot to file.
|
| 361 |
+
model (optional): Lightning model, used to name the file with `current_epoch`.
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
File path if saved, None otherwise.
|
| 365 |
+
"""
|
| 366 |
+
npde_np = npde[torch.isfinite(npde)].detach().cpu().numpy()
|
| 367 |
+
|
| 368 |
+
fig = plt.figure()
|
| 369 |
+
stats.probplot(npde_np, dist="norm", plot=plt)
|
| 370 |
+
|
| 371 |
+
if train:
|
| 372 |
+
# Use model.current_epoch if provided
|
| 373 |
+
path = f"qq_plot_epoch_{epoch}.png"
|
| 374 |
+
fig.savefig(path, bbox_inches="tight")
|
| 375 |
+
plt.close(fig)
|
| 376 |
+
return path
|
| 377 |
+
else:
|
| 378 |
+
plt.show()
|
| 379 |
+
return None
|
| 380 |
+
|
| 381 |
+
def vcp_from_sample(model,databatch_list,empirical_databatch,train=False):
|
| 382 |
+
"""
|
| 383 |
+
in order to have a shape [S,I,T] vs [I,T] the models concatenates all samples for each held out individuals
|
| 384 |
+
which are of shape [S,B=1,I=1,T,1] (held out sample) -> [S,I,T] (required by vpc)
|
| 385 |
+
"""
|
| 386 |
+
samples_list = model.sample(databatch_list,use_unique_times=True,num_samples=30)
|
| 387 |
+
combined_observation = combine_samples([pair[0] for pair in samples_list])
|
| 388 |
+
combined_times = combine_samples([pair[1] for pair in samples_list])
|
| 389 |
+
print(combined_observation.shape)
|
| 390 |
+
simulation_times = combined_times[0,0,:]
|
| 391 |
+
print(simulation_times.shape)
|
| 392 |
+
patients, patients_time = extract_context_by_mask(empirical_databatch)
|
| 393 |
+
img = vpc(simulation_times, combined_observation, patients, patients_time,train=train)
|
| 394 |
+
return img
|
| 395 |
+
|
| 396 |
+
def vpc_from_empirical(databatch_list,databatch_list_context,model,train=False,image_name="vpc.png",samples_number=100,y_scale=None):
|
| 397 |
+
aicme = databatch_list_context[0]
|
| 398 |
+
patients = [db_tuple[0].target_obs.cpu().detach().numpy() for db_tuple in databatch_list]
|
| 399 |
+
patients_time = [db_tuple[0].target_obs_time.cpu().detach().numpy() for db_tuple in databatch_list]
|
| 400 |
+
max_time_index = aicme.context_obs_mask.sum(axis=2).squeeze().argmax()
|
| 401 |
+
all_samples_times = aicme.context_obs_time[0,max_time_index,aicme.context_obs_mask[0,max_time_index]]
|
| 402 |
+
all_samples = []
|
| 403 |
+
for db_tuple in databatch_list:
|
| 404 |
+
samples,samples_time = model.sample_new_individual(db_tuple,samples_number)
|
| 405 |
+
all_samples.append(samples)
|
| 406 |
+
all_samples = torch.cat(all_samples,dim=2).squeeze()
|
| 407 |
+
all_samples = all_samples[:,:,aicme.context_obs_mask[0,max_time_index]]
|
| 408 |
+
vpc(all_samples_times, all_samples, patients, patients_time,train=train,image_name=image_name,y_scale=y_scale)
|
| 409 |
+
|
| 410 |
+
def vpc(test_time, MetaStudies, patients, patients_time, train=True, image_name="vpc.png", y_scale=None):
|
| 411 |
+
"""
|
| 412 |
+
Generate a Visual Predictive Check (VPC) plot with PyTorch tensor inputs.
|
| 413 |
+
|
| 414 |
+
Parameters:
|
| 415 |
+
- test_time: 1D PyTorch tensor of fixed time points for simulated data (shape [T])
|
| 416 |
+
- MetaStudies: 3D PyTorch tensor of simulated data (shape [M, P, T])
|
| 417 |
+
- patients: List of 1D PyTorch tensors, each with observed concentrations
|
| 418 |
+
- patients_time: List of 1D PyTorch tensors, each with corresponding times
|
| 419 |
+
- train: If True, save plot; else show it
|
| 420 |
+
- image_name: File name to save the image if train=True
|
| 421 |
+
- y_scale: Set to "log" for log-scale y-axis; None for linear
|
| 422 |
+
"""
|
| 423 |
+
if len(test_time.shape) > 1:
|
| 424 |
+
test_time = test_time.squeeze()
|
| 425 |
+
|
| 426 |
+
test_time_np = test_time.detach().cpu().numpy()
|
| 427 |
+
MetaStudies_np = MetaStudies.detach().cpu().numpy()
|
| 428 |
+
|
| 429 |
+
percentiles = [5, 25, 50, 75, 95]
|
| 430 |
+
sim_percentiles = np.percentile(MetaStudies_np, percentiles, axis=1) # [5, M, T]
|
| 431 |
+
sim_percentiles_agg = np.percentile(sim_percentiles, 50, axis=1) # [5, T]
|
| 432 |
+
|
| 433 |
+
p05, p25, p50, p75, p95 = sim_percentiles_agg
|
| 434 |
+
|
| 435 |
+
plt.figure(figsize=(10, 6))
|
| 436 |
+
plt.fill_between(test_time_np, p05, p95, color='blue', alpha=0.2, label='5th-95th Percentile')
|
| 437 |
+
plt.fill_between(test_time_np, p25, p75, color='blue', alpha=0.4, label='25th-75th Percentile')
|
| 438 |
+
plt.plot(test_time_np, p50, color='blue', label='Median (50th Percentile)')
|
| 439 |
+
|
| 440 |
+
for obs, times in zip(patients, patients_time):
|
| 441 |
+
plt.scatter(times, obs, color='red', alpha=0.6, s=20)
|
| 442 |
+
|
| 443 |
+
plt.xlabel('Time (hours)')
|
| 444 |
+
plt.ylabel('Concentration (g/L)')
|
| 445 |
+
if y_scale == "log":
|
| 446 |
+
plt.yscale('log')
|
| 447 |
+
|
| 448 |
+
plt.title('Visual Predictive Check (VPC)')
|
| 449 |
+
plt.legend()
|
| 450 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 451 |
+
|
| 452 |
+
if train:
|
| 453 |
+
plt.savefig(image_name)
|
| 454 |
+
plt.close()
|
| 455 |
+
return image_name
|
| 456 |
+
else:
|
| 457 |
+
plt.show()
|
| 458 |
+
plt.close()
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def get_unique_target_times(
|
| 462 |
+
db_list: List[AICMECompartmentsDataBatch]
|
| 463 |
+
) -> TensorType["U", 1]:
|
| 464 |
+
"""
|
| 465 |
+
Given P databatches, each with
|
| 466 |
+
.target_obs_time: [B, t_ind, num_obs_t, 1]
|
| 467 |
+
returns a tensor of shape [U, 1] containing the sorted unique times
|
| 468 |
+
across *all* batches and *all* target time points.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
db_list: list of length P of AICMECompartmentsDataBatch
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
unique_times: Tensor of shape [U, 1], where U is the number of
|
| 475 |
+
unique target‐observation times across every batch.
|
| 476 |
+
"""
|
| 477 |
+
# 1) Flatten each batch's times:
|
| 478 |
+
# db.target_obs_time.squeeze(-1).reshape(-1) has shape [B * t_ind * num_obs_t]
|
| 479 |
+
flat_times = [
|
| 480 |
+
db.target_obs_time.squeeze(-1).reshape(-1) # [B * t_ind * num_obs_t]
|
| 481 |
+
for db in db_list
|
| 482 |
+
]
|
| 483 |
+
# 2) Concatenate all P batches → [(P * B * t_ind * num_obs_t)]
|
| 484 |
+
all_times = torch.cat(flat_times, dim=0)
|
| 485 |
+
|
| 486 |
+
# 3) Compute sorted unique values → [U]
|
| 487 |
+
unique = torch.unique(all_times)
|
| 488 |
+
|
| 489 |
+
# 4) Return as column vector → [U, 1]
|
| 490 |
+
return unique.unsqueeze(-1).unsqueeze(0).unsqueeze(0) # TensorType[1,1,"U", 1]
|
sim_priors_pk/metrics/quantiles_coverage.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torchtyping import TensorType
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_predictive_quantiles(
|
| 8 |
+
pred_values: TensorType["B", "S", "T", 1],
|
| 9 |
+
pred_mask: TensorType["B", "S", "T"] | TensorType["B", "T"],
|
| 10 |
+
alpha: float,
|
| 11 |
+
) -> Tuple[
|
| 12 |
+
TensorType["B", "T", 1],
|
| 13 |
+
TensorType["B", "T", 1],
|
| 14 |
+
]:
|
| 15 |
+
"""
|
| 16 |
+
Compute lower and upper predictive quantiles (α/2, 1−α/2)
|
| 17 |
+
across stochastic samples, supporting both shared and
|
| 18 |
+
per-sample (per-individual) masks.
|
| 19 |
+
|
| 20 |
+
Parameters
|
| 21 |
+
----------
|
| 22 |
+
pred_values : [B, S, T, 1]
|
| 23 |
+
Predicted sample trajectories.
|
| 24 |
+
pred_mask : [B, T] or [B, S, T]
|
| 25 |
+
Boolean mask marking valid time points.
|
| 26 |
+
- If [B, T]: same mask for all samples.
|
| 27 |
+
- If [B, S, T]: individual-specific masks.
|
| 28 |
+
alpha : float
|
| 29 |
+
Significance level (e.g. 0.05 for 90% interval).
|
| 30 |
+
|
| 31 |
+
Returns
|
| 32 |
+
-------
|
| 33 |
+
q_low, q_high : [B, T, 1]
|
| 34 |
+
Predictive lower and upper quantile envelopes.
|
| 35 |
+
"""
|
| 36 |
+
B, S, T, _ = pred_values.shape
|
| 37 |
+
device = pred_values.device
|
| 38 |
+
|
| 39 |
+
# --- normalize mask shape ---
|
| 40 |
+
if pred_mask.ndim == 2:
|
| 41 |
+
pred_mask = pred_mask.unsqueeze(1).repeat(1, S, 1) # [B,S,T]
|
| 42 |
+
|
| 43 |
+
q_low_list = []
|
| 44 |
+
q_high_list = []
|
| 45 |
+
|
| 46 |
+
for b in range(B):
|
| 47 |
+
q_low_b = torch.zeros(T, device=device)
|
| 48 |
+
q_high_b = torch.zeros(T, device=device)
|
| 49 |
+
|
| 50 |
+
# for each time index, only include valid samples
|
| 51 |
+
for t_idx in range(T):
|
| 52 |
+
valid_s = pred_mask[b, :, t_idx]
|
| 53 |
+
if valid_s.any():
|
| 54 |
+
vals = pred_values[b, valid_s, t_idx, 0]
|
| 55 |
+
q_low_b[t_idx] = vals.quantile(alpha / 2)
|
| 56 |
+
q_high_b[t_idx] = vals.quantile(1 - alpha / 2)
|
| 57 |
+
else:
|
| 58 |
+
# leave zeros (or NaN if preferred)
|
| 59 |
+
q_low_b[t_idx] = 0.0
|
| 60 |
+
q_high_b[t_idx] = 0.0
|
| 61 |
+
|
| 62 |
+
q_low_list.append(q_low_b.unsqueeze(-1))
|
| 63 |
+
q_high_list.append(q_high_b.unsqueeze(-1))
|
| 64 |
+
|
| 65 |
+
q_low = torch.stack(q_low_list, dim=0) # [B,T,1]
|
| 66 |
+
q_high = torch.stack(q_high_list, dim=0) # [B,T,1]
|
| 67 |
+
|
| 68 |
+
return q_low, q_high
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def interpolate_quantiles_to_obs_times(
|
| 72 |
+
q_low: TensorType["B", "Tpred", 1],
|
| 73 |
+
q_high: TensorType["B", "Tpred", 1],
|
| 74 |
+
pred_times: TensorType["B", "Tpred", 1],
|
| 75 |
+
pred_mask: TensorType["B", "Tpred"],
|
| 76 |
+
real_times: TensorType["B", "I", "Treal", 1],
|
| 77 |
+
real_mask: TensorType["B", "I", "Treal"],
|
| 78 |
+
) -> Tuple[
|
| 79 |
+
TensorType["B", "I", "Treal", 1],
|
| 80 |
+
TensorType["B", "I", "Treal", 1],
|
| 81 |
+
]:
|
| 82 |
+
"""
|
| 83 |
+
Interpolate predictive quantile bands (q_low, q_high) to the irregular
|
| 84 |
+
observation times of real data.
|
| 85 |
+
|
| 86 |
+
Parameters
|
| 87 |
+
----------
|
| 88 |
+
q_low, q_high : TensorType["B", "Tpred", 1]
|
| 89 |
+
Predictive lower and upper quantile curves at distinct time grid points.
|
| 90 |
+
pred_times : TensorType["B", "Tpred", 1]
|
| 91 |
+
Time grid corresponding to the quantile curves.
|
| 92 |
+
pred_mask : TensorType["B", "Tpred"]
|
| 93 |
+
Boolean mask marking valid predictive times per batch.
|
| 94 |
+
real_times : TensorType["B", "I", "Treal", 1]
|
| 95 |
+
Observation times for each individual and batch.
|
| 96 |
+
real_mask : TensorType["B", "I", "Treal"]
|
| 97 |
+
Mask indicating valid observed time points.
|
| 98 |
+
|
| 99 |
+
Returns
|
| 100 |
+
-------
|
| 101 |
+
q_low_interp, q_high_interp : Tuple[
|
| 102 |
+
TensorType["B", "I", "Treal", 1],
|
| 103 |
+
TensorType["B", "I", "Treal", 1],
|
| 104 |
+
]
|
| 105 |
+
Interpolated quantile band values at each observed time, padded
|
| 106 |
+
where invalid.
|
| 107 |
+
|
| 108 |
+
Notes
|
| 109 |
+
-----
|
| 110 |
+
- Uses linear interpolation between nearest predictive time knots.
|
| 111 |
+
- Out-of-range times are clamped to the nearest boundary quantile.
|
| 112 |
+
- Invalid (masked) observations are returned as zeros.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
B, I, Treal, _ = real_times.shape
|
| 116 |
+
device = real_times.device
|
| 117 |
+
|
| 118 |
+
q_low_interp_list, q_high_interp_list = [], []
|
| 119 |
+
|
| 120 |
+
for b in range(B):
|
| 121 |
+
# Extract valid predictive points for this batch
|
| 122 |
+
valid_mask_b = pred_mask[b] # [Tpred]
|
| 123 |
+
valid_T = valid_mask_b.sum().item()
|
| 124 |
+
if valid_T < 2:
|
| 125 |
+
# Degenerate case: not enough points for interpolation
|
| 126 |
+
q_low_interp_list.append(torch.zeros(I, Treal, 1, device=device))
|
| 127 |
+
q_high_interp_list.append(torch.zeros(I, Treal, 1, device=device))
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
t_pred = pred_times[b, valid_mask_b, 0] # [T_b]
|
| 131 |
+
ql = q_low[b, valid_mask_b, 0] # [T_b]
|
| 132 |
+
qh = q_high[b, valid_mask_b, 0] # [T_b]
|
| 133 |
+
|
| 134 |
+
# For each individual, interpolate its observation times
|
| 135 |
+
q_low_i, q_high_i = [], []
|
| 136 |
+
for i in range(I):
|
| 137 |
+
t_obs = real_times[b, i, :, 0] # [Treal]
|
| 138 |
+
valid_obs = real_mask[b, i] # [Treal]
|
| 139 |
+
|
| 140 |
+
# Clamp obs times into predictive range
|
| 141 |
+
t_clamped = torch.clamp(t_obs, t_pred.min(), t_pred.max())
|
| 142 |
+
|
| 143 |
+
# Use searchsorted to find bracketing indices
|
| 144 |
+
idx_right = torch.searchsorted(t_pred, t_clamped)
|
| 145 |
+
idx_left = (idx_right - 1).clamp(min=0)
|
| 146 |
+
idx_right = idx_right.clamp(max=valid_T - 1)
|
| 147 |
+
|
| 148 |
+
# Gather times and values for interpolation
|
| 149 |
+
t_L, t_R = t_pred[idx_left], t_pred[idx_right]
|
| 150 |
+
ql_L, ql_R = ql[idx_left], ql[idx_right]
|
| 151 |
+
qh_L, qh_R = qh[idx_left], qh[idx_right]
|
| 152 |
+
|
| 153 |
+
denom = (t_R - t_L).clamp(min=1e-8)
|
| 154 |
+
w_R = (t_clamped - t_L) / denom
|
| 155 |
+
w_L = 1.0 - w_R
|
| 156 |
+
|
| 157 |
+
ql_interp = w_L * ql_L + w_R * ql_R
|
| 158 |
+
qh_interp = w_L * qh_L + w_R * qh_R
|
| 159 |
+
|
| 160 |
+
# Zero out invalid times
|
| 161 |
+
ql_interp = ql_interp.masked_fill(~valid_obs, 0.0)
|
| 162 |
+
qh_interp = qh_interp.masked_fill(~valid_obs, 0.0)
|
| 163 |
+
|
| 164 |
+
q_low_i.append(ql_interp.unsqueeze(-1))
|
| 165 |
+
q_high_i.append(qh_interp.unsqueeze(-1))
|
| 166 |
+
|
| 167 |
+
q_low_interp_list.append(torch.stack(q_low_i, dim=0)) # [I, Treal, 1]
|
| 168 |
+
q_high_interp_list.append(torch.stack(q_high_i, dim=0)) # [I, Treal, 1]
|
| 169 |
+
|
| 170 |
+
q_low_interp = torch.stack(q_low_interp_list, dim=0) # [B, I, Treal, 1]
|
| 171 |
+
q_high_interp = torch.stack(q_high_interp_list, dim=0) # [B, I, Treal, 1]
|
| 172 |
+
|
| 173 |
+
return q_low_interp, q_high_interp
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def compute_time_weighted_coverage(
|
| 177 |
+
real_values: TensorType["B", "I", "Treal", 1],
|
| 178 |
+
real_times: TensorType["B", "I", "Treal", 1],
|
| 179 |
+
real_mask: TensorType["B", "I", "Treal"],
|
| 180 |
+
q_low_interp: TensorType["B", "I", "Treal", 1],
|
| 181 |
+
q_high_interp: TensorType["B", "I", "Treal", 1],
|
| 182 |
+
reduce: bool = True,
|
| 183 |
+
) -> TensorType["B"]:
|
| 184 |
+
"""
|
| 185 |
+
Compute time-weighted coverage fraction of observations within predictive bands.
|
| 186 |
+
"""
|
| 187 |
+
# [B, I, Treal, 1]
|
| 188 |
+
covered = (real_values >= q_low_interp) & (real_values <= q_high_interp)
|
| 189 |
+
covered = covered.squeeze(-1) & real_mask # [B, I, Treal]
|
| 190 |
+
|
| 191 |
+
# Compute Δt (difference along time)
|
| 192 |
+
dt = torch.diff(real_times, dim=2, prepend=real_times[:, :, :1])
|
| 193 |
+
dt = dt.squeeze(-1) * real_mask # [B, I, Treal]
|
| 194 |
+
dt_sum = dt.sum(dim=(1, 2), keepdim=True).clamp(min=1e-8)
|
| 195 |
+
weights = dt / dt_sum # normalized time weights
|
| 196 |
+
|
| 197 |
+
coverage = (covered.float() * weights).sum(dim=(1, 2)) # [B]
|
| 198 |
+
return coverage
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def compute_interval_score(
|
| 202 |
+
real_values: TensorType["B", "I", "Treal", 1],
|
| 203 |
+
real_times: TensorType["B", "I", "Treal", 1],
|
| 204 |
+
real_mask: TensorType["B", "I", "Treal"],
|
| 205 |
+
q_low_interp: TensorType["B", "I", "Treal", 1],
|
| 206 |
+
q_high_interp: TensorType["B", "I", "Treal", 1],
|
| 207 |
+
alpha: float,
|
| 208 |
+
) -> TensorType["B"]:
|
| 209 |
+
"""
|
| 210 |
+
Compute the time-weighted interval score (Gneiting & Raftery, 2007).
|
| 211 |
+
"""
|
| 212 |
+
width = (q_high_interp - q_low_interp).abs()
|
| 213 |
+
below = (q_low_interp - real_values).clamp(min=0)
|
| 214 |
+
above = (real_values - q_high_interp).clamp(min=0)
|
| 215 |
+
|
| 216 |
+
interval_score = width + (2 / alpha) * (below + above)
|
| 217 |
+
interval_score = interval_score.squeeze(-1) * real_mask # [B, I, Treal]
|
| 218 |
+
|
| 219 |
+
# Δt weighting
|
| 220 |
+
dt = torch.diff(real_times, dim=2, prepend=real_times[:, :, :1]).squeeze(-1)
|
| 221 |
+
dt = dt * real_mask
|
| 222 |
+
dt_sum = dt.sum(dim=(1, 2), keepdim=True).clamp(min=1e-8)
|
| 223 |
+
weights = dt / dt_sum
|
| 224 |
+
|
| 225 |
+
# Weighted mean per batch
|
| 226 |
+
score_weighted = (interval_score * weights).sum(dim=(1, 2))
|
| 227 |
+
return score_weighted
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def compute_percentile_coverage(
|
| 231 |
+
pred_values,
|
| 232 |
+
pred_times,
|
| 233 |
+
pred_mask,
|
| 234 |
+
real_values,
|
| 235 |
+
real_times,
|
| 236 |
+
real_mask,
|
| 237 |
+
alpha: float = 0.05,
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Compute predictive interval coverage and interval score between predicted and observed trajectories.
|
| 241 |
+
|
| 242 |
+
This function evaluates how well a stochastic predictive model captures
|
| 243 |
+
the true (real) observations within its predictive uncertainty bands.
|
| 244 |
+
|
| 245 |
+
It combines three subroutines:
|
| 246 |
+
1. :func:`compute_predictive_quantiles` — compute lower and upper predictive quantiles.
|
| 247 |
+
2. :func:`interpolate_quantiles_to_obs_times` — align quantile predictions to observation times.
|
| 248 |
+
3. :func:`compute_time_weighted_coverage` and :func:`compute_interval_score` —
|
| 249 |
+
compute Δt-weighted coverage fraction and proper scoring rule.
|
| 250 |
+
|
| 251 |
+
Parameters
|
| 252 |
+
----------
|
| 253 |
+
pred_values : TensorType["B", "S", "T_pred", 1]
|
| 254 |
+
Stochastic predictions for each batch element `B` and stochastic sample `S`.
|
| 255 |
+
Typically obtained by sampling the model multiple times.
|
| 256 |
+
|
| 257 |
+
pred_times : TensorType["B", "T_pred", 1]
|
| 258 |
+
Distinct prediction time grid per batch (shared across stochastic samples).
|
| 259 |
+
|
| 260 |
+
pred_mask : TensorType["B", "T_pred"]
|
| 261 |
+
Boolean mask indicating valid prediction time steps.
|
| 262 |
+
|
| 263 |
+
real_values : TensorType["B", "I", "T_real", 1]
|
| 264 |
+
Ground-truth or observed values for each batch and individual.
|
| 265 |
+
|
| 266 |
+
real_times : TensorType["B", "I", "T_real", 1]
|
| 267 |
+
Observation times corresponding to `real_values`.
|
| 268 |
+
|
| 269 |
+
real_mask : TensorType["B", "I", "T_real"]
|
| 270 |
+
Boolean mask indicating valid observed time points.
|
| 271 |
+
|
| 272 |
+
alpha : float, optional (default = 0.05)
|
| 273 |
+
Significance level defining the predictive interval width.
|
| 274 |
+
For example:
|
| 275 |
+
* α = 0.05 → 90% central interval (quantiles 0.025 and 0.975)
|
| 276 |
+
* α = 0.10 → 80% central interval (quantiles 0.05 and 0.95)
|
| 277 |
+
Smaller α yields wider intervals (more conservative coverage).
|
| 278 |
+
|
| 279 |
+
Returns
|
| 280 |
+
-------
|
| 281 |
+
dict[str, TensorType["B"]]
|
| 282 |
+
Dictionary containing:
|
| 283 |
+
- ``"coverage"`` : Δt-weighted fraction of observations inside the predictive interval.
|
| 284 |
+
- ``"interval_score"`` : Proper interval score (Gneiting & Raftery, 2007),
|
| 285 |
+
penalizing both interval width and miscoverage.
|
| 286 |
+
|
| 287 |
+
Notes
|
| 288 |
+
-----
|
| 289 |
+
- High coverage (≈1.0) indicates all real points lie inside the predictive band.
|
| 290 |
+
In well-calibrated models, expected coverage ≈ 1−α.
|
| 291 |
+
- Lower interval scores correspond to sharper and better-calibrated predictions.
|
| 292 |
+
|
| 293 |
+
References
|
| 294 |
+
----------
|
| 295 |
+
Gneiting, T. & Raftery, A. E. (2007). *Strictly Proper Scoring Rules, Prediction, and Estimation*.
|
| 296 |
+
Journal of the American Statistical Association, 102(477), 359-378.
|
| 297 |
+
"""
|
| 298 |
+
q_low, q_high = compute_predictive_quantiles(pred_values, pred_mask, alpha)
|
| 299 |
+
q_low_interp, q_high_interp = interpolate_quantiles_to_obs_times(
|
| 300 |
+
q_low, q_high, pred_times, pred_mask, real_times, real_mask
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
coverage = compute_time_weighted_coverage(
|
| 304 |
+
real_values, real_times, real_mask, q_low_interp, q_high_interp
|
| 305 |
+
)
|
| 306 |
+
interval_score = compute_interval_score(
|
| 307 |
+
real_values, real_times, real_mask, q_low_interp, q_high_interp, alpha
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return {"coverage": coverage, "interval_score": interval_score}
|
sim_priors_pk/metrics/sampling_quality.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluate sampling quality of a model based on Visual Predictive Checks or Normalized Prediction Distribution Errors (NPDEs).
|
| 2 |
+
# Input for both evaluations: a StudyJSON object containing the observed data and a List[StudyJSON] containing replicates of simulated data from the model.
|
| 3 |
+
# This way, both neural networks and NLME models can be evaluated using the same code, as long as they can produce the required StudyJSON objects.
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional, Sequence
|
| 6 |
+
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from scipy.stats import chi2, norm, shapiro, ttest_1samp
|
| 11 |
+
|
| 12 |
+
from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def json_to_dataframe(study_json: StudyJSON) -> pd.DataFrame:
|
| 16 |
+
"""
|
| 17 |
+
Convert a StudyJSON object to a pandas DataFrame for easier analysis.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
study_json (StudyJSON): The StudyJSON object to convert.
|
| 21 |
+
Returns:
|
| 22 |
+
pd.DataFrame: A DataFrame with columns ["Type", "ID", "Time", "Value"] from the StudyJSON data.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
frames = []
|
| 26 |
+
|
| 27 |
+
for data_type in ["context", "target"]:
|
| 28 |
+
entries = study_json.get(data_type, [])
|
| 29 |
+
|
| 30 |
+
for j, entry in enumerate(entries):
|
| 31 |
+
# Prefer name_id, else _id, else a deterministic fallback
|
| 32 |
+
name_id = entry.get("name_id") or entry.get("_id") or f"{data_type}_{j}"
|
| 33 |
+
|
| 34 |
+
df_entry = pd.DataFrame(
|
| 35 |
+
{
|
| 36 |
+
"Type": data_type,
|
| 37 |
+
"ID": str(name_id), # ensure it's a string
|
| 38 |
+
"Time": entry["observation_times"],
|
| 39 |
+
"Value": entry["observations"],
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
frames.append(df_entry)
|
| 43 |
+
|
| 44 |
+
if frames:
|
| 45 |
+
return pd.concat(frames, ignore_index=True)
|
| 46 |
+
else:
|
| 47 |
+
return pd.DataFrame(columns=["Type", "ID", "Time", "Value"])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def json_list_to_dataframe(study_list: List[StudyJSON]) -> pd.DataFrame:
|
| 51 |
+
"""
|
| 52 |
+
Convert a list of StudyJSON objects to a pandas DataFrame for easier analysis.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
study_list (List[StudyJSON]): The list of StudyJSON objects to convert.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
pd.DataFrame: A DataFrame with columns ["Type", "ID", "Time", "Value", "Replicate"] from the StudyJSON data.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
frames = []
|
| 62 |
+
|
| 63 |
+
for replicate_idx, study in enumerate(study_list):
|
| 64 |
+
df = json_to_dataframe(study)
|
| 65 |
+
df["Replicate"] = replicate_idx
|
| 66 |
+
frames.append(df)
|
| 67 |
+
|
| 68 |
+
return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def validate_npde_vpc_inputs(
|
| 72 |
+
data: pd.DataFrame, simulations: List[pd.DataFrame], differentTimesError: bool = True
|
| 73 |
+
) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Validate the inputs for NPDE / VPC calculation.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
data (pd.DataFrame): The observed data in DataFrame format with columns ["Type", "ID", "Time", "Value"].
|
| 79 |
+
simulations (List[pd.DataFrame]): A list of DataFrames with columns ["Type", "ID", "Time", "Value","Replicate"]
|
| 80 |
+
representing simulated data from the model.
|
| 81 |
+
differentTimesError (bool): Whether to raise an error if observation times differ between individuals (default: True).
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
None: If the inputs are valid, otherwise raises a ValueError.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
key_cols = ["Type", "ID", "Time"]
|
| 88 |
+
|
| 89 |
+
obs_keys = data[key_cols].drop_duplicates().sort_values(key_cols).reset_index(drop=True)
|
| 90 |
+
|
| 91 |
+
pred_keys = simulations[key_cols].drop_duplicates().sort_values(key_cols).reset_index(drop=True) # type: ignore
|
| 92 |
+
|
| 93 |
+
if not obs_keys.equals(pred_keys):
|
| 94 |
+
raise ValueError("Observations and predictions are not structurally identical.")
|
| 95 |
+
|
| 96 |
+
if differentTimesError:
|
| 97 |
+
if (data.groupby("ID")["Time"].apply(lambda x: tuple(sorted(x))).nunique()) != 1:
|
| 98 |
+
raise ValueError("Observation times differ between individuals.")
|
| 99 |
+
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def compute_npde_data(data: StudyJSON, simulations: List[StudyJSON]) -> np.ndarray:
|
| 104 |
+
"""
|
| 105 |
+
Compute Normalized Prediction Distribution Errors (NPDEs) for a given StudyJSON and a list of simulated StudyJSONs.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
data (StudyJSON): The observed data in StudyJSON
|
| 109 |
+
simulations (List[StudyJSON]): A list of StudyJSON objects representing simulated data from the model.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
np.ndarray: An array of NPDE values.
|
| 113 |
+
"""
|
| 114 |
+
# Extract observed values and predicted values from the StudyJSON objects and validate them before calculating NPDEs.
|
| 115 |
+
observed_values = json_to_dataframe(data)
|
| 116 |
+
predicted_values = json_list_to_dataframe(simulations)
|
| 117 |
+
|
| 118 |
+
validate_npde_vpc_inputs(observed_values, predicted_values, differentTimesError=False)
|
| 119 |
+
|
| 120 |
+
# Merge observations and predictions into a single DataFrame for NPDE calculation.
|
| 121 |
+
key_cols = ["Type", "ID", "Time"]
|
| 122 |
+
|
| 123 |
+
pred_wide = predicted_values.pivot(index=key_cols, columns="Replicate", values="Value")
|
| 124 |
+
|
| 125 |
+
obs_indexed = observed_values.set_index(key_cols)
|
| 126 |
+
|
| 127 |
+
combined = pred_wide.join(obs_indexed["Value"].rename("Observed"))
|
| 128 |
+
|
| 129 |
+
# Calculate NPDEs for each replicate and return as a numpy array.
|
| 130 |
+
replicate_cols = pred_wide.columns
|
| 131 |
+
|
| 132 |
+
pred_vals = combined[replicate_cols].values
|
| 133 |
+
obs_vals = combined["Observed"].values
|
| 134 |
+
|
| 135 |
+
# Empirical CDF (truncated to avoid 0 and 1) for each observation based on the predicted distribution from the replicates.
|
| 136 |
+
pde = (pred_vals <= obs_vals[:, None]).sum(axis=1) / (len(replicate_cols) + 1) + 0.5 / (
|
| 137 |
+
len(replicate_cols) + 1
|
| 138 |
+
)
|
| 139 |
+
npde = norm.ppf(pde)
|
| 140 |
+
|
| 141 |
+
return npde
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def npde_plot(npde_values: np.ndarray) -> None:
|
| 145 |
+
"""
|
| 146 |
+
Create a quantile-quantile-plot of NPDE values.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
npde_values (np.ndarray): An array of NPDE values to plot.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
None
|
| 153 |
+
"""
|
| 154 |
+
plt.figure(figsize=(6, 6))
|
| 155 |
+
plt.title("Q-Q Plot of NPDE Values")
|
| 156 |
+
plt.xlabel("Theoretical Quantiles")
|
| 157 |
+
plt.ylabel("Empirical Quantiles")
|
| 158 |
+
norm_qq = np.sort(npde_values)
|
| 159 |
+
theoretical_qq = norm.ppf((np.arange(len(npde_values)) + 1) / (len(npde_values) + 1))
|
| 160 |
+
plt.plot(theoretical_qq, norm_qq, marker="o", linestyle="")
|
| 161 |
+
plt.plot(theoretical_qq, theoretical_qq, color="red", linestyle="--")
|
| 162 |
+
plt.grid()
|
| 163 |
+
plt.show()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def npde_pvalues(npde_values: np.ndarray) -> dict:
|
| 167 |
+
"""
|
| 168 |
+
Calculate p-values based on the theoretical N(0,1) distribution of NPDE values.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
npde_values (np.ndarray): An array of NPDE values to summarize.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
dict: A dictionary containing p-values for different tests applied to the NPDE values:
|
| 175 |
+
- "mean": The (one-sample) t-test for zero mean of the NPDE values.
|
| 176 |
+
- "variance": The (one-sample) chi-squared test for unit variance of the NPDE values.
|
| 177 |
+
- "normality": The Shapiro-Wilk test for normality of the NPDE values.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
# variance test not implemented in scipy, so we calculate the p-value manually based on
|
| 181 |
+
# the chi-squared distribution of the sample variance under the null hypothesis of unit variance.
|
| 182 |
+
n = len(npde_values)
|
| 183 |
+
sample_var = np.var(npde_values, ddof=1)
|
| 184 |
+
chi2_stat = (n - 1) * sample_var
|
| 185 |
+
p_lower = chi2.cdf(chi2_stat, df=n - 1)
|
| 186 |
+
p_upper = 1 - p_lower
|
| 187 |
+
p_var = 2 * min(p_lower, p_upper)
|
| 188 |
+
|
| 189 |
+
return {
|
| 190 |
+
"mean": ttest_1samp(npde_values, 0).pvalue, # type: ignore
|
| 191 |
+
"variance": p_var,
|
| 192 |
+
"normality": shapiro(npde_values).pvalue,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def compute_vpc_data(
|
| 197 |
+
data: StudyJSON,
|
| 198 |
+
simulations: Sequence[StudyJSON],
|
| 199 |
+
quantiles: List[float] = [0.05, 0.5, 0.95],
|
| 200 |
+
confidence: float = 0.9,
|
| 201 |
+
n_bins: Optional[int] = None,
|
| 202 |
+
binning: str = "equal_count", # "equal_count" or "equal_width"
|
| 203 |
+
) -> pd.DataFrame:
|
| 204 |
+
"""
|
| 205 |
+
Compute data for a Visual Predictive Check (VPC) plot for the given StudyJSON and a list of simulated StudyJSONs.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
data (StudyJSON): The observed data in StudyJSON
|
| 209 |
+
simulations (List[StudyJSON]): A list of simulated data in StudyJSON format.
|
| 210 |
+
quantiles (List[float]): The quantiles to display in the VPC plot (default: [0.05, 0.5, 0.95]).
|
| 211 |
+
confidence (float): The confidence level for the prediction intervals (default: 0.9).
|
| 212 |
+
Returns:
|
| 213 |
+
pd.DataFrame: A DataFrame containing the VPC data.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
observed_values = json_to_dataframe(data)
|
| 217 |
+
predicted_values = json_list_to_dataframe(simulations)
|
| 218 |
+
|
| 219 |
+
alpha_low = (1 - confidence) / 2
|
| 220 |
+
alpha_high = 1 - alpha_low
|
| 221 |
+
|
| 222 |
+
# --------------------------------
|
| 223 |
+
# Binning (if requested OR if needed)
|
| 224 |
+
# --------------------------------
|
| 225 |
+
if n_bins is not None:
|
| 226 |
+
validate_npde_vpc_inputs(observed_values, predicted_values, differentTimesError=False)
|
| 227 |
+
|
| 228 |
+
match binning:
|
| 229 |
+
case "equal_count":
|
| 230 |
+
observed_values["TimeBin"] = pd.qcut(
|
| 231 |
+
observed_values["Time"], q=n_bins, duplicates="drop"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Use same bin edges for predicted
|
| 235 |
+
bins = observed_values["TimeBin"].cat.categories
|
| 236 |
+
predicted_values["TimeBin"] = pd.cut(predicted_values["Time"], bins=bins)
|
| 237 |
+
|
| 238 |
+
case "equal_width":
|
| 239 |
+
tmin = observed_values["Time"].min()
|
| 240 |
+
tmax = observed_values["Time"].max()
|
| 241 |
+
bins = np.linspace(tmin, tmax, n_bins + 1)
|
| 242 |
+
|
| 243 |
+
observed_values["TimeBin"] = pd.cut(
|
| 244 |
+
observed_values["Time"], bins=bins, include_lowest=True
|
| 245 |
+
)
|
| 246 |
+
predicted_values["TimeBin"] = pd.cut(
|
| 247 |
+
predicted_values["Time"], bins=bins, include_lowest=True
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
case _:
|
| 251 |
+
raise ValueError("binning must be 'equal_width' or 'equal_count'")
|
| 252 |
+
|
| 253 |
+
# Use bin midpoint for plotting
|
| 254 |
+
bin_midpoints = (
|
| 255 |
+
observed_values.groupby("TimeBin", observed=False)["Time"].mean().rename("Time")
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Replace Time with bin midpoint
|
| 259 |
+
observed_values["Time"] = observed_values["TimeBin"].map(bin_midpoints)
|
| 260 |
+
predicted_values["Time"] = predicted_values["TimeBin"].map(bin_midpoints)
|
| 261 |
+
|
| 262 |
+
# Drop bin column
|
| 263 |
+
observed_values = observed_values.drop(columns="TimeBin")
|
| 264 |
+
predicted_values = predicted_values.drop(columns="TimeBin")
|
| 265 |
+
|
| 266 |
+
else:
|
| 267 |
+
validate_npde_vpc_inputs(observed_values, predicted_values, differentTimesError=True) # type: ignore
|
| 268 |
+
|
| 269 |
+
# --------------------------------
|
| 270 |
+
# Quantile calculation
|
| 271 |
+
# --------------------------------
|
| 272 |
+
vpc_obs = (
|
| 273 |
+
observed_values.groupby("Time")["Value"]
|
| 274 |
+
.quantile(quantiles) # type: ignore
|
| 275 |
+
.rename("Obs")
|
| 276 |
+
.reset_index()
|
| 277 |
+
.rename(columns={"level_1": "Quantile"})
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
vpc_pred = (
|
| 281 |
+
predicted_values.groupby(["Time", "Replicate"])["Value"]
|
| 282 |
+
.quantile(quantiles) # type: ignore
|
| 283 |
+
.rename("SimQuantile")
|
| 284 |
+
.reset_index()
|
| 285 |
+
.rename(columns={"level_2": "Quantile"})
|
| 286 |
+
.groupby(["Time", "Quantile"])["SimQuantile"]
|
| 287 |
+
.quantile([alpha_low, alpha_high]) # type: ignore
|
| 288 |
+
.rename("VPC")
|
| 289 |
+
.reset_index()
|
| 290 |
+
.rename(columns={"level_2": "PI"})
|
| 291 |
+
.pivot(index=["Time", "Quantile"], columns="PI", values="VPC")
|
| 292 |
+
.reset_index()
|
| 293 |
+
.rename(columns={alpha_low: "LowerPred", alpha_high: "UpperPred"})
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
vpc_data = vpc_obs.merge(vpc_pred, on=["Time", "Quantile"], how="left")
|
| 297 |
+
|
| 298 |
+
return vpc_data
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def vpc_plot(vpc_data: pd.DataFrame, ax=None, log_y: bool = False) -> None:
|
| 302 |
+
"""
|
| 303 |
+
Create a Visual Predictive Check (VPC) plot for the given VPC data.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
vpc_data (pd.DataFrame): The VPC data to plot.
|
| 307 |
+
ax: Optional matplotlib axis to plot on. If None, a new figure and axis will be created.
|
| 308 |
+
log_y: Whether to use a logarithmic scale for the y-axis (default: False).
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
None
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
quantiles = np.sort(vpc_data["Quantile"].unique())
|
| 315 |
+
|
| 316 |
+
# Enforce exactly 3 quantiles
|
| 317 |
+
if len(quantiles) != 3:
|
| 318 |
+
raise ValueError(f"Expected exactly 3 quantiles, got {len(quantiles)}: {quantiles}")
|
| 319 |
+
|
| 320 |
+
# Default axis management
|
| 321 |
+
if ax is None:
|
| 322 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 323 |
+
|
| 324 |
+
# Log-scale option
|
| 325 |
+
if log_y:
|
| 326 |
+
# Safety check: log scale requires strictly positive values
|
| 327 |
+
y_cols = ["Obs", "LowerPred", "UpperPred"]
|
| 328 |
+
if (vpc_data[y_cols] <= 0).any().any():
|
| 329 |
+
raise ValueError("Log scale requested but non-positive values detected.")
|
| 330 |
+
ax.set_yscale("log")
|
| 331 |
+
|
| 332 |
+
# Color scheme: lower, median, upper
|
| 333 |
+
colors = ["tab:blue", "tab:orange", "tab:blue"]
|
| 334 |
+
|
| 335 |
+
# Map sorted quantiles to colors
|
| 336 |
+
q_to_color = dict(zip(quantiles, colors))
|
| 337 |
+
|
| 338 |
+
# Plot observed quantiles
|
| 339 |
+
quantiles = vpc_data["Quantile"].unique()
|
| 340 |
+
for q in quantiles:
|
| 341 |
+
subset = vpc_data[vpc_data["Quantile"] == q]
|
| 342 |
+
|
| 343 |
+
color = q_to_color[q]
|
| 344 |
+
is_median = np.isclose(q, 0.5)
|
| 345 |
+
|
| 346 |
+
ax.plot(
|
| 347 |
+
subset["Time"],
|
| 348 |
+
subset["Obs"],
|
| 349 |
+
marker="o",
|
| 350 |
+
color=color,
|
| 351 |
+
linewidth=2 if is_median else 1,
|
| 352 |
+
label=f"Observed {q:.0%}",
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
ax.fill_between(
|
| 356 |
+
subset["Time"],
|
| 357 |
+
subset["LowerPred"],
|
| 358 |
+
subset["UpperPred"],
|
| 359 |
+
color=color,
|
| 360 |
+
alpha=0.25,
|
| 361 |
+
label=f"Simulated {q:.0%} PI",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Keep legend outside the plotting area to avoid occluding trajectories.
|
| 365 |
+
ax.legend(
|
| 366 |
+
loc="upper center",
|
| 367 |
+
bbox_to_anchor=(0.5, -0.18),
|
| 368 |
+
ncol=3,
|
| 369 |
+
frameon=False,
|
| 370 |
+
)
|
| 371 |
+
ax.figure.subplots_adjust(bottom=0.25)
|
| 372 |
+
return ax
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
# Example usage
|
| 377 |
+
observed_data = StudyJSON(
|
| 378 |
+
context=[
|
| 379 |
+
IndividualJSON(name_id="1", observation_times=[0, 1, 2], observations=[10, 20, 30]),
|
| 380 |
+
IndividualJSON(name_id="2", observation_times=[0, 1, 2], observations=[11, 21, 31]),
|
| 381 |
+
]
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
simulated_data = [
|
| 385 |
+
StudyJSON(
|
| 386 |
+
context=[
|
| 387 |
+
IndividualJSON(name_id="1", observation_times=[0, 1, 2], observations=[12, 22, 32]),
|
| 388 |
+
IndividualJSON(name_id="2", observation_times=[0, 1, 2], observations=[13, 21, 30]),
|
| 389 |
+
]
|
| 390 |
+
),
|
| 391 |
+
StudyJSON(
|
| 392 |
+
context=[
|
| 393 |
+
IndividualJSON(name_id="1", observation_times=[0, 1, 2], observations=[8, 18, 28]),
|
| 394 |
+
IndividualJSON(name_id="2", observation_times=[0, 1, 2], observations=[11, 19, 27]),
|
| 395 |
+
]
|
| 396 |
+
),
|
| 397 |
+
]
|
| 398 |
+
# Convert to dataframes for visualization (optional)
|
| 399 |
+
observed_values = json_to_dataframe(observed_data)
|
| 400 |
+
simulated_values = json_list_to_dataframe(simulated_data)
|
| 401 |
+
|
| 402 |
+
validate_npde_inputs(observed_values, simulated_values)
|
| 403 |
+
npde_results = calculate_npde(observed_data, simulated_data)
|
| 404 |
+
|
| 405 |
+
print("NPDE Results:", npde_results)
|
| 406 |
+
|
| 407 |
+
vpc_data = create_vpc_data(observed_data, simulated_data)
|
| 408 |
+
|
| 409 |
+
vpc_plot(vpc_data)
|