cesarali commited on
Commit
5686f5b
·
verified ·
1 Parent(s): 269fdb7

manual runtime bundle push from load_and_push.ipynb

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +131 -0
  3. config.json +445 -0
  4. configuration_sim_priors_pk.py +42 -0
  5. modeling_sim_priors_pk.py +123 -0
  6. pytorch_model.bin +3 -0
  7. requirements.txt +4 -0
  8. sim_priors_pk/.DS_Store +0 -0
  9. sim_priors_pk/__init__.py +43 -0
  10. sim_priors_pk/config_classes/__init__.py +0 -0
  11. sim_priors_pk/config_classes/data_config.py +375 -0
  12. sim_priors_pk/config_classes/diffusion_pk_config.py +327 -0
  13. sim_priors_pk/config_classes/flow_pk_config.py +534 -0
  14. sim_priors_pk/config_classes/node_pk_config.py +518 -0
  15. sim_priors_pk/config_classes/source_process_config.py +52 -0
  16. sim_priors_pk/config_classes/training_config.py +96 -0
  17. sim_priors_pk/config_classes/utils.py +14 -0
  18. sim_priors_pk/config_classes/yaml_fallback.py +143 -0
  19. sim_priors_pk/data/README.md +86 -0
  20. sim_priors_pk/data/__init__.py +12 -0
  21. sim_priors_pk/data/data_empirical/__init__.py +35 -0
  22. sim_priors_pk/data/data_empirical/builder.py +1139 -0
  23. sim_priors_pk/data/data_empirical/json_schema.py +372 -0
  24. sim_priors_pk/data/data_empirical/json_stats.py +201 -0
  25. sim_priors_pk/data/data_empirical/simulx_to_json.py +71 -0
  26. sim_priors_pk/data/data_generation/__init__.py +0 -0
  27. sim_priors_pk/data/data_generation/compartment_models.py +721 -0
  28. sim_priors_pk/data/data_generation/compartment_models_management.py +1338 -0
  29. sim_priors_pk/data/data_generation/dosing_models.py +0 -0
  30. sim_priors_pk/data/data_generation/observations_classes.py +1776 -0
  31. sim_priors_pk/data/data_generation/observations_functions.py +69 -0
  32. sim_priors_pk/data/data_generation/study_population_stats.py +185 -0
  33. sim_priors_pk/data/data_preprocessing/__init__.py +0 -0
  34. sim_priors_pk/data/data_preprocessing/data_preprocessing_utils.py +321 -0
  35. sim_priors_pk/data/data_preprocessing/raw_to_tensors_bundles.py +360 -0
  36. sim_priors_pk/data/data_preprocessing/tensors_to_databatch.py +72 -0
  37. sim_priors_pk/data/datasets/aicme_batch.py +167 -0
  38. sim_priors_pk/data/datasets/aicme_datasets.py +1874 -0
  39. sim_priors_pk/data/extra/compartment_models_vectorized.py +182 -0
  40. sim_priors_pk/data/extra/kernels.py +28 -0
  41. sim_priors_pk/hub_runtime/README.md +187 -0
  42. sim_priors_pk/hub_runtime/__init__.py +19 -0
  43. sim_priors_pk/hub_runtime/configuration_sim_priors_pk.py +42 -0
  44. sim_priors_pk/hub_runtime/modeling_sim_priors_pk.py +123 -0
  45. sim_priors_pk/hub_runtime/runtime_bundle.py +269 -0
  46. sim_priors_pk/hub_runtime/runtime_contract.py +662 -0
  47. sim_priors_pk/metrics/__init__.py +0 -0
  48. sim_priors_pk/metrics/pk_metrics.py +490 -0
  49. sim_priors_pk/metrics/quantiles_coverage.py +310 -0
  50. sim_priors_pk/metrics/sampling_quality.py +409 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 César A. Ojeda
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ library_name: generative-pk
6
+ datasets:
7
+ - simulated
8
+ metrics:
9
+ - rmse
10
+ - npde
11
+ tags:
12
+ - generative
13
+ - predictive
14
+ ---
15
+
16
+ # Hierarchical Neural Process for Pharmacokinetic Data
17
+
18
+ ## Overview
19
+ An Amortized Context Neural Process Generative model for Pharmacokinetic Modelling
20
+
21
+ **Model details:**
22
+ - **Authors:** César Ojeda (@cesarali)
23
+ - **License:** Apache 2.0
24
+
25
+ ## Intended use
26
+ Sample Drug Concentration Behavior and Sample and Prediction of New Points or new Individual
27
+ ## Runtime Bundle
28
+
29
+ This repository is the consumer-facing runtime bundle for this PK model.
30
+
31
+ - Runtime repo: `cesarali/AICME-runtime`
32
+ - Native training/artifact repo: `cesarali/AICMEPK_cluster`
33
+ - Supported tasks: `generate`, `predict`
34
+ - Default task: `generate`
35
+ - Load path: `AutoModel.from_pretrained(..., trust_remote_code=True)`
36
+
37
+ ### Installation
38
+
39
+ You do **not** need to install `sim_priors_pk` to use this runtime bundle.
40
+
41
+ `transformers` is the public loading entrypoint, but `transformers` alone is
42
+ not sufficient because this is a PyTorch model with custom runtime code. A
43
+ reliable consumer environment is:
44
+
45
+ ```bash
46
+ pip install torch transformers huggingface_hub lightning datasets pandas torchtyping gpytorch pot torchdiffeq torchsde ruamel.yaml pyyaml
47
+ ```
48
+
49
+ ### Python Usage
50
+
51
+ ```python
52
+ from transformers import AutoModel
53
+
54
+ model = AutoModel.from_pretrained("cesarali/AICME-runtime", trust_remote_code=True)
55
+
56
+ studies = [
57
+ {
58
+ "context": [
59
+ {
60
+ "name_id": "ctx_0",
61
+ "observations": [0.2, 0.5, 0.3],
62
+ "observation_times": [0.5, 1.0, 2.0],
63
+ "dosing": [1.0],
64
+ "dosing_type": ["oral"],
65
+ "dosing_times": [0.0],
66
+ "dosing_name": ["oral"],
67
+ }
68
+ ],
69
+ "target": [],
70
+ "meta_data": {"study_name": "demo", "substance_name": "drug_x"},
71
+ }
72
+ ]
73
+
74
+ outputs = model.run_task(
75
+ task="generate",
76
+ studies=studies,
77
+ num_samples=4,
78
+ )
79
+ print(outputs["results"][0]["samples"])
80
+ ```
81
+
82
+ ### Predictive Sampling
83
+
84
+ ```python
85
+ from transformers import AutoModel
86
+
87
+ model = AutoModel.from_pretrained("cesarali/AICME-runtime", trust_remote_code=True)
88
+
89
+ predict_studies = [
90
+ {
91
+ "context": [
92
+ {
93
+ "name_id": "ctx_0",
94
+ "observations": [0.2, 0.5, 0.3],
95
+ "observation_times": [0.5, 1.0, 2.0],
96
+ "dosing": [1.0],
97
+ "dosing_type": ["oral"],
98
+ "dosing_times": [0.0],
99
+ "dosing_name": ["oral"],
100
+ }
101
+ ],
102
+ "target": [
103
+ {
104
+ "name_id": "tgt_0",
105
+ "observations": [0.25, 0.31],
106
+ "observation_times": [0.5, 1.0],
107
+ "remaining": [0.0, 0.0, 0.0],
108
+ "remaining_times": [2.0, 4.0, 8.0],
109
+ "dosing": [1.0],
110
+ "dosing_type": ["oral"],
111
+ "dosing_times": [0.0],
112
+ "dosing_name": ["oral"],
113
+ }
114
+ ],
115
+ "meta_data": {"study_name": "demo", "substance_name": "drug_x"},
116
+ }
117
+ ]
118
+
119
+ outputs = model.run_task(
120
+ task="predict",
121
+ studies=predict_studies,
122
+ num_samples=4,
123
+ )
124
+ print(outputs["results"][0]["samples"][0]["target"][0]["prediction_samples"])
125
+ ```
126
+
127
+ ### Notes
128
+
129
+ - `trust_remote_code=True` is required because this model uses custom Hugging Face Hub runtime code.
130
+ - The consumer API is `transformers` + `run_task(...)`; the consumer does not need a local clone of this repository.
131
+ - This runtime bundle is intentionally separate from the native training export so you can evaluate both distribution paths in parallel.
config.json ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture_name": "AICMEPK",
3
+ "architectures": [
4
+ "PKHubModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_sim_priors_pk.PKHubConfig",
8
+ "AutoModel": "modeling_sim_priors_pk.PKHubModel"
9
+ },
10
+ "builder_config": {
11
+ "max_context_individuals": 10,
12
+ "max_context_observations": 15,
13
+ "max_context_remaining": 15,
14
+ "max_target_individuals": 1,
15
+ "max_target_observations": 5,
16
+ "max_target_remaining": 12
17
+ },
18
+ "default_task": "generate",
19
+ "experiment_config": {
20
+ "comet_ai_key": null,
21
+ "context_observations": {
22
+ "add_rem": true,
23
+ "drop_time_zero_observations": false,
24
+ "empirical_number_of_obs": false,
25
+ "generative_bias": false,
26
+ "max_num_obs": 15,
27
+ "max_past": 5,
28
+ "min_past": 3,
29
+ "past_time_ratio": 0.1,
30
+ "split_past_future": false,
31
+ "type": "pk_peak_half_life"
32
+ },
33
+ "debug_test": false,
34
+ "dosing": {
35
+ "logdose_mean_range": [
36
+ -2.0,
37
+ 2.0
38
+ ],
39
+ "logdose_std_range": [
40
+ 0.1,
41
+ 0.5
42
+ ],
43
+ "num_individuals": 10,
44
+ "route_options": [
45
+ "oral",
46
+ "iv"
47
+ ],
48
+ "route_weights": [
49
+ 0.8,
50
+ 0.2
51
+ ],
52
+ "same_route": true,
53
+ "time": 0.0
54
+ },
55
+ "experiment_dir": "/work/ojedamarin/Projects/Pharma/Results/comet/uai/7195d8f55b5d4684a766a69d5a736d28",
56
+ "experiment_indentifier": null,
57
+ "experiment_name": "uai",
58
+ "experiment_type": "nodepk",
59
+ "hf_model_card_path": [
60
+ "hf_model_cards",
61
+ "AICME-PK_Readme.md"
62
+ ],
63
+ "hf_model_name": "AICMEPK_cluster",
64
+ "hugging_face_token": null,
65
+ "meta_study": {
66
+ "V_tmag_range": [
67
+ 0.001,
68
+ 0.001
69
+ ],
70
+ "V_tscl_range": [
71
+ 1,
72
+ 5
73
+ ],
74
+ "drug_id_options": [
75
+ "Drug_A",
76
+ "Drug_B",
77
+ "Drug_C"
78
+ ],
79
+ "k_1p_tmag_range": [
80
+ 0.01,
81
+ 0.02
82
+ ],
83
+ "k_1p_tscl_range": [
84
+ 1,
85
+ 5
86
+ ],
87
+ "k_a_tmag_range": [
88
+ 0.01,
89
+ 0.02
90
+ ],
91
+ "k_a_tscl_range": [
92
+ 1,
93
+ 5
94
+ ],
95
+ "k_e_tmag_range": [
96
+ 0.01,
97
+ 0.02
98
+ ],
99
+ "k_e_tscl_range": [
100
+ 1,
101
+ 5
102
+ ],
103
+ "k_p1_tmag_range": [
104
+ 0.01,
105
+ 0.02
106
+ ],
107
+ "k_p1_tscl_range": [
108
+ 1,
109
+ 5
110
+ ],
111
+ "log_V_mean_range": [
112
+ 2,
113
+ 8
114
+ ],
115
+ "log_V_std_range": [
116
+ 0.2,
117
+ 0.6
118
+ ],
119
+ "log_k_1p_mean_range": [
120
+ -4,
121
+ 0
122
+ ],
123
+ "log_k_1p_std_range": [
124
+ 0.2,
125
+ 0.6
126
+ ],
127
+ "log_k_a_mean_range": [
128
+ -1,
129
+ 2
130
+ ],
131
+ "log_k_a_std_range": [
132
+ 0.2,
133
+ 0.6
134
+ ],
135
+ "log_k_e_mean_range": [
136
+ -5,
137
+ 0
138
+ ],
139
+ "log_k_e_std_range": [
140
+ 0.2,
141
+ 0.6
142
+ ],
143
+ "log_k_p1_mean_range": [
144
+ -4,
145
+ -1
146
+ ],
147
+ "log_k_p1_std_range": [
148
+ 0.2,
149
+ 0.6
150
+ ],
151
+ "num_individuals_range": [
152
+ 5,
153
+ 10
154
+ ],
155
+ "num_peripherals_range": [
156
+ 1,
157
+ 3
158
+ ],
159
+ "rel_ruv_range": [
160
+ 0.001,
161
+ 0.01
162
+ ],
163
+ "solver_method": "rk4",
164
+ "time_num_steps": 100,
165
+ "time_start": 0.0,
166
+ "time_stop": 16.0
167
+ },
168
+ "mix_data": {
169
+ "evaluate_prediction_steps_past": 5,
170
+ "keep_tempfile": false,
171
+ "log_and_max": false,
172
+ "log_and_z": false,
173
+ "log_transform": false,
174
+ "n_of_databatches": null,
175
+ "n_of_permutations": 3,
176
+ "n_of_target_individuals": 1,
177
+ "normalize_by_max": true,
178
+ "normalize_time": true,
179
+ "recreate_tempfile": false,
180
+ "sample_size_for_generative_evaluation": null,
181
+ "sample_size_for_generative_evaluation_end_of_training": 500,
182
+ "sample_size_for_generative_evaluation_val": 10,
183
+ "store_in_tempfile": false,
184
+ "tempfile_path": [
185
+ "preprocessed",
186
+ "simulated_ou_as_rates"
187
+ ],
188
+ "test_empirical_datasets": [
189
+ "cesarali/lenuzza-2016",
190
+ "cesarali/Indometacin",
191
+ "cesarali/Theophylline"
192
+ ],
193
+ "test_size": 64,
194
+ "tqdm_progress": false,
195
+ "train_size": 12800,
196
+ "val_size": 256,
197
+ "z_score_normalization": false
198
+ },
199
+ "my_results_path": "/work/ojedamarin/Projects/Pharma/Results/",
200
+ "name_str": "AICMEPK",
201
+ "network": {
202
+ "activation": "ReLU",
203
+ "aggregator_num_heads": 8,
204
+ "aggregator_type": "mean",
205
+ "combine_latent_mode": "mlp",
206
+ "cov_proj_dim": 16,
207
+ "decoder_attention_layers": 2,
208
+ "decoder_hidden_dim": 512,
209
+ "decoder_name": "TransformerDecoder",
210
+ "decoder_num_layers": 4,
211
+ "decoder_rnn_hidden_dim": 256,
212
+ "drift_activation": "Tanh",
213
+ "drift_num_layers": 2,
214
+ "dropout": 0.1,
215
+ "encoder_rnn_hidden_dim": 256,
216
+ "exclusive_node_step": true,
217
+ "ignore_logvar": true,
218
+ "individual_encoder_name": "RNNContextEncoder",
219
+ "individual_encoder_number_of_heads": 4,
220
+ "init_hidden_num_layers": 2,
221
+ "input_encoding_hidden_dim": 128,
222
+ "kl_weight": 1.0,
223
+ "loss_name": "log_nll",
224
+ "node_step": true,
225
+ "norm": "layer",
226
+ "output_head_num_layers": 3,
227
+ "prediction_latent_deterministic": false,
228
+ "prediction_only": false,
229
+ "reconstruction_only": false,
230
+ "rnn_decoder_number_of_layers": 4,
231
+ "rnn_individual_encoder_number_of_layers": 4,
232
+ "scale_dosing_amounts": true,
233
+ "study_latent_deterministic": false,
234
+ "time_obs_encoder_hidden_dim": 256,
235
+ "time_obs_encoder_output_dim": 256,
236
+ "use_attention": true,
237
+ "use_invariance_loss": false,
238
+ "use_kl_i": true,
239
+ "use_kl_i_np": true,
240
+ "use_kl_init": true,
241
+ "use_kl_s": true,
242
+ "use_self_attention": true,
243
+ "use_time_deltas": true,
244
+ "zi_latent_dim": 128
245
+ },
246
+ "run_index": 0,
247
+ "tags": [
248
+ "AICME",
249
+ "AISTATS-2026",
250
+ "camera-ready"
251
+ ],
252
+ "target_observations": {
253
+ "add_rem": true,
254
+ "drop_time_zero_observations": false,
255
+ "empirical_number_of_obs": 2,
256
+ "generative_bias": false,
257
+ "max_num_obs": 15,
258
+ "max_past": 5,
259
+ "min_past": 3,
260
+ "past_time_ratio": 0.1,
261
+ "split_past_future": true,
262
+ "type": "pk_peak_half_life"
263
+ },
264
+ "train": {
265
+ "amsgrad": false,
266
+ "batch_size": 64,
267
+ "betas": [
268
+ 0.9,
269
+ 0.999
270
+ ],
271
+ "callbacks_scheduler": {
272
+ "checkpoint_used_in_end": [
273
+ "end",
274
+ "best",
275
+ "log_rmse"
276
+ ],
277
+ "include_end": true,
278
+ "keep_temp_files": false,
279
+ "max_samples_per_group": 500,
280
+ "percent_step": 0.1,
281
+ "skip_sanity_check": true,
282
+ "store_samples": true,
283
+ "task_during": [
284
+ {
285
+ "fn_key": "pk.predictive.images",
286
+ "log_prefix": "Synthetic",
287
+ "n_samples": 1,
288
+ "name": "synthetic/predictive_images",
289
+ "sample_source": "val_batch",
290
+ "save_to_disk": true,
291
+ "split": "val",
292
+ "task_cfg": {
293
+ "label": "Synthetic",
294
+ "milestone_stride": 1
295
+ }
296
+ },
297
+ {
298
+ "fn_key": "pk.generative.images",
299
+ "log_prefix": "Synthetic",
300
+ "n_samples": 10,
301
+ "name": "synthetic/new_individuals_images",
302
+ "sample_source": "val_batch",
303
+ "save_to_disk": true,
304
+ "split": "val",
305
+ "task_cfg": {
306
+ "label": "Synthetic",
307
+ "milestone_stride": 1
308
+ }
309
+ },
310
+ {
311
+ "fn_key": "pk.predictive.metrics",
312
+ "log_prefix": "Empirical",
313
+ "n_samples": 1,
314
+ "name": "empirical/predictive_metrics",
315
+ "sample_source": "empirical_set",
316
+ "save_to_disk": false,
317
+ "split": "empirical_heldout",
318
+ "task_cfg": {
319
+ "label": "Empirical",
320
+ "milestone_stride": 5
321
+ }
322
+ },
323
+ {
324
+ "checkpoint_metric": true,
325
+ "checkpoint_metric_name": "log_rmse",
326
+ "checkpoint_mode": "min",
327
+ "fn_key": "pk.empirical.summary",
328
+ "log_prefix": "Empirical",
329
+ "n_samples": 0,
330
+ "name": "empirical/summary",
331
+ "sample_source": "val_batch",
332
+ "save_to_disk": false,
333
+ "split": "val",
334
+ "task_cfg": {
335
+ "label": "Empirical",
336
+ "milestone_stride": 5,
337
+ "selected_summary_drugs": [
338
+ "paracetamol glucuronide",
339
+ "midazolam"
340
+ ],
341
+ "summary_metric": "log_rmse"
342
+ }
343
+ }
344
+ ],
345
+ "tasks_end": [
346
+ {
347
+ "fn_key": "pk.predictive.metrics",
348
+ "log_prefix": "Empirical",
349
+ "n_samples": 1,
350
+ "name": "empirical/predictive_metrics",
351
+ "sample_source": "empirical_set",
352
+ "save_to_disk": false,
353
+ "split": "empirical_heldout",
354
+ "task_cfg": {
355
+ "label": "Empirical"
356
+ }
357
+ },
358
+ {
359
+ "fn_key": "pk.predictive.images",
360
+ "log_prefix": "Empirical",
361
+ "n_samples": 1,
362
+ "name": "empirical/predictive_images",
363
+ "sample_source": "empirical_set",
364
+ "save_to_disk": true,
365
+ "split": "empirical_heldout",
366
+ "task_cfg": {
367
+ "label": "Empirical"
368
+ }
369
+ },
370
+ {
371
+ "fn_key": "pk.vpc.npde_pvalues",
372
+ "log_prefix": "Empirical",
373
+ "n_samples": 500,
374
+ "name": "empirical/vpc_npde_pvalues",
375
+ "sample_source": "empirical_set",
376
+ "save_to_disk": false,
377
+ "split": "empirical_no_heldout",
378
+ "task_cfg": {
379
+ "label": "Empirical"
380
+ }
381
+ },
382
+ {
383
+ "fn_key": "pk.vpc.images",
384
+ "log_prefix": "Empirical",
385
+ "n_samples": 500,
386
+ "name": "empirical/vpc_images",
387
+ "sample_source": "empirical_set",
388
+ "save_to_disk": true,
389
+ "split": "empirical_no_heldout",
390
+ "task_cfg": {
391
+ "label": "Empirical"
392
+ }
393
+ },
394
+ {
395
+ "fn_key": "pk.empirical.summary",
396
+ "log_prefix": "Empirical",
397
+ "n_samples": 0,
398
+ "name": "empirical/summary",
399
+ "sample_source": "val_batch",
400
+ "save_to_disk": false,
401
+ "split": "val",
402
+ "task_cfg": {
403
+ "label": "Empirical",
404
+ "selected_summary_drugs": [
405
+ "paracetamol glucuronide",
406
+ "midazolam"
407
+ ],
408
+ "summary_metric": "log_rmse"
409
+ }
410
+ }
411
+ ],
412
+ "tasks_validation": []
413
+ },
414
+ "epochs": 100,
415
+ "eps": 1e-08,
416
+ "gradient_clip_val": 0.5,
417
+ "learning_rate": 0.0001,
418
+ "log_interval": 1,
419
+ "num_batch_plot": 1,
420
+ "num_workers": 8,
421
+ "optimizer_name": "AdamW",
422
+ "persistent_workers": true,
423
+ "scheduler_name": "CosineAnnealingLR",
424
+ "scheduler_params": {
425
+ "T_max": 1000,
426
+ "eta_min": 5e-05,
427
+ "last_epoch": -1
428
+ },
429
+ "shuffle_val": true,
430
+ "weight_decay": 0.0001
431
+ },
432
+ "upload_to_hf_hub": true,
433
+ "verbose": false
434
+ },
435
+ "experiment_type": "nodepk",
436
+ "io_schema_version": "studyjson-v1",
437
+ "model_type": "sim_priors_pk",
438
+ "original_repo_id": "cesarali/AICMEPK_cluster",
439
+ "runtime_repo_id": "cesarali/AICME-runtime",
440
+ "supported_tasks": [
441
+ "generate",
442
+ "predict"
443
+ ],
444
+ "transformers_version": "4.52.4"
445
+ }
configuration_sim_priors_pk.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face configuration for self-contained PK runtime bundles."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+ from sim_priors_pk.hub_runtime.runtime_contract import STUDY_JSON_IO_VERSION
10
+
11
+
12
+ class PKHubConfig(PretrainedConfig):
13
+ """Public Hub config describing a consumer-facing PK runtime bundle."""
14
+
15
+ model_type = "sim_priors_pk"
16
+
17
+ def __init__(
18
+ self,
19
+ architecture_name: Optional[str] = None,
20
+ experiment_type: str = "nodepk",
21
+ experiment_config: Optional[Dict[str, Any]] = None,
22
+ builder_config: Optional[Dict[str, Any]] = None,
23
+ supported_tasks: Optional[List[str]] = None,
24
+ default_task: Optional[str] = None,
25
+ io_schema_version: str = STUDY_JSON_IO_VERSION,
26
+ original_repo_id: Optional[str] = None,
27
+ runtime_repo_id: Optional[str] = None,
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(**kwargs)
31
+ self.architecture_name = architecture_name
32
+ self.experiment_type = experiment_type
33
+ self.experiment_config = dict(experiment_config or {})
34
+ self.builder_config = dict(builder_config or {})
35
+ self.supported_tasks = list(supported_tasks or [])
36
+ self.default_task = default_task or (self.supported_tasks[0] if self.supported_tasks else None)
37
+ self.io_schema_version = io_schema_version
38
+ self.original_repo_id = original_repo_id
39
+ self.runtime_repo_id = runtime_repo_id
40
+
41
+
42
+ __all__ = ["PKHubConfig"]
modeling_sim_priors_pk.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face AutoModel wrapper for consumer-facing PK runtime bundles."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Optional, Sequence, Union
6
+
7
+ import torch
8
+ from transformers import PreTrainedModel
9
+
10
+ from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
11
+ from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
12
+ from sim_priors_pk.hub_runtime.runtime_contract import (
13
+ RuntimeBuilderConfig,
14
+ build_batch_from_studies,
15
+ infer_supported_tasks,
16
+ instantiate_backbone_from_hub_config,
17
+ normalize_studies_input,
18
+ split_runtime_samples,
19
+ validate_studies_for_task,
20
+ )
21
+ from sim_priors_pk.models.amortized_inference.generative_pk import (
22
+ NewGenerativeMixin,
23
+ NewPredictiveMixin,
24
+ )
25
+
26
+
27
+ class PKHubModel(PreTrainedModel):
28
+ """Thin wrapper exposing a stable StudyJSON runtime API on top of PK models."""
29
+
30
+ config_class = PKHubConfig
31
+ base_model_prefix = "backbone"
32
+
33
+ def __init__(self, config: PKHubConfig, backbone: Optional[torch.nn.Module] = None) -> None:
34
+ super().__init__(config)
35
+ self.backbone = backbone if backbone is not None else instantiate_backbone_from_hub_config(config)
36
+ self.backbone.eval()
37
+
38
+ def forward(self, *args, **kwargs):
39
+ """Delegate raw forward calls to the wrapped PK backbone."""
40
+
41
+ return self.backbone(*args, **kwargs)
42
+
43
+ @property
44
+ def supported_tasks(self) -> Sequence[str]:
45
+ """Tasks supported by this runtime model."""
46
+
47
+ return tuple(getattr(self.config, "supported_tasks", []) or infer_supported_tasks(self.backbone))
48
+
49
+ @torch.inference_mode()
50
+ def run_task(
51
+ self,
52
+ *,
53
+ task: str,
54
+ studies: Union[StudyJSON, Sequence[StudyJSON]],
55
+ num_samples: int = 1,
56
+ **kwargs: Any,
57
+ ) -> Dict[str, Any]:
58
+ """Run the public StudyJSON inference contract for the requested task."""
59
+
60
+ supported_tasks = list(self.supported_tasks)
61
+ if task not in supported_tasks:
62
+ raise ValueError(
63
+ f"Unsupported task {task!r}. Supported tasks: {supported_tasks or 'none'}."
64
+ )
65
+ if int(num_samples) < 1:
66
+ raise ValueError("num_samples must be >= 1.")
67
+
68
+ canonical_studies = normalize_studies_input(studies)
69
+ builder_config = RuntimeBuilderConfig.from_dict(self.config.builder_config)
70
+ validate_studies_for_task(canonical_studies, task=task, builder_config=builder_config)
71
+
72
+ experiment_config_payload = getattr(self.config, "experiment_config", {})
73
+ meta_dosing_payload = experiment_config_payload.get("dosing", {})
74
+ batch = build_batch_from_studies(
75
+ canonical_studies,
76
+ builder_config=builder_config,
77
+ meta_dosing=self.backbone.meta_dosing.__class__(**meta_dosing_payload)
78
+ if meta_dosing_payload
79
+ else self.backbone.meta_dosing,
80
+ )
81
+ batch = batch.to(self.device)
82
+
83
+ if task == "generate":
84
+ if not isinstance(self.backbone, NewGenerativeMixin):
85
+ raise ValueError(f"Backbone {type(self.backbone).__name__} does not support generate.")
86
+ output_studies = self.backbone.sample_new_individuals_to_studyjson(
87
+ batch,
88
+ sample_size=int(num_samples),
89
+ num_steps=kwargs.get("num_steps"),
90
+ )
91
+ elif task == "predict":
92
+ if not isinstance(self.backbone, NewPredictiveMixin):
93
+ raise ValueError(f"Backbone {type(self.backbone).__name__} does not support predict.")
94
+ output_studies = self.backbone.sample_individual_prediction_from_batch_list_to_studyjson(
95
+ [batch],
96
+ sample_size=int(num_samples),
97
+ )[0]
98
+ else:
99
+ raise ValueError(f"Unsupported task {task!r}.")
100
+
101
+ results = [
102
+ {
103
+ "input_index": index,
104
+ "samples": split_runtime_samples(task, study),
105
+ }
106
+ for index, study in enumerate(output_studies)
107
+ ]
108
+
109
+ return {
110
+ "task": task,
111
+ "io_schema_version": self.config.io_schema_version,
112
+ "model_info": {
113
+ "architecture_name": self.config.architecture_name,
114
+ "experiment_type": self.config.experiment_type,
115
+ "supported_tasks": supported_tasks,
116
+ "runtime_repo_id": self.config.runtime_repo_id,
117
+ "original_repo_id": self.config.original_repo_id,
118
+ },
119
+ "results": results,
120
+ }
121
+
122
+
123
+ __all__ = ["PKHubModel"]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec19d3a6970fcda03332a75ea0b12bb53e17e2d945088ef46e28c74f73195c84
3
+ size 37495779
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pytest==8.3.5
2
+ ipython==9.2.0
3
+ comet_ml==3.49.6
4
+ matplotlib==3.10.1 # If not needed for inference
sim_priors_pk/.DS_Store ADDED
Binary file (6.15 kB). View file
 
sim_priors_pk/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+
4
+ def _load_key_file(path: Path) -> str | None:
5
+ """Return the contents of a key file if it exists, otherwise ``None``."""
6
+
7
+ try:
8
+ return path.read_text(encoding="utf-8").strip()
9
+ except FileNotFoundError:
10
+ return None
11
+ except OSError:
12
+ # If the file is unreadable we surface the issue by returning ``None``
13
+ # so callers can decide how to handle missing credentials.
14
+ return None
15
+
16
+
17
+ base_dir = Path(__file__).resolve().parent
18
+ project_dir = (base_dir / "..").resolve()
19
+ data_dir = project_dir / "data"
20
+ test_resources_dir = project_dir / "tests" / "resources"
21
+ results_dir = project_dir / "results"
22
+ reports_dir = project_dir / "reports"
23
+ config_dir = project_dir / "config_files"
24
+
25
+ comet_keys_file = project_dir / "COMET_KEYS.txt"
26
+ hf_keys_file = project_dir / "KEYS.txt"
27
+
28
+ COMET_KEY = _load_key_file(comet_keys_file)
29
+ HUGGINGFACE_KEY = _load_key_file(hf_keys_file)
30
+
31
+ __all__ = [
32
+ "COMET_KEY",
33
+ "HUGGINGFACE_KEY",
34
+ "base_dir",
35
+ "comet_keys_file",
36
+ "config_dir",
37
+ "data_dir",
38
+ "hf_keys_file",
39
+ "project_dir",
40
+ "reports_dir",
41
+ "results_dir",
42
+ "test_resources_dir",
43
+ ]
sim_priors_pk/config_classes/__init__.py ADDED
File without changes
sim_priors_pk/config_classes/data_config.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Dict, Tuple, Optional, Union
4
+ import warnings
5
+
6
+ try: # pragma: no cover - exercised indirectly via configuration loading
7
+ import yaml # type: ignore
8
+ except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
9
+ from sim_priors_pk.config_classes import yaml_fallback as yaml
10
+
11
+ try: # pragma: no cover - optional dependency for downstream modules
12
+ import torch # type: ignore
13
+ except ModuleNotFoundError: # pragma: no cover - torch is not required for configuration loading
14
+ torch = None # type: ignore
15
+
16
+ @dataclass
17
+ class SimpleMetaStudyConfig:
18
+ """
19
+ Minimal configuration for the synthetic (non-mechanistic) PK simulator.
20
+ Used when `simple_mode=True` is detected in the YAML file.
21
+ """
22
+
23
+ simple_mode: bool = True
24
+
25
+ # --- keep same naming as MetaStudyConfig for compatibility ---
26
+ num_individuals: int = 16
27
+ num_individuals_range: Tuple[int, int] = (16, 16) # <== added to avoid downstream errors
28
+
29
+ time_start: float = 0.0
30
+ time_stop: float = 24.0
31
+ time_num_steps: int = 40
32
+
33
+ band_scale_range: Tuple[float, float] = (0.1, 0.3)
34
+ baseline_range: Tuple[float, float] = (0.0, 0.1)
35
+ decay_rate_range: Tuple[float, float] = (0.3, 0.6)
36
+ p1: float = 0.5 # of runs use the exponential, 65% use the pulse
37
+ num_peripherals_range: Tuple[int, int] = (1, 3)
38
+
39
+ solver_method: str = "dummy"
40
+ drug_id_options: List[str] = field(default_factory=lambda: ["DummyDrug"])
41
+
42
+ @classmethod
43
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "SimpleMetaStudyConfig":
44
+ with open(file_path, "r", encoding="utf-8") as handle:
45
+ cfg = yaml.safe_load(handle) or {}
46
+ cfg = cfg.get("meta_study", cfg)
47
+
48
+ # Ensure backward compatibility if YAML only defines num_individuals
49
+ if "num_individuals_range" not in cfg and "num_individuals" in cfg:
50
+ n = cfg["num_individuals"]
51
+ cfg["num_individuals_range"] = (n, n)
52
+
53
+ return cls(**cfg)
54
+
55
+
56
+ @dataclass
57
+ class MetaStudyConfig:
58
+ """
59
+ This class contains the configuration for the compartment study.
60
+ i.e. it specifies the parameters to sample the population which
61
+ in turns will sample the individuals.
62
+ """
63
+
64
+ drug_id_options: List[str] = field(default_factory=lambda: ["Drug_A", "Drug_B", "Drug_C"])
65
+ num_individuals_range: Tuple[int, int] = (20, 20)
66
+
67
+ num_peripherals_range: Tuple[int, int] = (1, 3)
68
+ log_k_a_mean_range: Tuple[float, float] = (-1.5, 1.5)
69
+ log_k_a_std_range: Tuple[float, float] = (0.1, 0.5)
70
+ k_a_tmag_range: Tuple[float, float] = (0.01, 0.1)
71
+ k_a_tscl_range: Tuple[float, float] = (1.0, 5.0)
72
+ log_k_e_mean_range: Tuple[float, float] = (-1.5, 1.5)
73
+ log_k_e_std_range: Tuple[float, float] = (0.1, 0.5)
74
+ k_e_tmag_range: Tuple[float, float] = (0.01, 0.1)
75
+ k_e_tscl_range: Tuple[float, float] = (1.0, 5.0)
76
+ log_V_mean_range: Tuple[float, float] = (-1.5, 1.5)
77
+ log_V_std_range: Tuple[float, float] = (0.1, 0.5)
78
+ V_tmag_range: Tuple[float, float] = (0.01, 0.1)
79
+ V_tscl_range: Tuple[float, float] = (1.0, 5.0)
80
+ log_k_1p_mean_range: Tuple[float, float] = (-1.5, 1.5)
81
+ log_k_1p_std_range: Tuple[float, float] = (0.1, 0.5)
82
+ k_1p_tmag_range: Tuple[float, float] = (0.01, 0.1)
83
+ k_1p_tscl_range: Tuple[float, float] = (1.0, 5.0)
84
+ log_k_p1_mean_range: Tuple[float, float] = (-1.5, 1.5)
85
+ log_k_p1_std_range: Tuple[float, float] = (0.1, 0.5)
86
+ k_p1_tmag_range: Tuple[float, float] = (0.01, 0.1)
87
+ k_p1_tscl_range: Tuple[float, float] = (1.0, 5.0)
88
+
89
+ # Parameters for observation noise
90
+ rel_ruv_range: Tuple[float, float] = (0.05, 0.3)
91
+
92
+ # Parameters for generating time_points
93
+ time_start: float = 0.0
94
+ time_stop: float = 10.0
95
+ time_num_steps: int = 100
96
+
97
+ # parameters for solver
98
+ solver_method: str = "rk4"
99
+
100
+ @classmethod
101
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "MetaStudyConfig":
102
+ """Instantiate the meta-study configuration from a YAML file."""
103
+
104
+ with open(file_path, "r", encoding="utf-8") as handle:
105
+ config_dict = yaml.safe_load(handle) or {}
106
+
107
+ if isinstance(config_dict, dict) and "meta_study" in config_dict:
108
+ config_dict = config_dict.get("meta_study") or {}
109
+
110
+ if not isinstance(config_dict, dict):
111
+ raise TypeError("Expected 'meta_study' section in YAML to be a mapping.")
112
+
113
+ return cls(**config_dict)
114
+
115
+
116
+ @dataclass
117
+ class ObservationsConfig:
118
+ """High-level knobs describing an observation strategy."""
119
+
120
+ # ``None`` (e.g. YAML ``type: null``) is treated as the legacy
121
+ # ``pk_peak_half_life`` strategy by the observation factory.
122
+ type: Optional[str] = "pk_peak_half_life"
123
+ add_rem: bool = True
124
+ split_past_future: bool = False
125
+ min_past: Optional[int] = None
126
+ max_past: Optional[int] = None
127
+ max_num_obs: int = 10
128
+ empirical_number_of_obs: int = 2
129
+ # When True, entries at non-positive times are excluded from sampled
130
+ # observations (e.g. concentration at dosing time t=0).
131
+ drop_time_zero_observations: bool = False
132
+
133
+ # Strategy specific semantic controls (do not affect tensor shapes directly)
134
+ past_time_ratio: float = 0.1 # Used by random strategies with fixed boundary
135
+ # Sampling policy for split-past/future strategies:
136
+ # - False: sample uniformly in [min_past, max_past]
137
+ # - True: sample 0 with prob. 0.5, otherwise sample uniformly
138
+ # in [max(1, min_past), max_past]
139
+ generative_bias: bool = False
140
+
141
+ def __post_init__(self):
142
+ if not isinstance(self.generative_bias, bool):
143
+ raise ValueError("generative_bias must be a boolean (true/false)")
144
+
145
+ if self.split_past_future:
146
+ if self.min_past is None or self.max_past is None:
147
+ raise ValueError(
148
+ "min_past and max_past must be provided when split_past_future=True"
149
+ )
150
+ if self.min_past < 0:
151
+ raise ValueError("min_past must be non-negative")
152
+ if self.max_past < self.min_past:
153
+ raise ValueError("max_past must be >= min_past")
154
+ self.add_rem = True
155
+
156
+ @classmethod
157
+ def from_yaml(
158
+ cls,
159
+ file_path: Union[str, os.PathLike],
160
+ section: Optional[str] = None,
161
+ ) -> "ObservationsConfig":
162
+ """Instantiate an observation configuration from a YAML file."""
163
+
164
+ with open(file_path, "r", encoding="utf-8") as handle:
165
+ config_dict = yaml.safe_load(handle) or {}
166
+
167
+ if not isinstance(config_dict, dict):
168
+ raise TypeError("Expected YAML content to be a mapping.")
169
+
170
+ if section is not None:
171
+ if section not in config_dict:
172
+ raise KeyError(f"Section '{section}' not found in YAML file '{file_path}'.")
173
+ config_dict = config_dict.get(section) or {}
174
+ else:
175
+ potential_sections = [
176
+ key for key in ("context_observations", "target_observations") if key in config_dict
177
+ ]
178
+ if len(potential_sections) > 1:
179
+ raise ValueError(
180
+ "Multiple observation sections found; specify which one to load using the 'section' argument."
181
+ )
182
+ if potential_sections:
183
+ config_dict = config_dict.get(potential_sections[0]) or {}
184
+
185
+ if not isinstance(config_dict, dict):
186
+ raise TypeError("Expected observation configuration to be provided as a mapping.")
187
+
188
+ return cls(**config_dict)
189
+
190
+
191
+ @dataclass
192
+ class MixDataConfig:
193
+ """
194
+ Here we specify how do we construct the mix databatch,
195
+ i.e. if we treat as a the decoder variable one full path
196
+ or if we treat as the decoder variable the future steps of a paht
197
+ """
198
+
199
+ test_empirical_datasets: List[str] = field(default_factory=lambda: ["cesarali/lenuzza-2016"])
200
+ # Deprecated fields removed (unused in current training flow):
201
+ # pretraining_*, val_protocol, test_protocol, split_strategy, split_seed.
202
+ evaluate_prediction_steps_past: int = 4 # lenght of past is kept fix for evaluation
203
+ sample_size_for_generative_evaluation_val: Optional[int] = None
204
+ # Number of generative samples (S) used for validation-time callback
205
+ # evaluation (new individuals and VPC/NPDE consumers). Defaults to 10.
206
+ sample_size_for_generative_evaluation_end_of_training: Optional[int] = None
207
+ # Number of generative samples (S) used for end-of-training callback
208
+ # evaluation (empirical end hooks). Defaults to 500.
209
+ sample_size_for_generative_evaluation: Optional[int] = None
210
+ # Deprecated legacy alias for both values above. When set and the new
211
+ # fields are not provided, the same value is applied to both stages.
212
+ # Value/time normalization flags consumed by PKScaler.
213
+ # Precedence for value scaling:
214
+ # 1) log_and_z=True -> "log_and_z"
215
+ # 2) log_and_max=True -> "log_and_max"
216
+ # 3) log_transform=True -> "log"
217
+ # 4) z_score_normalization=True -> "zscore"
218
+ # 5) normalize_by_max=True -> "max"
219
+ # 6) otherwise -> "none"
220
+ z_score_normalization: bool = False
221
+ # Explicit single switch for log + z-score scaling in PKScaler.
222
+ log_and_z: bool = False
223
+ # Explicit single switch for log + max scaling in PKScaler.
224
+ log_and_max: bool = False
225
+ normalize_by_max: bool = True
226
+ normalize_time: bool = True
227
+
228
+ n_of_permutations: int = 1
229
+ n_of_databatches: Optional[int] = None # deprecated alias
230
+ n_of_target_individuals: int = 1 # ignored for LOO/NO_TARGET
231
+ # Log-only transform flag consumed by PKScaler (value_method="log").
232
+ # This is no longer handled in the dataset/datamodule path.
233
+ log_transform: bool = False # Matches node-pk-1804.yaml
234
+
235
+ store_in_tempfile: bool = False # When True dataset is generated and saved to a temporary file
236
+ keep_tempfile: bool = False # Don't delete the temporary file on cleanup
237
+ recreate_tempfile: bool = False # Regenerate file even if it already exists
238
+
239
+ tempfile_path: Tuple[str, str] = (
240
+ "preprocessed",
241
+ "simulated_ou_as_rates.tr",
242
+ )
243
+
244
+ tqdm_progress: bool = False # Show progress bar when generating temp files
245
+ # DATA SIZES
246
+ train_size: int = 1000
247
+ val_size: int = 100
248
+ test_size: int = 100
249
+
250
+ def __post_init__(self) -> None:
251
+ if self.n_of_databatches is not None and self.n_of_permutations == 1:
252
+ self.n_of_permutations = self.n_of_databatches
253
+ warnings.warn(
254
+ "n_of_databatches is deprecated; use n_of_permutations",
255
+ DeprecationWarning,
256
+ )
257
+ legacy_sample_size = self.sample_size_for_generative_evaluation
258
+ if (
259
+ self.sample_size_for_generative_evaluation_val is None
260
+ and legacy_sample_size is not None
261
+ ):
262
+ self.sample_size_for_generative_evaluation_val = int(legacy_sample_size)
263
+ if (
264
+ self.sample_size_for_generative_evaluation_end_of_training is None
265
+ and legacy_sample_size is not None
266
+ ):
267
+ self.sample_size_for_generative_evaluation_end_of_training = int(legacy_sample_size)
268
+
269
+ if self.sample_size_for_generative_evaluation_val is None:
270
+ self.sample_size_for_generative_evaluation_val = 10
271
+ if self.sample_size_for_generative_evaluation_end_of_training is None:
272
+ self.sample_size_for_generative_evaluation_end_of_training = 500
273
+
274
+ if int(self.sample_size_for_generative_evaluation_val) < 1:
275
+ raise ValueError("sample_size_for_generative_evaluation_val must be >= 1")
276
+ if int(self.sample_size_for_generative_evaluation_end_of_training) < 1:
277
+ raise ValueError("sample_size_for_generative_evaluation_end_of_training must be >= 1")
278
+
279
+ self.sample_size_for_generative_evaluation_val = int(
280
+ self.sample_size_for_generative_evaluation_val
281
+ )
282
+ self.sample_size_for_generative_evaluation_end_of_training = int(
283
+ self.sample_size_for_generative_evaluation_end_of_training
284
+ )
285
+
286
+ if legacy_sample_size is not None:
287
+ warnings.warn(
288
+ "sample_size_for_generative_evaluation is deprecated; use "
289
+ "sample_size_for_generative_evaluation_val and "
290
+ "sample_size_for_generative_evaluation_end_of_training",
291
+ DeprecationWarning,
292
+ )
293
+ if self.n_of_permutations < 1:
294
+ raise ValueError("n_of_permutations must be >= 1")
295
+
296
+ @classmethod
297
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "MixDataConfig":
298
+ """Instantiate the mix-data configuration from a YAML file."""
299
+
300
+ with open(file_path, "r", encoding="utf-8") as handle:
301
+ config_dict = yaml.safe_load(handle) or {}
302
+
303
+ if isinstance(config_dict, dict):
304
+ for key in ("mix_data", "mix_data_config"):
305
+ if key in config_dict and isinstance(config_dict[key], dict):
306
+ config_dict = config_dict[key]
307
+ break
308
+
309
+ if not isinstance(config_dict, dict):
310
+ raise TypeError("Expected mix data configuration to be provided as a mapping.")
311
+
312
+ return cls(**config_dict)
313
+
314
+
315
+ @dataclass
316
+ class MetaDosingConfig:
317
+ """
318
+ Config for specifying meta dosing information.
319
+ """
320
+
321
+ num_individuals: int = 10
322
+ same_route: bool = True
323
+ logdose_mean_range: Tuple[float, float] = (-2, 2)
324
+ logdose_std_range: Tuple[float, float] = (0.1, 0.5)
325
+ route_options: List[str] = field(default_factory=lambda: ["oral", "iv"])
326
+ route_weights: List[float] = field(default_factory=lambda: [0.8, 0.2])
327
+ time: float = 0.0
328
+
329
+ @classmethod
330
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "MetaDosingConfig":
331
+ """Instantiate the meta-dosing configuration from a YAML file."""
332
+
333
+ with open(file_path, "r", encoding="utf-8") as handle:
334
+ config_dict = yaml.safe_load(handle) or {}
335
+
336
+ if isinstance(config_dict, dict) and "dosing" in config_dict:
337
+ config_dict = config_dict.get("dosing") or {}
338
+
339
+ if not isinstance(config_dict, dict):
340
+ raise TypeError("Expected 'dosing' section in YAML to be a mapping.")
341
+
342
+ return cls(**config_dict)
343
+
344
+ @dataclass
345
+ class MetaDosingWithDurationConfig(MetaDosingConfig):
346
+ """
347
+ Config for specifying meta dosing information including iv infusions.
348
+ """
349
+
350
+ route_duration_weights: Dict[str,float] = field(default_factory=lambda: {"oral": 0.0, "iv": 0.5}) # no duration for oral, 50% chance of infusion for iv
351
+ duration_range: Tuple[float, float] = (0.5, 2.0) # Duration of infusion; 0.0 means bolus
352
+
353
+ @dataclass
354
+ class DosingConfig:
355
+ """
356
+ Config for specifying dosing information. For now, it just holds the amount D of a single oral dose
357
+ given at time t = 0.
358
+ """
359
+
360
+ dose: float = 1.0
361
+ route: str = "oral"
362
+ time: float = 0.0
363
+
364
+
365
+ @dataclass
366
+ class DosingWithDurationConfig:
367
+ """
368
+ Config for specifying dosing information. It holds the amount D of a dose
369
+ given at time t = 0, optionally with an infusion duration.
370
+ """
371
+
372
+ dose: float = 1.0
373
+ route: str = "oral"
374
+ time: float = 0.0
375
+ duration: float = 0.0 # Duration of infusion; 0.0 means bolus
sim_priors_pk/config_classes/diffusion_pk_config.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import yaml # type: ignore
7
+
8
+ from sim_priors_pk.config_classes.data_config import (
9
+ MetaDosingConfig,
10
+ MetaStudyConfig,
11
+ MixDataConfig,
12
+ ObservationsConfig,
13
+ SimpleMetaStudyConfig,
14
+ )
15
+ from sim_priors_pk.config_classes.node_pk_config import EncoderDecoderNetworkConfig
16
+ from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
17
+ from sim_priors_pk.config_classes.training_config import TrainingConfig
18
+ from sim_priors_pk.config_classes.utils import TupleSafeLoader
19
+
20
+
21
+ @dataclass
22
+ class DiffusionPKExperimentConfig:
23
+ """Experiment configuration dedicated to diffusion PK models."""
24
+
25
+ experiment_type: str = "diffusionpk"
26
+ name_str: str = "ContinuousDiffusionPK"
27
+ diffusion_type: str = "continuous" # "continuous" or "discrete"
28
+
29
+ comet_ai_key: str = None
30
+ experiment_name: str = "diffusion_pk_compartments"
31
+ hugging_face_token: str = None
32
+ upload_to_hf_hub: bool = True
33
+ hf_model_name: str = "DiffusionPK_test"
34
+ hf_model_card_path: Tuple[str, str, str] = ("hf_model_card", "DIFFUSION-PK_Readme.md")
35
+
36
+ tags: List[str] = field(default_factory=lambda: ["diffusion-pk", "B-0"])
37
+ experiment_indentifier: str = None
38
+ my_results_path: str = None
39
+ experiment_dir: str = None
40
+ verbose: bool = False
41
+ run_index: int = 0
42
+ debug_test: bool = False
43
+
44
+ # Diffusion training knob: predict unit Gaussian noise or correlated noise.
45
+ predict_gaussian_noise: bool = True
46
+
47
+ network: EncoderDecoderNetworkConfig = field(default_factory=EncoderDecoderNetworkConfig)
48
+ source_process: SourceProcessConfig = field(default_factory=SourceProcessConfig)
49
+ mix_data: MixDataConfig = field(default_factory=MixDataConfig)
50
+
51
+ context_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
52
+ target_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
53
+
54
+ meta_study: MetaStudyConfig = field(default_factory=MetaStudyConfig)
55
+ dosing: MetaDosingConfig = field(default_factory=MetaDosingConfig)
56
+
57
+ train: TrainingConfig = field(default_factory=TrainingConfig)
58
+
59
+ @staticmethod
60
+ def from_yaml(file_path: str) -> "DiffusionPKExperimentConfig":
61
+ """Initializes the class from a YAML file."""
62
+
63
+ with open(file_path, "r") as file:
64
+ config_dict = yaml.load(file, Loader=TupleSafeLoader) or {}
65
+
66
+ if not isinstance(config_dict, dict):
67
+ raise TypeError("Expected experiment YAML to be a mapping.")
68
+
69
+ exp_type = config_dict.get("experiment_type")
70
+ if exp_type is not None and str(exp_type).lower() != "diffusionpk":
71
+ raise ValueError(
72
+ "Expected experiment_type 'diffusionpk' for DiffusionPKExperimentConfig, "
73
+ f"got {exp_type!r}."
74
+ )
75
+
76
+ base_dir = os.path.dirname(os.path.abspath(file_path))
77
+
78
+ data_cfg_dict = (
79
+ DiffusionPKExperimentConfig._load_ref_yaml(config_dict.get("data_config"), base_dir)
80
+ or {}
81
+ )
82
+ training_cfg_dict = (
83
+ DiffusionPKExperimentConfig._load_ref_yaml(config_dict.get("training_config"), base_dir)
84
+ or {}
85
+ )
86
+ model_cfg_dict = (
87
+ DiffusionPKExperimentConfig._load_ref_yaml(config_dict.get("model_config"), base_dir)
88
+ or {}
89
+ )
90
+
91
+ observations_section = DiffusionPKExperimentConfig._resolve_config_section(
92
+ config_dict, base_dir, "observations_config"
93
+ )
94
+ if observations_section is None:
95
+ observations_section = DiffusionPKExperimentConfig._resolve_config_section(
96
+ data_cfg_dict, base_dir, "observations_config"
97
+ )
98
+ if observations_section is not None:
99
+ context_observations_base = observations_section.get("context_observations")
100
+ target_observations_base = observations_section.get("target_observations")
101
+ else:
102
+ context_observations_base = data_cfg_dict.get("context_observations")
103
+ target_observations_base = data_cfg_dict.get("target_observations")
104
+
105
+ mix_data_section = DiffusionPKExperimentConfig._resolve_config_section(
106
+ config_dict, base_dir, "mix_data_config"
107
+ )
108
+ if mix_data_section is None:
109
+ mix_data_section = DiffusionPKExperimentConfig._resolve_config_section(
110
+ data_cfg_dict, base_dir, "mix_data_config"
111
+ )
112
+ if mix_data_section is None:
113
+ mix_data_section = data_cfg_dict.get("mix_data")
114
+
115
+ meta_study_section = DiffusionPKExperimentConfig._resolve_config_section(
116
+ config_dict, base_dir, "meta_study_config"
117
+ )
118
+ if meta_study_section is None:
119
+ meta_study_section = DiffusionPKExperimentConfig._resolve_config_section(
120
+ data_cfg_dict, base_dir, "meta_study_config"
121
+ )
122
+ meta_dosing_section = DiffusionPKExperimentConfig._resolve_config_section(
123
+ config_dict, base_dir, "meta_dosing_config"
124
+ )
125
+ if meta_dosing_section is None:
126
+ meta_dosing_section = DiffusionPKExperimentConfig._resolve_config_section(
127
+ data_cfg_dict, base_dir, "meta_dosing_config"
128
+ )
129
+
130
+ meta_study_base = DiffusionPKExperimentConfig._extract_config_mapping(
131
+ meta_study_section, "meta_study"
132
+ )
133
+ if meta_study_base is None and meta_dosing_section is not None:
134
+ meta_study_base = DiffusionPKExperimentConfig._extract_config_mapping(
135
+ meta_dosing_section, "meta_study"
136
+ )
137
+ if meta_study_base is None:
138
+ meta_study_base = data_cfg_dict.get("meta_study")
139
+
140
+ dosing_base = DiffusionPKExperimentConfig._extract_config_mapping(
141
+ meta_dosing_section, "dosing"
142
+ )
143
+ if dosing_base is None:
144
+ dosing_base = data_cfg_dict.get("dosing")
145
+
146
+ mix_data_cfg = DiffusionPKExperimentConfig._merge_dicts(
147
+ mix_data_section, config_dict.get("mix_data")
148
+ )
149
+ context_obs_cfg = DiffusionPKExperimentConfig._merge_dicts(
150
+ context_observations_base, config_dict.get("context_observations")
151
+ )
152
+ target_obs_cfg = DiffusionPKExperimentConfig._merge_dicts(
153
+ target_observations_base, config_dict.get("target_observations")
154
+ )
155
+ meta_study_cfg = DiffusionPKExperimentConfig._merge_dicts(
156
+ meta_study_base, config_dict.get("meta_study")
157
+ )
158
+ dosing_cfg = DiffusionPKExperimentConfig._merge_dicts(
159
+ dosing_base, config_dict.get("dosing")
160
+ )
161
+
162
+ train_section = training_cfg_dict.get("train", training_cfg_dict)
163
+ train_cfg = DiffusionPKExperimentConfig._merge_dicts(
164
+ train_section, config_dict.get("train")
165
+ )
166
+
167
+ network_section = model_cfg_dict.get("network", model_cfg_dict)
168
+ network_cfg = DiffusionPKExperimentConfig._merge_dicts(
169
+ network_section, config_dict.get("network")
170
+ )
171
+
172
+ source_section = DiffusionPKExperimentConfig._resolve_config_section(
173
+ config_dict, base_dir, "source_config"
174
+ )
175
+ if source_section is None:
176
+ source_section = DiffusionPKExperimentConfig._resolve_config_section(
177
+ model_cfg_dict, base_dir, "source_config"
178
+ )
179
+ if source_section is None:
180
+ source_section = model_cfg_dict.get("source_process") or model_cfg_dict.get("noise_model")
181
+ source_section = DiffusionPKExperimentConfig._extract_config_mapping(
182
+ source_section, "source_process"
183
+ )
184
+ if isinstance(source_section, dict) and "noise_model" in source_section:
185
+ source_section = source_section.get("noise_model")
186
+ source_cfg = DiffusionPKExperimentConfig._merge_dicts(
187
+ source_section, config_dict.get("source_process")
188
+ )
189
+
190
+ if meta_study_cfg.get("simple_mode", False):
191
+ meta_study_instance = SimpleMetaStudyConfig(**meta_study_cfg)
192
+ else:
193
+ meta_study_instance = MetaStudyConfig(**meta_study_cfg)
194
+
195
+ train_cfg = TrainingConfig._filter_kwargs(train_cfg)
196
+
197
+ return DiffusionPKExperimentConfig(
198
+ experiment_type=str(config_dict.get("experiment_type", "diffusionpk")).lower(),
199
+ name_str=config_dict.get("name_str", "ContinuousDiffusionPK"),
200
+ diffusion_type=config_dict.get("diffusion_type", "continuous"),
201
+ tags=config_dict.get("tags", ["diffusion-pk", "B-0"]),
202
+ experiment_name=config_dict.get("experiment_name", "diffusion_pk_compartments"),
203
+ experiment_indentifier=config_dict.get("experiment_indentifier", None),
204
+ my_results_path=config_dict.get("my_results_path", None),
205
+ experiment_dir=config_dict.get("experiment_dir", None),
206
+ comet_ai_key=config_dict.get("comet_ai_key", None),
207
+ hugging_face_token=config_dict.get("hugging_face_token", None),
208
+ upload_to_hf_hub=config_dict.get("upload_to_hf_hub", True),
209
+ hf_model_name=config_dict.get("hf_model_name", "DiffusionPK_test"),
210
+ hf_model_card_path=tuple(
211
+ config_dict.get(
212
+ "hf_model_card_path", ("hf_model_card", "DIFFUSION-PK_Readme.md")
213
+ )
214
+ ),
215
+ debug_test=config_dict.get("debug_test", False),
216
+ predict_gaussian_noise=bool(config_dict.get("predict_gaussian_noise", True)),
217
+ network=EncoderDecoderNetworkConfig(**network_cfg),
218
+ source_process=SourceProcessConfig(**source_cfg),
219
+ mix_data=MixDataConfig(**mix_data_cfg),
220
+ context_observations=ObservationsConfig(**context_obs_cfg),
221
+ target_observations=ObservationsConfig(**target_obs_cfg),
222
+ meta_study=meta_study_instance,
223
+ dosing=MetaDosingConfig(**dosing_cfg),
224
+ train=TrainingConfig(**train_cfg),
225
+ )
226
+
227
+ @staticmethod
228
+ def _merge_dicts(
229
+ base_dict: Optional[Dict[str, Any]], override_dict: Optional[Dict[str, Any]]
230
+ ) -> Dict[str, Any]:
231
+ """Merge two optional dictionaries returning a new dictionary."""
232
+
233
+ merged: Dict[str, Any] = {}
234
+
235
+ if base_dict:
236
+ if not isinstance(base_dict, dict):
237
+ raise TypeError(
238
+ "Expected base_dict to be a mapping when merging configuration sections."
239
+ )
240
+ merged = deepcopy(base_dict)
241
+
242
+ if override_dict:
243
+ if not isinstance(override_dict, dict):
244
+ raise TypeError(
245
+ "Expected override_dict to be a mapping when merging configuration sections."
246
+ )
247
+ merged.update(override_dict)
248
+
249
+ return merged
250
+
251
+ @staticmethod
252
+ def _extract_config_mapping(
253
+ section: Optional[Dict[str, Any]], nested_key: str
254
+ ) -> Optional[Dict[str, Any]]:
255
+ """Return a nested configuration mapping or the section itself."""
256
+
257
+ if section is None:
258
+ return None
259
+
260
+ if not isinstance(section, dict):
261
+ raise TypeError(
262
+ "Expected configuration section to be a mapping when extracting nested"
263
+ f" '{nested_key}' values."
264
+ )
265
+
266
+ if nested_key in section:
267
+ nested_value = section[nested_key]
268
+ if nested_value is None:
269
+ return None
270
+ if not isinstance(nested_value, dict):
271
+ raise TypeError(
272
+ f"Expected '{nested_key}' section to be a mapping when extracting configuration values."
273
+ )
274
+ return nested_value
275
+
276
+ return section
277
+
278
+ @staticmethod
279
+ def _load_ref_yaml(
280
+ ref: Optional[Union[str, Dict[str, Any]]], base_dir: str
281
+ ) -> Optional[Dict[str, Any]]:
282
+ """Load a referenced YAML block or return inline dictionaries as-is."""
283
+
284
+ if ref is None:
285
+ return None
286
+
287
+ if isinstance(ref, dict):
288
+ return ref
289
+
290
+ if isinstance(ref, str):
291
+ ref_path = ref
292
+ if not os.path.isabs(ref_path):
293
+ ref_path = os.path.join(base_dir, ref_path)
294
+
295
+ with open(ref_path, "r") as handle:
296
+ return yaml.load(handle, Loader=TupleSafeLoader) or {}
297
+
298
+ raise TypeError("Expected configuration reference to be a mapping or string path.")
299
+
300
+ @staticmethod
301
+ def _resolve_config_section(
302
+ cfg_dict: Dict[str, Any], base_dir: str, key: str
303
+ ) -> Optional[Dict[str, Any]]:
304
+ """Resolve nested configuration references within a configuration block."""
305
+
306
+ if key not in cfg_dict:
307
+ return None
308
+
309
+ section = cfg_dict[key]
310
+
311
+ if section is None:
312
+ return None
313
+
314
+ if isinstance(section, dict):
315
+ ref_value = section.get("_ref") if "_ref" in section else None
316
+ if ref_value is not None:
317
+ loaded = DiffusionPKExperimentConfig._load_ref_yaml(ref_value, base_dir)
318
+ return loaded or {}
319
+ return section
320
+
321
+ if isinstance(section, str):
322
+ loaded = DiffusionPKExperimentConfig._load_ref_yaml(section, base_dir)
323
+ return loaded or {}
324
+
325
+ raise TypeError(
326
+ f"Expected configuration section '{key}' to be a mapping or string reference."
327
+ )
sim_priors_pk/config_classes/flow_pk_config.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from copy import deepcopy
4
+ from dataclasses import asdict, dataclass, field, fields
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ try: # pragma: no cover - exercised indirectly via configuration loading
8
+ import yaml # type: ignore
9
+ except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
10
+ from sim_priors_pk.config_classes import yaml_fallback as yaml
11
+
12
+ try: # pragma: no cover - optional dependency for HF integration
13
+ from transformers import PretrainedConfig # type: ignore
14
+ except ModuleNotFoundError: # pragma: no cover - allow configuration utilities without transformers
15
+
16
+ class PretrainedConfig: # type: ignore
17
+ def __init__(self, **kwargs):
18
+ super().__init__()
19
+
20
+ from sim_priors_pk.config_classes.data_config import (
21
+ MetaDosingConfig,
22
+ MetaStudyConfig,
23
+ MixDataConfig,
24
+ ObservationsConfig,
25
+ SimpleMetaStudyConfig,
26
+ )
27
+ from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
28
+ from sim_priors_pk.config_classes.training_config import TrainingConfig
29
+ from sim_priors_pk.config_classes.utils import TupleSafeLoader
30
+
31
+
32
+ def _to_float(x: Any) -> float:
33
+ try:
34
+ v = float(x)
35
+ except Exception:
36
+ return math.inf
37
+ # guard against NaN
38
+ if math.isnan(v):
39
+ return math.inf
40
+ return v
41
+
42
+
43
+ def _raise_flowpk_network_migration() -> None:
44
+ raise ValueError(
45
+ "FlowPK configs no longer accept a 'network' section. "
46
+ "Please rename 'network' to 'vector_field' and set 'experiment_type: flowpk' in your YAML."
47
+ )
48
+
49
+
50
+ @dataclass
51
+ class VectorFieldPKConfig:
52
+ """Configuration for the transformer vector field used by FlowPK."""
53
+
54
+ # Transformer vector field configuration
55
+ hidden_dim: int = 64
56
+ fourier_modes: int = 16
57
+ use_spectral_qkv: bool = False
58
+ time_fourier_max_freq: int = 64
59
+ encoder_num_heads: int = 4
60
+ decoder_num_heads: int = 4
61
+ encoder_attention_layers: int = 2
62
+ decoder_attention_layers: int = 2
63
+ dropout: float = 0.0
64
+
65
+ # # Latent/conditioning settings required by the vector field implementation
66
+ cov_proj_dim: int = 16 # p in the paper
67
+ combine_latent_mode: str = "mlp" # Options: "mlp", "sum"
68
+ zi_latent_dim: int = 200
69
+
70
+ @classmethod
71
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "VectorFieldPKConfig":
72
+ """Instantiate the vector field configuration from a YAML file."""
73
+
74
+ with open(file_path, "r", encoding="utf-8") as handle:
75
+ config_dict = yaml.safe_load(handle) or {}
76
+
77
+ if isinstance(config_dict, dict) and "network" in config_dict:
78
+ _raise_flowpk_network_migration()
79
+
80
+ if isinstance(config_dict, dict) and "vector_field" in config_dict:
81
+ config_dict = config_dict.get("vector_field") or {}
82
+
83
+ if not isinstance(config_dict, dict):
84
+ raise TypeError("Expected 'vector_field' section in YAML to be a mapping.")
85
+
86
+ return cls(**config_dict)
87
+
88
+
89
+ @dataclass
90
+ class FlowPKExperimentConfig:
91
+ """Experiment configuration for FlowPK (vector field only)."""
92
+
93
+ experiment_type: str = "flowpk"
94
+ name_str: str = "FlowPK"
95
+ comet_ai_key: str = None
96
+ experiment_name: str = "flow_pk_compartments"
97
+ hugging_face_token: str = None
98
+ upload_to_hf_hub: bool = True
99
+ hf_model_name: str = "FlowPK_test"
100
+ hf_model_card_path: Tuple[str, str, str] = ("hf_model_card", "CVAE_Readme.md")
101
+
102
+ tags: List[str] = field(default_factory=lambda: ["flow-pk", "B-0"])
103
+ experiment_indentifier: str = None
104
+ my_results_path: str = None
105
+ experiment_dir: str = None
106
+ verbose: bool = False
107
+ run_index: int = 0
108
+ debug_test: bool = False
109
+ # Default Euler integration steps used by FlowPK sampling when callers
110
+ # do not provide ``num_steps`` explicitly (for example VPC callbacks).
111
+ flow_num_steps: int = 50
112
+
113
+ vector_field: VectorFieldPKConfig = field(default_factory=VectorFieldPKConfig)
114
+ source_process: SourceProcessConfig = field(default_factory=SourceProcessConfig)
115
+ mix_data: MixDataConfig = field(default_factory=MixDataConfig)
116
+
117
+ context_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
118
+ target_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
119
+
120
+ meta_study: MetaStudyConfig = field(default_factory=MetaStudyConfig)
121
+ dosing: MetaDosingConfig = field(default_factory=MetaDosingConfig)
122
+
123
+ train: TrainingConfig = field(default_factory=TrainingConfig)
124
+
125
+ @staticmethod
126
+ def from_yaml(file_path: str) -> "FlowPKExperimentConfig":
127
+ """Initializes the class from a YAML file.
128
+
129
+ Supports both monolithic experiment YAML files as well as files that
130
+ reference dedicated data, training, and model configuration YAMLs.
131
+ """
132
+
133
+ with open(file_path, "r") as file:
134
+ config_dict = yaml.load(file, Loader=TupleSafeLoader) or {}
135
+
136
+ if not isinstance(config_dict, dict):
137
+ raise TypeError("Expected experiment YAML to be a mapping.")
138
+
139
+ exp_type = config_dict.get("experiment_type")
140
+ if exp_type is not None and str(exp_type).lower() != "flowpk":
141
+ raise ValueError(
142
+ f"Expected experiment_type 'flowpk' for FlowPKExperimentConfig, got {exp_type!r}."
143
+ )
144
+
145
+ if "network" in config_dict:
146
+ _raise_flowpk_network_migration()
147
+
148
+ base_dir = os.path.dirname(os.path.abspath(file_path))
149
+
150
+ data_cfg_dict = (
151
+ FlowPKExperimentConfig._load_ref_yaml(config_dict.get("data_config"), base_dir) or {}
152
+ )
153
+ training_cfg_dict = (
154
+ FlowPKExperimentConfig._load_ref_yaml(config_dict.get("training_config"), base_dir)
155
+ or {}
156
+ )
157
+ model_cfg_dict = (
158
+ FlowPKExperimentConfig._load_ref_yaml(config_dict.get("model_config"), base_dir) or {}
159
+ )
160
+
161
+ if isinstance(model_cfg_dict, dict) and "network" in model_cfg_dict:
162
+ _raise_flowpk_network_migration()
163
+
164
+ observations_section = FlowPKExperimentConfig._resolve_config_section(
165
+ config_dict, base_dir, "observations_config"
166
+ )
167
+ if observations_section is None:
168
+ observations_section = FlowPKExperimentConfig._resolve_config_section(
169
+ data_cfg_dict, base_dir, "observations_config"
170
+ )
171
+ if observations_section is not None:
172
+ context_observations_base = observations_section.get("context_observations")
173
+ target_observations_base = observations_section.get("target_observations")
174
+ else:
175
+ context_observations_base = data_cfg_dict.get("context_observations")
176
+ target_observations_base = data_cfg_dict.get("target_observations")
177
+
178
+ mix_data_section = FlowPKExperimentConfig._resolve_config_section(
179
+ config_dict, base_dir, "mix_data_config"
180
+ )
181
+ if mix_data_section is None:
182
+ mix_data_section = FlowPKExperimentConfig._resolve_config_section(
183
+ data_cfg_dict, base_dir, "mix_data_config"
184
+ )
185
+ if mix_data_section is None:
186
+ mix_data_section = data_cfg_dict.get("mix_data")
187
+
188
+ meta_study_section = FlowPKExperimentConfig._resolve_config_section(
189
+ config_dict, base_dir, "meta_study_config"
190
+ )
191
+ if meta_study_section is None:
192
+ meta_study_section = FlowPKExperimentConfig._resolve_config_section(
193
+ data_cfg_dict, base_dir, "meta_study_config"
194
+ )
195
+ meta_dosing_section = FlowPKExperimentConfig._resolve_config_section(
196
+ config_dict, base_dir, "meta_dosing_config"
197
+ )
198
+ if meta_dosing_section is None:
199
+ meta_dosing_section = FlowPKExperimentConfig._resolve_config_section(
200
+ data_cfg_dict, base_dir, "meta_dosing_config"
201
+ )
202
+
203
+ meta_study_base = FlowPKExperimentConfig._extract_config_mapping(
204
+ meta_study_section, "meta_study"
205
+ )
206
+ if meta_study_base is None and meta_dosing_section is not None:
207
+ meta_study_base = FlowPKExperimentConfig._extract_config_mapping(
208
+ meta_dosing_section, "meta_study"
209
+ )
210
+ if meta_study_base is None:
211
+ meta_study_base = data_cfg_dict.get("meta_study")
212
+
213
+ dosing_base = FlowPKExperimentConfig._extract_config_mapping(meta_dosing_section, "dosing")
214
+ if dosing_base is None:
215
+ dosing_base = data_cfg_dict.get("dosing")
216
+
217
+ mix_data_inline = config_dict.get("mix_data")
218
+ if mix_data_inline is not None and not isinstance(mix_data_inline, dict):
219
+ raise TypeError("Expected 'mix_data' section in experiment YAML to be a mapping.")
220
+
221
+ # Backward compatibility: allow mix-data keys at experiment top-level.
222
+ # Nested `mix_data:` values take precedence over these legacy top-level keys.
223
+ legacy_mix_data_inline = {
224
+ field_meta.name: config_dict[field_meta.name]
225
+ for field_meta in fields(MixDataConfig)
226
+ if field_meta.name in config_dict
227
+ }
228
+ merged_mix_data_inline = dict(legacy_mix_data_inline)
229
+ if isinstance(mix_data_inline, dict):
230
+ merged_mix_data_inline.update(mix_data_inline)
231
+
232
+ mix_data_cfg = FlowPKExperimentConfig._merge_dicts(
233
+ mix_data_section, merged_mix_data_inline
234
+ )
235
+ context_obs_cfg = FlowPKExperimentConfig._merge_dicts(
236
+ context_observations_base, config_dict.get("context_observations")
237
+ )
238
+ target_obs_cfg = FlowPKExperimentConfig._merge_dicts(
239
+ target_observations_base, config_dict.get("target_observations")
240
+ )
241
+ meta_study_cfg = FlowPKExperimentConfig._merge_dicts(
242
+ meta_study_base, config_dict.get("meta_study")
243
+ )
244
+ dosing_cfg = FlowPKExperimentConfig._merge_dicts(dosing_base, config_dict.get("dosing"))
245
+
246
+ train_section = training_cfg_dict.get("train", training_cfg_dict)
247
+ train_cfg = FlowPKExperimentConfig._merge_dicts(train_section, config_dict.get("train"))
248
+
249
+ vector_field_section = model_cfg_dict.get("vector_field", model_cfg_dict)
250
+ if isinstance(vector_field_section, dict) and "network" in vector_field_section:
251
+ _raise_flowpk_network_migration()
252
+ vector_field_cfg = FlowPKExperimentConfig._merge_dicts(
253
+ vector_field_section, config_dict.get("vector_field")
254
+ )
255
+
256
+ source_section = FlowPKExperimentConfig._resolve_config_section(
257
+ config_dict, base_dir, "source_config"
258
+ )
259
+ if source_section is None:
260
+ source_section = FlowPKExperimentConfig._resolve_config_section(
261
+ model_cfg_dict, base_dir, "source_config"
262
+ )
263
+ if source_section is None:
264
+ source_section = model_cfg_dict.get("source_process") or model_cfg_dict.get("noise_model")
265
+ source_section = FlowPKExperimentConfig._extract_config_mapping(
266
+ source_section, "source_process"
267
+ )
268
+ if isinstance(source_section, dict) and "noise_model" in source_section:
269
+ source_section = source_section.get("noise_model")
270
+
271
+ source_cfg = FlowPKExperimentConfig._merge_dicts(
272
+ source_section, config_dict.get("source_process")
273
+ )
274
+
275
+ # -----------------------------------------------------------------
276
+ # Choose MetaStudy class dynamically (simple vs full)
277
+ # -----------------------------------------------------------------
278
+ if meta_study_cfg.get("simple_mode", False):
279
+ meta_study_instance = SimpleMetaStudyConfig(**meta_study_cfg)
280
+ else:
281
+ meta_study_instance = MetaStudyConfig(**meta_study_cfg)
282
+
283
+ train_cfg = TrainingConfig._filter_kwargs(train_cfg)
284
+
285
+ return FlowPKExperimentConfig(
286
+ experiment_type=str(config_dict.get("experiment_type", "flowpk")).lower(),
287
+ name_str=config_dict.get("name_str", "FlowPK"),
288
+ tags=config_dict.get("tags", ["flow-pk", "B-0"]),
289
+ experiment_name=config_dict.get("experiment_name", "flow_pk_compartments"),
290
+ experiment_indentifier=config_dict.get("experiment_indentifier", None),
291
+ my_results_path=config_dict.get("my_results_path", None),
292
+ experiment_dir=config_dict.get("experiment_dir", None),
293
+ comet_ai_key=config_dict.get("comet_ai_key", None),
294
+ hugging_face_token=config_dict.get("hugging_face_token", None),
295
+ upload_to_hf_hub=config_dict.get("upload_to_hf_hub", True),
296
+ hf_model_name=config_dict.get("hf_model_name", "FlowPK_test"),
297
+ hf_model_card_path=tuple(
298
+ config_dict.get("hf_model_card_path", ("hf_model_card", "CVAE_Readme.md"))
299
+ ),
300
+ debug_test=config_dict.get("debug_test", False),
301
+ flow_num_steps=int(config_dict.get("flow_num_steps", 50)),
302
+ vector_field=VectorFieldPKConfig(**vector_field_cfg),
303
+ source_process=SourceProcessConfig(**source_cfg),
304
+ mix_data=MixDataConfig(**mix_data_cfg),
305
+ context_observations=ObservationsConfig(**context_obs_cfg),
306
+ target_observations=ObservationsConfig(**target_obs_cfg),
307
+ meta_study=meta_study_instance,
308
+ dosing=MetaDosingConfig(**dosing_cfg),
309
+ train=TrainingConfig(**train_cfg),
310
+ )
311
+
312
+ @staticmethod
313
+ def _merge_dicts(
314
+ base_dict: Optional[Dict[str, Any]], override_dict: Optional[Dict[str, Any]]
315
+ ) -> Dict[str, Any]:
316
+ """Merge two optional dictionaries returning a new dictionary."""
317
+
318
+ merged: Dict[str, Any] = {}
319
+
320
+ if base_dict:
321
+ if not isinstance(base_dict, dict):
322
+ raise TypeError(
323
+ "Expected base_dict to be a mapping when merging configuration sections."
324
+ )
325
+ merged = deepcopy(base_dict)
326
+
327
+ if override_dict:
328
+ if not isinstance(override_dict, dict):
329
+ raise TypeError(
330
+ "Expected override_dict to be a mapping when merging configuration sections."
331
+ )
332
+ merged.update(override_dict)
333
+
334
+ return merged
335
+
336
+ @staticmethod
337
+ def _extract_config_mapping(
338
+ section: Optional[Dict[str, Any]], nested_key: str
339
+ ) -> Optional[Dict[str, Any]]:
340
+ """Return a nested configuration mapping or the section itself."""
341
+
342
+ if section is None:
343
+ return None
344
+
345
+ if not isinstance(section, dict):
346
+ raise TypeError(
347
+ "Expected configuration section to be a mapping when extracting nested"
348
+ f" '{nested_key}' values."
349
+ )
350
+
351
+ if nested_key in section:
352
+ nested_value = section[nested_key]
353
+ if nested_value is None:
354
+ return None
355
+ if not isinstance(nested_value, dict):
356
+ raise TypeError(
357
+ f"Expected '{nested_key}' section to be a mapping when extracting configuration values."
358
+ )
359
+ return nested_value
360
+
361
+ return section
362
+
363
+ @staticmethod
364
+ def _load_ref_yaml(
365
+ ref: Optional[Union[str, Dict[str, Any]]], base_dir: str
366
+ ) -> Optional[Dict[str, Any]]:
367
+ """Load a referenced YAML block or return inline dictionaries as-is."""
368
+
369
+ if ref is None:
370
+ return None
371
+
372
+ if isinstance(ref, dict):
373
+ return ref
374
+
375
+ if isinstance(ref, str):
376
+ ref_path = ref
377
+ if not os.path.isabs(ref_path):
378
+ ref_path = os.path.join(base_dir, ref_path)
379
+
380
+ with open(ref_path, "r") as handle:
381
+ return yaml.load(handle, Loader=TupleSafeLoader) or {}
382
+
383
+ raise TypeError("Expected configuration reference to be a mapping or string path.")
384
+
385
+ @staticmethod
386
+ def _resolve_config_section(
387
+ cfg_dict: Dict[str, Any], base_dir: str, key: str
388
+ ) -> Optional[Dict[str, Any]]:
389
+ """Resolve nested configuration references within a configuration block."""
390
+
391
+ if key not in cfg_dict:
392
+ return None
393
+
394
+ section = cfg_dict[key]
395
+
396
+ if section is None:
397
+ return None
398
+
399
+ if isinstance(section, dict):
400
+ ref_value = section.get("_ref") if "_ref" in section else None
401
+ if ref_value is not None:
402
+ loaded = FlowPKExperimentConfig._load_ref_yaml(ref_value, base_dir)
403
+ return loaded or {}
404
+ return section
405
+
406
+ if isinstance(section, str):
407
+ loaded = FlowPKExperimentConfig._load_ref_yaml(section, base_dir)
408
+ return loaded or {}
409
+
410
+ raise TypeError(
411
+ f"Expected configuration section '{key}' to be a mapping or string reference."
412
+ )
413
+
414
+ def to_yaml(self, file_path: str):
415
+ """Saves the class to a YAML file."""
416
+ with open(file_path, "w") as file:
417
+ yaml.dump(asdict(self), file, default_flow_style=False)
418
+
419
+
420
+ class HFFlowPKConfig(PretrainedConfig):
421
+ """
422
+ HF config wrapping FlowPKExperimentConfig plus tracked metrics.
423
+
424
+ Canonical storage:
425
+ self.tracking: dict with shape
426
+ {
427
+ "best": { "<metric_name>": {"value": float, "step": int|None, "epoch": int|None} },
428
+ "meta": { ...optional... }
429
+ }
430
+
431
+ Backward compat:
432
+ - Accepts legacy keys like best_val_loss / best_val_rmse.
433
+ - Mirrors best["val_rmse"] to `best_val_loss` if you still use that elsewhere.
434
+ """
435
+
436
+ model_type = "flow_pk"
437
+
438
+ def __init__(self, **kwargs):
439
+ # --- extract tracking / legacy keys before super().__init__ ---
440
+ tracking = kwargs.pop("tracking", None)
441
+
442
+ # legacy keys (accept either; normalize into tracking)
443
+ legacy_best_val_loss = kwargs.pop("best_val_loss", None)
444
+ legacy_best_val_rmse = kwargs.pop("best_val_rmse", None)
445
+
446
+ super().__init__(**kwargs)
447
+
448
+ # copy remaining config fields
449
+ for k, v in kwargs.items():
450
+ setattr(self, k, v)
451
+
452
+ # initialize tracking
453
+ if tracking is None or not isinstance(tracking, dict):
454
+ tracking = {"best": {}, "meta": {}}
455
+ tracking.setdefault("best", {})
456
+ tracking.setdefault("meta", {})
457
+ self.tracking: Dict[str, Any] = tracking
458
+
459
+ # fold legacy into canonical schema if present
460
+ legacy = legacy_best_val_loss if legacy_best_val_loss is not None else legacy_best_val_rmse
461
+ if legacy is not None:
462
+ # choose a canonical metric name; I'd recommend "val_rmse" if that’s what it is.
463
+ self.set_best("val_rmse", legacy)
464
+
465
+ # optional alias for older codepaths
466
+ self._sync_legacy_aliases()
467
+
468
+ # --------- public API ----------
469
+ def set_best(
470
+ self,
471
+ metric_name: str,
472
+ value: Any,
473
+ *,
474
+ step: Optional[int] = None,
475
+ epoch: Optional[int] = None,
476
+ ) -> None:
477
+ v = _to_float(value)
478
+ self.tracking["best"][metric_name] = {"value": v, "step": step, "epoch": epoch}
479
+ self._sync_legacy_aliases()
480
+
481
+ def get_best(self, metric_name: str, default: float = math.inf) -> float:
482
+ d = self.tracking.get("best", {}).get(metric_name)
483
+ if not d:
484
+ return float(default)
485
+ return _to_float(d.get("value", default))
486
+
487
+ def is_better(
488
+ self,
489
+ metric_name: str,
490
+ candidate_value: Any,
491
+ *,
492
+ higher_is_better: bool = False,
493
+ ) -> bool:
494
+ cand = _to_float(candidate_value)
495
+ best = self.get_best(metric_name, default=(-math.inf if higher_is_better else math.inf))
496
+ return cand > best if higher_is_better else cand < best
497
+
498
+ def update_if_better(
499
+ self,
500
+ metric_name: str,
501
+ candidate_value: Any,
502
+ *,
503
+ step: Optional[int] = None,
504
+ epoch: Optional[int] = None,
505
+ higher_is_better: bool = False,
506
+ ) -> bool:
507
+ if self.is_better(metric_name, candidate_value, higher_is_better=higher_is_better):
508
+ self.set_best(metric_name, candidate_value, step=step, epoch=epoch)
509
+ return True
510
+ return False
511
+
512
+ # --------- construction ----------
513
+ @classmethod
514
+ def from_flowpk(cls, flowpk_cfg, **tracked_best: float) -> "HFFlowPKConfig":
515
+ """
516
+ tracked_best: e.g. val_rmse=..., val_nll=..., val_crps=...
517
+ """
518
+ cfg_dict = asdict(flowpk_cfg)
519
+ cfg = cls(**cfg_dict)
520
+ for k, v in tracked_best.items():
521
+ cfg.set_best(k, v)
522
+ return cfg
523
+
524
+ # --------- internal ----------
525
+ def _sync_legacy_aliases(self) -> None:
526
+ """
527
+ Keep a legacy scalar field for older code that expects `best_val_loss`.
528
+ Here we mirror it to `best["val_rmse"]` by convention.
529
+ """
530
+ # if val_rmse exists, mirror it; otherwise inf
531
+ self.best_val_loss = self.get_best("val_rmse", default=math.inf)
532
+
533
+
534
+ FlowPKConfig = FlowPKExperimentConfig
sim_priors_pk/config_classes/node_pk_config.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from copy import deepcopy
4
+ from dataclasses import asdict, dataclass, field
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import yaml # type: ignore
8
+ from transformers import PretrainedConfig
9
+
10
+ from sim_priors_pk.config_classes.data_config import (
11
+ MetaDosingConfig,
12
+ MetaStudyConfig,
13
+ MixDataConfig,
14
+ ObservationsConfig,
15
+ SimpleMetaStudyConfig,
16
+ )
17
+ from sim_priors_pk.config_classes.training_config import TrainingConfig
18
+ from sim_priors_pk.config_classes.utils import TupleSafeLoader
19
+
20
+
21
+ def _to_float(x: Any) -> float:
22
+ try:
23
+ v = float(x)
24
+ except Exception:
25
+ return math.inf
26
+ # guard against NaN
27
+ if math.isnan(v):
28
+ return math.inf
29
+ return v
30
+
31
+
32
+ @dataclass
33
+ class EncoderDecoderNetworkConfig:
34
+ """
35
+ Configuration for the encoder-decoder network.
36
+ """
37
+
38
+ # Encoder configuration
39
+ individual_encoder_name: str = "RNNContextEncoder"
40
+ time_obs_encoder_hidden_dim: int = 200
41
+ time_obs_encoder_output_dim: int = 200
42
+ rnn_individual_encoder_number_of_layers: int = 2
43
+ individual_encoder_number_of_heads: int = 4
44
+ encoder_rnn_hidden_dim: int = 128
45
+ input_encoding_hidden_dim: int = 128
46
+ zi_latent_dim: int = 200
47
+ use_attention: bool = True
48
+ use_self_attention: bool = False
49
+ use_time_deltas: bool = True
50
+
51
+ # Decoder configuration
52
+ decoder_name: str = "RNNDecoder"
53
+ decoder_num_layers: int = 2
54
+ decoder_attention_layers: int = 2
55
+ decoder_hidden_dim: int = 128
56
+ decoder_rnn_hidden_dim: int = 200
57
+ rnn_decoder_number_of_layers: int = 4
58
+ node_step: bool = True
59
+ exclusive_node_step: bool = False
60
+ cov_proj_dim: int = 16 # p in the paper
61
+ ignore_logvar: bool = True # sampling
62
+
63
+ # Aggregator
64
+ aggregator_type: str = "attention" # attention, mean
65
+ aggregator_num_heads: int = 8
66
+
67
+ # Control reconstruction vs prediction losses
68
+ prediction_only: bool = False
69
+ reconstruction_only: bool = False
70
+
71
+ # Deterministic study latent (disable sampling)
72
+ study_latent_deterministic: bool = False
73
+
74
+ # Deterministic individual latent for prediction
75
+ prediction_latent_deterministic: bool = False
76
+
77
+ # How to combine study and individual latents
78
+ combine_latent_mode: str = "mlp" # Options: "mlp", "sum"
79
+
80
+ # MLP configurations (used in init_hidden, output heads, drift)
81
+ init_hidden_num_layers: int = 2
82
+ output_head_num_layers: int = 2
83
+ drift_num_layers: int = 3
84
+ dropout: float = 0.1
85
+ activation: str = "ReLU" # For init/logvar/mean
86
+ drift_activation: str = "Tanh"
87
+ norm: str = "layer" # Options: "layer", "batch", None
88
+
89
+ # Loss
90
+ loss_name: str = "nll" # Options: "nll", "log_nll", "rmse", mv_nll
91
+
92
+ # latent node pk
93
+ kl_weight: float = 1.0
94
+
95
+ # KL regularisation flags
96
+ use_kl_s: bool = True
97
+ use_kl_i: bool = True
98
+ use_kl_i_np: bool = True
99
+ use_kl_init: bool = True
100
+ use_invariance_loss: bool = True
101
+
102
+ # Optional scaling for dosing amount inputs (route types remain unscaled)
103
+ scale_dosing_amounts: bool = True
104
+
105
+ @classmethod
106
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "EncoderDecoderNetworkConfig":
107
+ """Instantiate the network configuration from a YAML file."""
108
+
109
+ with open(file_path, "r", encoding="utf-8") as handle:
110
+ config_dict = yaml.safe_load(handle) or {}
111
+
112
+ if isinstance(config_dict, dict) and "network" in config_dict:
113
+ config_dict = config_dict.get("network") or {}
114
+
115
+ if not isinstance(config_dict, dict):
116
+ raise TypeError("Expected 'network' section in YAML to be a mapping.")
117
+
118
+ return cls(**config_dict)
119
+
120
+
121
+ @dataclass
122
+ class NodePKExperimentConfig:
123
+ """Experiment configuration for NodePK-family models."""
124
+
125
+ experiment_type: str = "nodepk"
126
+ name_str: str = "NodePK"
127
+ comet_ai_key: str = None
128
+ experiment_name: str = "node_pk_compartments"
129
+ hugging_face_token: str = None
130
+ upload_to_hf_hub: bool = True
131
+ hf_model_name: str = "NodePK_test"
132
+ hf_model_card_path: Tuple[str, str, str] = ("hf_model_card", "CVAE_Readme.md")
133
+
134
+ tags: List[str] = field(default_factory=lambda: ["node-pk", "B-0"])
135
+ experiment_indentifier: str = None
136
+ my_results_path: str = None
137
+ experiment_dir: str = None
138
+ verbose: bool = False
139
+ run_index: int = 0
140
+ debug_test: bool = False
141
+
142
+ network: EncoderDecoderNetworkConfig = field(default_factory=EncoderDecoderNetworkConfig)
143
+ mix_data: MixDataConfig = field(default_factory=MixDataConfig)
144
+
145
+ context_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
146
+ target_observations: ObservationsConfig = field(default_factory=ObservationsConfig)
147
+
148
+ meta_study: MetaStudyConfig = field(default_factory=MetaStudyConfig)
149
+ dosing: MetaDosingConfig = field(default_factory=MetaDosingConfig)
150
+
151
+ train: TrainingConfig = field(default_factory=TrainingConfig)
152
+
153
+ @staticmethod
154
+ def from_yaml(file_path: str) -> "NodePKExperimentConfig":
155
+ """Initializes the class from a YAML file.
156
+
157
+ Supports both monolithic experiment YAML files as well as files that
158
+ reference dedicated data, training, and model configuration YAMLs.
159
+ """
160
+
161
+ with open(file_path, "r") as file:
162
+ config_dict = yaml.load(file, Loader=TupleSafeLoader) or {}
163
+
164
+ exp_type = None
165
+ if isinstance(config_dict, dict):
166
+ exp_type = config_dict.get("experiment_type")
167
+ if exp_type is not None and str(exp_type).lower() != "nodepk":
168
+ raise ValueError(
169
+ f"Expected experiment_type 'nodepk' for NodePKExperimentConfig, got {exp_type!r}."
170
+ )
171
+
172
+ base_dir = os.path.dirname(os.path.abspath(file_path))
173
+
174
+ data_cfg_dict = (
175
+ NodePKExperimentConfig._load_ref_yaml(config_dict.get("data_config"), base_dir) or {}
176
+ )
177
+ training_cfg_dict = (
178
+ NodePKExperimentConfig._load_ref_yaml(config_dict.get("training_config"), base_dir)
179
+ or {}
180
+ )
181
+ model_cfg_dict = (
182
+ NodePKExperimentConfig._load_ref_yaml(config_dict.get("model_config"), base_dir) or {}
183
+ )
184
+
185
+
186
+ observations_section = NodePKExperimentConfig._resolve_config_section(
187
+ config_dict, base_dir, "observations_config"
188
+ )
189
+ if observations_section is None:
190
+ observations_section = NodePKExperimentConfig._resolve_config_section(
191
+ data_cfg_dict, base_dir, "observations_config"
192
+ )
193
+ if observations_section is not None:
194
+ context_observations_base = observations_section.get("context_observations")
195
+ target_observations_base = observations_section.get("target_observations")
196
+ else:
197
+ context_observations_base = data_cfg_dict.get("context_observations")
198
+ target_observations_base = data_cfg_dict.get("target_observations")
199
+
200
+ mix_data_section = NodePKExperimentConfig._resolve_config_section(
201
+ config_dict, base_dir, "mix_data_config"
202
+ )
203
+ if mix_data_section is None:
204
+ mix_data_section = NodePKExperimentConfig._resolve_config_section(
205
+ data_cfg_dict, base_dir, "mix_data_config"
206
+ )
207
+ if mix_data_section is None:
208
+ mix_data_section = data_cfg_dict.get("mix_data")
209
+
210
+ meta_study_section = NodePKExperimentConfig._resolve_config_section(
211
+ config_dict, base_dir, "meta_study_config"
212
+ )
213
+ if meta_study_section is None:
214
+ meta_study_section = NodePKExperimentConfig._resolve_config_section(
215
+ data_cfg_dict, base_dir, "meta_study_config"
216
+ )
217
+ meta_dosing_section = NodePKExperimentConfig._resolve_config_section(
218
+ config_dict, base_dir, "meta_dosing_config"
219
+ )
220
+ if meta_dosing_section is None:
221
+ meta_dosing_section = NodePKExperimentConfig._resolve_config_section(
222
+ data_cfg_dict, base_dir, "meta_dosing_config"
223
+ )
224
+
225
+ meta_study_base = NodePKExperimentConfig._extract_config_mapping(
226
+ meta_study_section, "meta_study"
227
+ )
228
+ if meta_study_base is None and meta_dosing_section is not None:
229
+ meta_study_base = NodePKExperimentConfig._extract_config_mapping(
230
+ meta_dosing_section, "meta_study"
231
+ )
232
+ if meta_study_base is None:
233
+ meta_study_base = data_cfg_dict.get("meta_study")
234
+
235
+ dosing_base = NodePKExperimentConfig._extract_config_mapping(meta_dosing_section, "dosing")
236
+ if dosing_base is None:
237
+ dosing_base = data_cfg_dict.get("dosing")
238
+
239
+ mix_data_cfg = NodePKExperimentConfig._merge_dicts(
240
+ mix_data_section, config_dict.get("mix_data")
241
+ )
242
+ context_obs_cfg = NodePKExperimentConfig._merge_dicts(
243
+ context_observations_base, config_dict.get("context_observations")
244
+ )
245
+ target_obs_cfg = NodePKExperimentConfig._merge_dicts(
246
+ target_observations_base, config_dict.get("target_observations")
247
+ )
248
+ meta_study_cfg = NodePKExperimentConfig._merge_dicts(
249
+ meta_study_base, config_dict.get("meta_study")
250
+ )
251
+ dosing_cfg = NodePKExperimentConfig._merge_dicts(dosing_base, config_dict.get("dosing"))
252
+
253
+ train_section = training_cfg_dict.get("train", training_cfg_dict)
254
+ train_cfg = NodePKExperimentConfig._merge_dicts(train_section, config_dict.get("train"))
255
+
256
+ network_section = model_cfg_dict.get("network", model_cfg_dict)
257
+ network_cfg = NodePKExperimentConfig._merge_dicts(
258
+ network_section, config_dict.get("network")
259
+ )
260
+
261
+ # -----------------------------------------------------------------
262
+ # Choose MetaStudy class dynamically (simple vs full)
263
+ # -----------------------------------------------------------------
264
+ if meta_study_cfg.get("simple_mode", False):
265
+ meta_study_instance = SimpleMetaStudyConfig(**meta_study_cfg)
266
+ else:
267
+ meta_study_instance = MetaStudyConfig(**meta_study_cfg)
268
+
269
+ train_cfg = TrainingConfig._filter_kwargs(train_cfg)
270
+
271
+ return NodePKExperimentConfig(
272
+ experiment_type=str(config_dict.get("experiment_type", "nodepk")).lower(),
273
+ name_str=config_dict.get("name_str", "ExampleModel"),
274
+ tags=config_dict.get("tags", ["node-pk", "B-0"]),
275
+ experiment_name=config_dict.get("experiment_name", "aicme_compartments"),
276
+ experiment_indentifier=config_dict.get("experiment_indentifier", None),
277
+ my_results_path=config_dict.get("my_results_path", None),
278
+ experiment_dir=config_dict.get("experiment_dir", None),
279
+ comet_ai_key=config_dict.get("comet_ai_key", None),
280
+ hugging_face_token=config_dict.get("hugging_face_token", None),
281
+ upload_to_hf_hub=config_dict.get("upload_to_hf_hub", True),
282
+ hf_model_name=config_dict.get("hf_model_name", "NodePK_test"),
283
+ hf_model_card_path=tuple(
284
+ config_dict.get("hf_model_card_path", ("hf_model_card", "CVAE_Readme.md"))
285
+ ),
286
+ debug_test=config_dict.get("debug_test", False),
287
+ network=EncoderDecoderNetworkConfig(**network_cfg),
288
+ mix_data=MixDataConfig(**mix_data_cfg),
289
+ context_observations=ObservationsConfig(**context_obs_cfg),
290
+ target_observations=ObservationsConfig(**target_obs_cfg),
291
+ meta_study=meta_study_instance,
292
+ dosing=MetaDosingConfig(**dosing_cfg),
293
+ train=TrainingConfig(**train_cfg),
294
+ )
295
+
296
+ @staticmethod
297
+ def _merge_dicts(
298
+ base_dict: Optional[Dict[str, Any]], override_dict: Optional[Dict[str, Any]]
299
+ ) -> Dict[str, Any]:
300
+ """Merge two optional dictionaries returning a new dictionary."""
301
+
302
+ merged: Dict[str, Any] = {}
303
+
304
+ if base_dict:
305
+ if not isinstance(base_dict, dict):
306
+ raise TypeError(
307
+ "Expected base_dict to be a mapping when merging configuration sections."
308
+ )
309
+ merged = deepcopy(base_dict)
310
+
311
+ if override_dict:
312
+ if not isinstance(override_dict, dict):
313
+ raise TypeError(
314
+ "Expected override_dict to be a mapping when merging configuration sections."
315
+ )
316
+ merged.update(override_dict)
317
+
318
+ return merged
319
+
320
+ @staticmethod
321
+ def _extract_config_mapping(
322
+ section: Optional[Dict[str, Any]], nested_key: str
323
+ ) -> Optional[Dict[str, Any]]:
324
+ """Return a nested configuration mapping or the section itself."""
325
+
326
+ if section is None:
327
+ return None
328
+
329
+ if not isinstance(section, dict):
330
+ raise TypeError(
331
+ "Expected configuration section to be a mapping when extracting nested"
332
+ f" '{nested_key}' values."
333
+ )
334
+
335
+ if nested_key in section:
336
+ nested_value = section[nested_key]
337
+ if nested_value is None:
338
+ return None
339
+ if not isinstance(nested_value, dict):
340
+ raise TypeError(
341
+ f"Expected '{nested_key}' section to be a mapping when extracting configuration values."
342
+ )
343
+ return nested_value
344
+
345
+ return section
346
+
347
+ @staticmethod
348
+ def _load_ref_yaml(
349
+ ref: Optional[Union[str, Dict[str, Any]]], base_dir: str
350
+ ) -> Optional[Dict[str, Any]]:
351
+ """Load a referenced YAML block or return inline dictionaries as-is."""
352
+
353
+ if ref is None:
354
+ return None
355
+
356
+ if isinstance(ref, dict):
357
+ return ref
358
+
359
+ if isinstance(ref, str):
360
+ ref_path = ref
361
+ if not os.path.isabs(ref_path):
362
+ ref_path = os.path.join(base_dir, ref_path)
363
+
364
+ with open(ref_path, "r") as handle:
365
+ return yaml.load(handle, Loader=TupleSafeLoader) or {}
366
+
367
+ raise TypeError("Expected configuration reference to be a mapping or string path.")
368
+
369
+ @staticmethod
370
+ def _resolve_config_section(
371
+ cfg_dict: Dict[str, Any], base_dir: str, key: str
372
+ ) -> Optional[Dict[str, Any]]:
373
+ """Resolve nested configuration references within a configuration block."""
374
+
375
+ if key not in cfg_dict:
376
+ return None
377
+
378
+ section = cfg_dict[key]
379
+
380
+ if section is None:
381
+ return None
382
+
383
+ if isinstance(section, dict):
384
+ ref_value = section.get("_ref") if "_ref" in section else None
385
+ if ref_value is not None:
386
+ loaded = NodePKExperimentConfig._load_ref_yaml(ref_value, base_dir)
387
+ return loaded or {}
388
+ return section
389
+
390
+ if isinstance(section, str):
391
+ loaded = NodePKExperimentConfig._load_ref_yaml(section, base_dir)
392
+ return loaded or {}
393
+
394
+ raise TypeError(
395
+ f"Expected configuration section '{key}' to be a mapping or string reference."
396
+ )
397
+
398
+ def to_yaml(self, file_path: str):
399
+ """Saves the class to a YAML file."""
400
+ with open(file_path, "w") as file:
401
+ yaml.dump(asdict(self), file, default_flow_style=False)
402
+
403
+
404
+ NodePKConfig = NodePKExperimentConfig
405
+
406
+
407
+ class HFNodePKConfig(PretrainedConfig):
408
+ """
409
+ HF config wrapping NodePKConfig plus tracked metrics.
410
+
411
+ Canonical storage:
412
+ self.tracking: dict with shape
413
+ {
414
+ "best": { "<metric_name>": {"value": float, "step": int|None, "epoch": int|None} },
415
+ "meta": { ...optional... }
416
+ }
417
+
418
+ Backward compat:
419
+ - Accepts legacy keys like best_val_loss / best_val_rmse.
420
+ - Mirrors best["val_rmse"] to `best_val_loss` if you still use that elsewhere.
421
+ """
422
+
423
+ model_type = "node_pk"
424
+
425
+ def __init__(self, **kwargs):
426
+ # --- extract tracking / legacy keys before super().__init__ ---
427
+ tracking = kwargs.pop("tracking", None)
428
+
429
+ # legacy keys (accept either; normalize into tracking)
430
+ legacy_best_val_loss = kwargs.pop("best_val_loss", None)
431
+ legacy_best_val_rmse = kwargs.pop("best_val_rmse", None)
432
+
433
+ super().__init__(**kwargs)
434
+
435
+ # copy remaining config fields
436
+ for k, v in kwargs.items():
437
+ setattr(self, k, v)
438
+
439
+ # initialize tracking
440
+ if tracking is None or not isinstance(tracking, dict):
441
+ tracking = {"best": {}, "meta": {}}
442
+ tracking.setdefault("best", {})
443
+ tracking.setdefault("meta", {})
444
+ self.tracking: Dict[str, Any] = tracking
445
+
446
+ # fold legacy into canonical schema if present
447
+ legacy = legacy_best_val_loss if legacy_best_val_loss is not None else legacy_best_val_rmse
448
+ if legacy is not None:
449
+ # choose a canonical metric name; I'd recommend "val_rmse" if that’s what it is.
450
+ self.set_best("val_rmse", legacy)
451
+
452
+ # optional alias for older codepaths
453
+ self._sync_legacy_aliases()
454
+
455
+ # --------- public API ----------
456
+ def set_best(
457
+ self,
458
+ metric_name: str,
459
+ value: Any,
460
+ *,
461
+ step: Optional[int] = None,
462
+ epoch: Optional[int] = None,
463
+ ) -> None:
464
+ v = _to_float(value)
465
+ self.tracking["best"][metric_name] = {"value": v, "step": step, "epoch": epoch}
466
+ self._sync_legacy_aliases()
467
+
468
+ def get_best(self, metric_name: str, default: float = math.inf) -> float:
469
+ d = self.tracking.get("best", {}).get(metric_name)
470
+ if not d:
471
+ return float(default)
472
+ return _to_float(d.get("value", default))
473
+
474
+ def is_better(
475
+ self,
476
+ metric_name: str,
477
+ candidate_value: Any,
478
+ *,
479
+ higher_is_better: bool = False,
480
+ ) -> bool:
481
+ cand = _to_float(candidate_value)
482
+ best = self.get_best(metric_name, default=(-math.inf if higher_is_better else math.inf))
483
+ return cand > best if higher_is_better else cand < best
484
+
485
+ def update_if_better(
486
+ self,
487
+ metric_name: str,
488
+ candidate_value: Any,
489
+ *,
490
+ step: Optional[int] = None,
491
+ epoch: Optional[int] = None,
492
+ higher_is_better: bool = False,
493
+ ) -> bool:
494
+ if self.is_better(metric_name, candidate_value, higher_is_better=higher_is_better):
495
+ self.set_best(metric_name, candidate_value, step=step, epoch=epoch)
496
+ return True
497
+ return False
498
+
499
+ # --------- construction ----------
500
+ @classmethod
501
+ def from_nodepk(cls, nodepk_cfg, **tracked_best: float) -> "HFNodePKConfig":
502
+ """
503
+ tracked_best: e.g. val_rmse=..., val_nll=..., val_crps=...
504
+ """
505
+ cfg_dict = asdict(nodepk_cfg)
506
+ cfg = cls(**cfg_dict)
507
+ for k, v in tracked_best.items():
508
+ cfg.set_best(k, v)
509
+ return cfg
510
+
511
+ # --------- internal ----------
512
+ def _sync_legacy_aliases(self) -> None:
513
+ """
514
+ Keep a legacy scalar field for older code that expects `best_val_loss`.
515
+ Here we mirror it to `best["val_rmse"]` by convention.
516
+ """
517
+ # if val_rmse exists, mirror it; otherwise inf
518
+ self.best_val_loss = self.get_best("val_rmse", default=math.inf)
sim_priors_pk/config_classes/source_process_config.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Union
4
+
5
+ try: # pragma: no cover - exercised indirectly via configuration loading
6
+ import yaml # type: ignore
7
+ except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
8
+ from sim_priors_pk.config_classes import yaml_fallback as yaml
9
+
10
+
11
+ @dataclass
12
+ class SourceProcessConfig:
13
+ """
14
+ Configuration for source processes used by flow and diffusion PK models.
15
+
16
+ Supported source_type values (case-insensitive):
17
+ - "gaussian_process" / "gp"
18
+ - "ornstein_uhlenbeck" / "ou"
19
+ - "wiener"
20
+ - "normal" / "gaussian"
21
+ """
22
+
23
+ source_type: str = "gaussian_process"
24
+
25
+ # Gaussian process hyper-parameter for RBF or OU.
26
+ gp_length_scale: float = 0.1
27
+ gp_variance: float = 1.0
28
+ gp_eps: float = 1e-8
29
+ gp_transform: str = 'softplus' # transformation to apply to the sampled noise, e.g. 'softplus', 'exp'
30
+
31
+ # Flow matching additive noise scale (used only in FlowPK).
32
+ flow_sigma: float = 1e-4
33
+ flow_num_steps: int = 100
34
+ use_OT_coupling: bool = False
35
+
36
+ @classmethod
37
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "SourceProcessConfig":
38
+ """Instantiate the source-process configuration from a YAML file."""
39
+
40
+ with open(file_path, "r", encoding="utf-8") as handle:
41
+ config_dict = yaml.safe_load(handle) or {}
42
+
43
+ if isinstance(config_dict, dict):
44
+ for key in ("source_process", "source", "noise_model"):
45
+ if key in config_dict:
46
+ config_dict = config_dict.get(key) or {}
47
+ break
48
+
49
+ if not isinstance(config_dict, dict):
50
+ raise TypeError("Expected source process configuration to be a mapping.")
51
+
52
+ return cls(**config_dict)
sim_priors_pk/config_classes/training_config.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field, fields
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ import yaml
6
+
7
+
8
+ @dataclass
9
+ class SchedulerTaskConfig:
10
+ """Typed configuration for one scheduler task."""
11
+
12
+ name: str
13
+ fn_key: str
14
+ n_samples: int = 0
15
+ sample_source: str = "unconditional"
16
+ split: str = "val"
17
+ empirical_name: Optional[str] = None
18
+ save_to_disk: bool = True
19
+ log_prefix: str = "val"
20
+ use_ema: bool = False
21
+ checkpoint_metric: bool = False
22
+ checkpoint_metric_name: Optional[str] = None
23
+ checkpoint_mode: str = "min"
24
+ task_cfg: Dict[str, Any] = field(default_factory=dict)
25
+
26
+
27
+ @dataclass
28
+ class SchedulerConfig:
29
+ """Typed configuration for scheduler-driven callback execution."""
30
+
31
+ percent_step: float = 1.0
32
+ include_end: bool = True
33
+ skip_sanity_check: bool = True
34
+ store_samples: bool = True
35
+ max_samples_per_group: int = 32
36
+ keep_temp_files: bool = False
37
+ cache_dir: Optional[str] = None
38
+ # Supported selectors are:
39
+ # - ``end`` for in-memory train-end weights,
40
+ # - ``last`` / ``best`` for experiment checkpoint callbacks,
41
+ # - scheduler-managed metric checkpoint names emitted by tasks.
42
+ checkpoint_used_in_end: List[str] = field(default_factory=lambda: ["end"])
43
+ tasks_validation: List[SchedulerTaskConfig] = field(default_factory=list)
44
+ task_during: List[SchedulerTaskConfig] = field(default_factory=list)
45
+ tasks_end: List[SchedulerTaskConfig] = field(default_factory=list)
46
+
47
+
48
+ @dataclass
49
+ class TrainingConfig:
50
+ epochs: int = 20
51
+ batch_size: int = 8
52
+ gradient_clip_val: float = 1.0
53
+ optimizer_name: str = "AdamW"
54
+ learning_rate: float = 0.0001
55
+ weight_decay: float = 1.0e-4
56
+ num_workers: int = 3
57
+ persistent_workers: bool = True
58
+ shuffle_val: bool = True
59
+
60
+ num_batch_plot: int = 1
61
+ log_interval: int = 1 # Frequency of logging and visualization
62
+
63
+ # Scheduler-driven PK evaluation and visualization.
64
+ callbacks_scheduler: Optional[Union[SchedulerConfig, Dict[str, Any]]] = None
65
+
66
+ betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
67
+ eps: float = 1.0e-8
68
+ amsgrad: bool = False
69
+ scheduler_name: str = "CosineAnnealingLR"
70
+ scheduler_params: Dict[str, Union[float, int]] = field(
71
+ default_factory=lambda: {"T_max": 1000, "eta_min": 5.0e-5, "last_epoch": -1}
72
+ )
73
+
74
+ @classmethod
75
+ def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "TrainingConfig":
76
+ """Instantiate the training configuration from a YAML file."""
77
+
78
+ with open(file_path, "r", encoding="utf-8") as handle:
79
+ config_dict = yaml.safe_load(handle) or {}
80
+
81
+ if isinstance(config_dict, dict) and "train" in config_dict:
82
+ config_dict = config_dict.get("train") or {}
83
+
84
+ if not isinstance(config_dict, dict):
85
+ raise TypeError("Expected 'train' section in YAML to be a mapping.")
86
+
87
+ return cls(**cls._filter_kwargs(config_dict))
88
+
89
+ @classmethod
90
+ def _filter_kwargs(cls, raw: Dict[str, Any]) -> Dict[str, Any]:
91
+ """Drop unknown keys (including deprecated logging flags)."""
92
+
93
+ if not isinstance(raw, dict):
94
+ return {}
95
+ valid = {f.name for f in fields(cls)}
96
+ return {key: value for key, value in raw.items() if key in valid}
sim_priors_pk/config_classes/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # pragma: no cover - exercised indirectly via configuration loading
2
+ import yaml # type: ignore
3
+ from yaml import SafeLoader # type: ignore
4
+ except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
5
+ from sim_priors_pk.config_classes import yaml_fallback as yaml
6
+ SafeLoader = yaml.SafeLoader
7
+
8
+ class TupleSafeLoader(SafeLoader):
9
+ def construct_python_tuple(self, node):
10
+ # Convert the YAML sequence (e.g., [0.01, 0.1]) into a tuple
11
+ return tuple(self.construct_sequence(node))
12
+
13
+ # Register the constructor for the fully qualified tag
14
+ TupleSafeLoader.add_constructor('tag:yaml.org,2002:python/tuple', TupleSafeLoader.construct_python_tuple)
sim_priors_pk/config_classes/yaml_fallback.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal YAML loader fallback used when PyYAML is unavailable."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import json
7
+ from typing import Any, Dict, List
8
+
9
+
10
+ class SafeLoader:
11
+ """Compatibility stub mimicking :class:`yaml.SafeLoader`."""
12
+
13
+ _constructors: Dict[str, Any] = {}
14
+
15
+ @classmethod
16
+ def add_constructor(cls, tag: str, constructor: Any) -> None:
17
+ cls._constructors[tag] = constructor
18
+
19
+
20
+ Loader = SafeLoader
21
+
22
+
23
+ def _convert_scalar(value: str) -> Any:
24
+ lowered = value.lower()
25
+ if lowered in {"true", "yes"}:
26
+ return True
27
+ if lowered in {"false", "no"}:
28
+ return False
29
+ if lowered in {"null", "none", "~"}:
30
+ return None
31
+
32
+ if value.startswith("[") or value.startswith("{") or value.startswith("("):
33
+ try:
34
+ return ast.literal_eval(value)
35
+ except (SyntaxError, ValueError):
36
+ pass
37
+
38
+ if value.startswith("\"") and value.endswith("\""):
39
+ return value[1:-1]
40
+ if value.startswith("'") and value.endswith("'"):
41
+ return value[1:-1]
42
+
43
+ try:
44
+ if "." in value or "e" in lowered:
45
+ return float(value)
46
+ return int(value)
47
+ except ValueError:
48
+ pass
49
+
50
+ return value
51
+
52
+
53
+ def _parse_lines(lines: List[str], indent: int = 0) -> Any:
54
+ mapping: Dict[str, Any] = {}
55
+ sequence: List[Any] = []
56
+ is_list: bool | None = None
57
+
58
+ while lines:
59
+ line = lines[0]
60
+ stripped = line.lstrip()
61
+
62
+ if not stripped or stripped.startswith("#"):
63
+ lines.pop(0)
64
+ continue
65
+
66
+ current_indent = len(line) - len(stripped)
67
+
68
+ if current_indent < indent and not stripped.startswith("- "):
69
+ break
70
+
71
+ if stripped.startswith("- "):
72
+ if is_list is False:
73
+ raise ValueError("Mixed mapping and sequence at the same level is unsupported.")
74
+ is_list = True
75
+
76
+ lines.pop(0)
77
+ item_value = stripped[2:].strip()
78
+
79
+ if not item_value:
80
+ sequence.append(_parse_lines(lines, current_indent + 2))
81
+ continue
82
+
83
+ if item_value.endswith(":"):
84
+ key = item_value[:-1].strip()
85
+ value = _parse_lines(lines, current_indent + 2)
86
+ sequence.append({key: value})
87
+ continue
88
+
89
+ sequence.append(_convert_scalar(item_value))
90
+ continue
91
+
92
+ if is_list is True:
93
+ raise ValueError("Mixed mapping and sequence at the same level is unsupported.")
94
+
95
+ is_list = False
96
+
97
+ lines.pop(0)
98
+ if ":" not in stripped:
99
+ raise ValueError(f"Invalid mapping entry: '{stripped}'.")
100
+
101
+ key, value_part = stripped.split(":", 1)
102
+ key = key.strip()
103
+ value_part = value_part.strip()
104
+
105
+ if value_part:
106
+ mapping[key] = _convert_scalar(value_part)
107
+ else:
108
+ mapping[key] = _parse_lines(lines, current_indent + 2)
109
+
110
+ if is_list:
111
+ return sequence
112
+ return mapping
113
+
114
+
115
+ def safe_load(stream: Any) -> Any:
116
+ """Parse YAML content from ``stream`` and return Python data structures."""
117
+
118
+ if hasattr(stream, "read"):
119
+ content = stream.read()
120
+ else:
121
+ content = stream
122
+
123
+ if not isinstance(content, str):
124
+ raise TypeError("YAML content must be a string or text stream.")
125
+
126
+ raw_lines = content.splitlines()
127
+ return _parse_lines(raw_lines.copy()) if raw_lines else None
128
+
129
+
130
+ def load(stream: Any, Loader: Any | None = None) -> Any: # noqa: N803 - API compatibility
131
+ """Compatibility wrapper mirroring :func:`yaml.load`."""
132
+
133
+ return safe_load(stream)
134
+
135
+
136
+ def dump(data: Any, stream: Any | None = None, default_flow_style: bool | None = None) -> str:
137
+ """Serialise ``data`` to YAML (JSON style in the fallback implementation)."""
138
+
139
+ text = json.dumps(data, indent=2)
140
+ if stream is not None:
141
+ stream.write(text)
142
+ return ""
143
+ return text
sim_priors_pk/data/README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # `sim_priors_pk.data` Package Guide
2
+
3
+ This guide documents the purpose of every subpackage that lives under `sim_priors_pk/data`. Use it as a quick reference when wiring new data pipelines or navigating the simulated pharmacokinetic (PK) workflow.
4
+
5
+ ## Configuration Preamble
6
+
7
+ Simulations are configured by combining reusable YAML files with the dataclasses that live in `sim_priors_pk.config_classes`. YAMLS are the file that populates those classes, we can define configs from the classes or reading the files.
8
+
9
+ - **YAML files (`config_files/`)** – Ready-to-use experiment definitions grouped under `config_files/experiment_configs`. For example, the `node-pk` folder contains `base-homogeneous.*.yaml` files that describe meta-study, dosing, and observation settings.
10
+ - **Config dataclasses (`sim_priors_pk/config_classes/`)** – Python dataclasses (`MetaStudyConfig`, `MetaDosingConfig`, `ObservationsConfig`, and friends) that parse those YAML files or can be instantiated directly in code when you need programmatic overrides.
11
+
12
+ When you load configurations in tests or scripts, prefer `MetaStudyConfig.from_yaml(...)` and similar helpers. They keep the simulation code aligned with the canonical YAML layout while still allowing you to craft configurations in pure Python when necessary.
13
+
14
+ ## Top-Level Layout
15
+
16
+ This are the files that matter for the handling of simulations anda data:
17
+
18
+ ```
19
+
20
+ sim_priors_pk/
21
+ ├── config_files/
22
+ │ └── experiment_configs/
23
+ ├── scripts/
24
+ ├── sim_priors_pk/
25
+ │ ├── config_classes/
26
+ │ └── data/
27
+ │ ├── data_empirical/
28
+ │ ├── data_generation/
29
+ │ ├── data_preprocessing/
30
+ │ ├── datasets/
31
+ │ └── extra/
32
+ └── tests/
33
+ └── data/
34
+ └── simulation_data/
35
+ └── test_simulations.py
36
+ ```
37
+
38
+ Each directory is described below together with the most important entry points it exposes.
39
+
40
+ ## `data_empirical`
41
+
42
+ Defines the data contracts used across the project.
43
+ These contracts specify the canonical JSON schema (StudyJSON, IndividualJSON) that standardizes how pharmacokinetic studies are represented — both empirical and simulated.
44
+ They serve as the interface between raw datasets, tensor batches, and model-ready data structures, ensuring a unified format throughout the pipeline. These helpers make it straightforward to load Hugging Face datasets or local JSON files, validate them, and materialise PyTorch-compatible batches.
45
+
46
+
47
+ ## `data_generation`
48
+
49
+ Simulation building blocks used to synthesise PK trajectories under configurable dosing and observation schemes.
50
+
51
+ * [`compartment_models.py`](data_generation/compartment_models.py) implements the stochastic sampling of population/individual PK parameters and the compartmental simulation loops.
52
+ * [`observations_classes.py`](data_generation/observations_classes.py) describe observation strategies (e.g. sparse vs. dense sampling) and utilities to realise them.
53
+ * [`compartment_models_management.py`](data_generation/compartment_models_management.py) orchestrates the full simulation workflow: it takes the meta-configuration, samples individual and dosing configurations, runs the compartmental simulations, applies the observation strategy, and assembles complete ensembles of studies in the data contracts.
54
+
55
+ Together these modules allow you to go from configuration dataclasses to simulated studies that mirror the empirical format.
56
+
57
+ ## `data_preprocessing`
58
+
59
+ deprecated
60
+
61
+ ## `datasets`
62
+
63
+ Lightning-ready dataset/dataloader factories.
64
+
65
+ - [`aicme_datasets.py`](datasets/aicme_datasets.py) defines `AICMECompartmentsDataBatch` and related PyTorch Lightning `DataModule` wrappers that harmonise both empirical and simulated studies for downstream training.
66
+
67
+
68
+ ## Putting It All Together
69
+
70
+ A typical workflow is:
71
+
72
+ 1. **Configure**: Use `sim_priors_pk.config_classes` to describe study, dosing, and observation priors.
73
+ 2. **Simulate**: Call into `data_generation` to sample synthetic studies or to augment empirical cohorts.
74
+ 3. **Serialise or load**: Store simulations as JSON, or load existing JSON/CSV with `data_empirical` and `data_preprocessing`.
75
+ 4. **Batch**: Wrap tensors using `datasets.AICMECompartmentsDataModule` for consumption by modules in `sim_priors_pk.models` and training scripts.
76
+
77
+ Refer back to this document whenever you onboard a new collaborator or reorganise data flows—the sections above stay aligned with the current code base.
78
+
79
+ ## Worked Examples and Tests
80
+
81
+ Integration-style tests in `tests/data/simulation_data/test_simulations.py` demonstrate how the configuration pieces fit together:
82
+
83
+ - `test_prepare_full_simulation_to_study_json` shows how YAML-driven configs from `config_files/experiment_configs/node-pk` feed into `prepare_full_simulation_to_study_json` and culminate in a canonical `StudyJSON`.
84
+ - `test_prepare_ensemble_of_simulations` builds on the same configuration files to generate an ensemble of studies and persists them to disk, illustrating how bulk simulations can be orchestrated.
85
+
86
+ Use these tests as executable documentation whenever you need to follow the end-to-end flow from configuration files to simulated study artefacts.
sim_priors_pk/data/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility namespace for data-related modules.
2
+
3
+ This package groups empirical, generation, preprocessing, and dataset
4
+ helpers so they can be imported with the ``sim_priors_pk.data`` prefix.
5
+ """
6
+
7
+ __all__ = [
8
+ "data_empirical",
9
+ "data_generation",
10
+ "data_preprocessing",
11
+ "datasets",
12
+ ]
sim_priors_pk/data/data_empirical/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for working with empirical JSON study data."""
2
+
3
+ try: # pragma: no cover - optional torch dependency
4
+ from .builder import (
5
+ JSON2AICMEBuilder,
6
+ EmpiricalBatchConfig,
7
+ held_out_ind_json,
8
+ held_out_list_json,
9
+ load_empirical_json_batches,
10
+ load_empirical_json_batches_as_dm,
11
+ load_empirical_hf_batches_as_dm,
12
+ databatch_to_study_jsons,
13
+ prediction_to_study_jsons,
14
+ )
15
+ except ModuleNotFoundError as exc: # pragma: no cover - allow missing torch
16
+ if exc.name != "torch":
17
+ raise
18
+ JSON2AICMEBuilder = EmpiricalBatchConfig = None # type: ignore
19
+ held_out_ind_json = held_out_list_json = None # type: ignore
20
+ load_empirical_json_batches = load_empirical_json_batches_as_dm = None # type: ignore
21
+ databatch_to_study_jsons = prediction_to_study_jsons = None # type: ignore
22
+
23
+ __all__ = [
24
+ "json_schema",
25
+ "JSON2AICMEBuilder",
26
+ "EmpiricalBatchConfig",
27
+ "held_out_ind_json",
28
+ "held_out_list_json",
29
+ "load_empirical_json_batches",
30
+ "load_empirical_json_batches_as_dm",
31
+ "load_empirical_hf_batches_as_dm",
32
+ "databatch_to_study_jsons",
33
+ "prediction_to_study_jsons",
34
+ "json_stats",
35
+ ]
sim_priors_pk/data/data_empirical/builder.py ADDED
@@ -0,0 +1,1139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
7
+
8
+ import torch
9
+ from datasets import load_dataset
10
+ from torchtyping import TensorType as TT
11
+
12
+ from sim_priors_pk.config_classes.data_config import (
13
+ MetaDosingConfig,
14
+ )
15
+ from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
16
+
17
+ if TYPE_CHECKING: # pragma: no cover - imported only for type hints
18
+ from sim_priors_pk.data.datasets.aicme_datasets import AICMECompartmentsDataModule
19
+
20
+ from .json_schema import IndividualJSON, StudyJSON, canonicalize_study
21
+ from .json_stats import EmpiricalJSONStats, compute_json_stats
22
+
23
+
24
+ @dataclass
25
+ class EmpiricalBatchConfig:
26
+ """Configuration for empirical batch construction.
27
+
28
+ Attributes
29
+ ----------
30
+ pad_value_time:
31
+ Value used to pad time tensors.
32
+ pad_value_obs:
33
+ Value used to pad observation tensors.
34
+ max_databatch_size:
35
+ Maximum number of studies that can be stacked into a single batch.
36
+ max_individuals:
37
+ Maximum number of individuals per context or target block.
38
+ max_observations:
39
+ Maximum number of observation time points per individual.
40
+ max_remaining:
41
+ Maximum number of remaining time points per individual.
42
+ max_context_individuals / max_target_individuals:
43
+ Optional overrides specifying separate capacities for context and
44
+ target individual counts.
45
+ max_context_observations / max_target_observations:
46
+ Optional overrides specifying per-block observation capacities.
47
+ max_context_remaining / max_target_remaining:
48
+ Optional overrides specifying per-block remaining simulation
49
+ capacities.
50
+ """
51
+
52
+ pad_value_time: float = 0.0
53
+ pad_value_obs: float = 0.0
54
+ max_databatch_size: int = 8
55
+ max_individuals: int = 1
56
+ max_observations: int = 0
57
+ max_remaining: int = 0
58
+ max_context_individuals: Optional[int] = None
59
+ max_target_individuals: Optional[int] = None
60
+ max_context_observations: Optional[int] = None
61
+ max_target_observations: Optional[int] = None
62
+ max_context_remaining: Optional[int] = None
63
+ max_target_remaining: Optional[int] = None
64
+
65
+
66
+ class JSON2AICMEBuilder:
67
+ """Convert empirical study JSON to :class:`AICMECompartmentsDataBatch`.
68
+
69
+ The builder pads context and target individuals to fixed sizes and
70
+ assembles the :class:`AICMECompartmentsDataBatch` expected by the models.
71
+ """
72
+
73
+ def __init__(self, cfg: EmpiricalBatchConfig) -> None:
74
+ self.cfg = cfg
75
+
76
+ def _ctx_cap(self) -> int:
77
+ return (
78
+ self.cfg.max_context_individuals
79
+ if self.cfg.max_context_individuals is not None
80
+ else self.cfg.max_individuals
81
+ )
82
+
83
+ def _tgt_cap(self) -> int:
84
+ return (
85
+ self.cfg.max_target_individuals
86
+ if self.cfg.max_target_individuals is not None
87
+ else self.cfg.max_individuals
88
+ )
89
+
90
+ def _ctx_obs_cap(self) -> int:
91
+ return (
92
+ self.cfg.max_context_observations
93
+ if self.cfg.max_context_observations is not None
94
+ else self.cfg.max_observations
95
+ )
96
+
97
+ def _tgt_obs_cap(self) -> int:
98
+ return (
99
+ self.cfg.max_target_observations
100
+ if self.cfg.max_target_observations is not None
101
+ else self.cfg.max_observations
102
+ )
103
+
104
+ def _ctx_rem_cap(self) -> int:
105
+ return (
106
+ self.cfg.max_context_remaining
107
+ if self.cfg.max_context_remaining is not None
108
+ else self.cfg.max_remaining
109
+ )
110
+
111
+ def _tgt_rem_cap(self) -> int:
112
+ return (
113
+ self.cfg.max_target_remaining
114
+ if self.cfg.max_target_remaining is not None
115
+ else self.cfg.max_remaining
116
+ )
117
+
118
+ def _block_from_inds(
119
+ self,
120
+ inds: List[IndividualJSON],
121
+ *,
122
+ max_individuals: int,
123
+ obs_cap: int,
124
+ rem_cap: int,
125
+ ) -> Dict[str, TT]:
126
+ """Assemble tensors for a list of individuals.
127
+
128
+ Padding is applied so that each block has the same number of
129
+ individuals (``max_individuals``) and time steps
130
+ (``max_observations``/``max_remaining``).
131
+ """
132
+
133
+ I_max = max(0, max_individuals)
134
+ ET = max(0, obs_cap)
135
+ R = max(0, rem_cap)
136
+
137
+ obs_tensor = torch.full((I_max, ET), self.cfg.pad_value_obs) # [I, ET]
138
+ time_tensor = torch.full((I_max, ET), self.cfg.pad_value_time) # [I, ET]
139
+ mask_tensor = torch.zeros((I_max, ET), dtype=torch.bool) # [I, ET]
140
+
141
+ rem_tensor = (
142
+ torch.full((I_max, R), self.cfg.pad_value_obs) if R else torch.zeros(I_max, 0)
143
+ ) # [I, R]
144
+ rem_time_tensor = (
145
+ torch.full((I_max, R), self.cfg.pad_value_time) if R else torch.zeros(I_max, 0)
146
+ ) # [I, R]
147
+ rem_mask_tensor = (
148
+ torch.zeros((I_max, R), dtype=torch.bool)
149
+ if R
150
+ else torch.zeros(I_max, 0, dtype=torch.bool)
151
+ ) # [I, R]
152
+
153
+ for i, ind in enumerate(inds[:I_max]):
154
+ obs = torch.tensor(ind.get("observations", []), dtype=torch.float32) # [ET?]
155
+ time = torch.tensor(ind.get("observation_times", []), dtype=torch.float32) # [ET?]
156
+ L = min(obs.shape[0], ET)
157
+ obs_tensor[i, :L] = obs[:L]
158
+ time_tensor[i, :L] = time[:L]
159
+ mask_tensor[i, :L] = True
160
+
161
+ rem = torch.tensor(ind.get("remaining", []), dtype=torch.float32) # [R?]
162
+ rem_t = torch.tensor(ind.get("remaining_times", []), dtype=torch.float32) # [R?]
163
+ Lr = min(rem.shape[0], R)
164
+ if R:
165
+ rem_tensor[i, :Lr] = rem[:Lr]
166
+ rem_time_tensor[i, :Lr] = rem_t[:Lr]
167
+ rem_mask_tensor[i, :Lr] = True
168
+
169
+ return {
170
+ "obs": obs_tensor,
171
+ "time": time_tensor,
172
+ "mask": mask_tensor,
173
+ "rem": rem_tensor,
174
+ "rem_time": rem_time_tensor,
175
+ "rem_mask": rem_mask_tensor,
176
+ }
177
+
178
+ def build_study_batch(
179
+ self, study: StudyJSON, meta_dosing: MetaDosingConfig
180
+ ) -> AICMECompartmentsDataBatch:
181
+ """Build a batch for a single study.
182
+ DOES NOT USES OBSERVATIONS STRATEGIESM,
183
+ takes the observation structure as given by the JSON data
184
+
185
+ Parameters
186
+ ----------
187
+ study:
188
+ Canonicalised representation of one study.
189
+ meta_dosing:
190
+ Global dosing configuration.
191
+
192
+ Returns
193
+ -------
194
+ AICMECompartmentsDataBatch
195
+ Batch with ``B=1``.
196
+ """
197
+
198
+ study = canonicalize_study(study)
199
+ ctx_cap = self._ctx_cap()
200
+ tgt_cap = self._tgt_cap()
201
+
202
+ ctx_block = self._block_from_inds(
203
+ study["context"],
204
+ max_individuals=ctx_cap,
205
+ obs_cap=self._ctx_obs_cap(),
206
+ rem_cap=self._ctx_rem_cap(),
207
+ )
208
+ tgt_block = self._block_from_inds(
209
+ study["target"],
210
+ max_individuals=tgt_cap,
211
+ obs_cap=self._tgt_obs_cap(),
212
+ rem_cap=self._tgt_rem_cap(),
213
+ )
214
+
215
+ route_vocab = {r: i for i, r in enumerate(meta_dosing.route_options)}
216
+
217
+ def _dose_route(inds: List[IndividualJSON], I_max: int):
218
+ amounts = torch.zeros(1, I_max, dtype=torch.float32) # [1, I]
219
+ routes = torch.zeros(1, I_max, dtype=torch.long) # [1, I]
220
+ for i, ind in enumerate(inds[:I_max]):
221
+ if ind.get("dosing"):
222
+ amounts[0, i] = ind["dosing"][0]
223
+ routes[0, i] = route_vocab.get(ind["dosing_type"][0], 0)
224
+ return amounts, routes
225
+
226
+ c_dose, c_route = _dose_route(study["context"], ctx_cap)
227
+ t_dose, t_route = _dose_route(study["target"], tgt_cap)
228
+
229
+ def _unsqueeze(block):
230
+ obs = block["obs"].unsqueeze(0).unsqueeze(-1) # [1, I, ET, 1]
231
+ time = block["time"].unsqueeze(0).unsqueeze(-1) # [1, I, ET, 1]
232
+ mask = block["mask"].unsqueeze(0) # [1, I, ET]
233
+ rem = block["rem"].unsqueeze(0).unsqueeze(-1) # [1, I, R, 1]
234
+ rem_time = block["rem_time"].unsqueeze(0).unsqueeze(-1) # [1, I, R, 1]
235
+ rem_mask = block["rem_mask"].unsqueeze(0) # [1, I, R]
236
+ return obs, time, mask, rem, rem_time, rem_mask
237
+
238
+ t_obs, t_time, t_mask, t_rem, t_rem_time, t_rem_mask = _unsqueeze(tgt_block)
239
+ c_obs, c_time, c_mask, c_rem, c_rem_time, c_rem_mask = _unsqueeze(ctx_block)
240
+
241
+ mask_ctx_inds = torch.zeros(1, ctx_cap, dtype=torch.bool) # [1, I]
242
+ mask_ctx_inds[0, : min(len(study["context"]), ctx_cap)] = True
243
+ mask_tgt_inds = torch.zeros(1, tgt_cap, dtype=torch.bool) # [1, I]
244
+ mask_tgt_inds[0, : min(len(study["target"]), tgt_cap)] = True
245
+
246
+ study_name = [study["meta_data"]["study_name"]]
247
+ substance_name = [study["meta_data"].get("substance_name", "")]
248
+
249
+ context_subject_name = [
250
+ [
251
+ study["context"][i].get("name_id", "") if i < len(study["context"]) else ""
252
+ for i in range(ctx_cap)
253
+ ]
254
+ ]
255
+ target_subject_name = [
256
+ [
257
+ study["target"][i].get("name_id", "") if i < len(study["target"]) else ""
258
+ for i in range(tgt_cap)
259
+ ]
260
+ ]
261
+
262
+ batch = AICMECompartmentsDataBatch(
263
+ target_obs=t_obs,
264
+ target_obs_time=t_time,
265
+ target_obs_mask=t_mask,
266
+ target_rem_sim=t_rem,
267
+ target_rem_sim_time=t_rem_time,
268
+ target_rem_sim_mask=t_rem_mask,
269
+ context_obs=c_obs,
270
+ context_obs_time=c_time,
271
+ context_obs_mask=c_mask,
272
+ context_rem_sim=c_rem,
273
+ context_rem_sim_time=c_rem_time,
274
+ context_rem_sim_mask=c_rem_mask,
275
+ target_dosing_amounts=t_dose,
276
+ target_dosing_route_types=t_route,
277
+ context_dosing_amounts=c_dose,
278
+ context_dosing_route_types=c_route,
279
+ mask_context_individuals=mask_ctx_inds,
280
+ mask_target_individuals=mask_tgt_inds,
281
+ study_name=study_name,
282
+ context_subject_name=context_subject_name,
283
+ target_subject_name=target_subject_name,
284
+ substance_name=substance_name,
285
+ time_scales=torch.tensor([[0.0, 0.0]]),
286
+ is_empirical=True,
287
+ )
288
+ return batch
289
+
290
+ @staticmethod
291
+ def _stack_B(
292
+ batches: List[AICMECompartmentsDataBatch],
293
+ ) -> AICMECompartmentsDataBatch:
294
+ """Concatenate ``batches`` along the batch dimension ``B``.
295
+
296
+ Each input batch must have ``B=1``; the returned batch will have
297
+ ``B=len(batches)`` with index order preserved.
298
+ """
299
+
300
+ if not batches:
301
+ raise ValueError("batches must not be empty")
302
+
303
+ stacked_fields = []
304
+ for values in zip(*batches):
305
+ first = values[0]
306
+ if isinstance(first, torch.Tensor):
307
+ stacked_fields.append(torch.cat(values, dim=0)) # [B, ...]
308
+ elif isinstance(first, list):
309
+ if first and isinstance(first[0], list):
310
+ merged_nested: List[List[str]] = []
311
+ for v in values:
312
+ merged_nested.extend(v)
313
+ stacked_fields.append(merged_nested)
314
+ else:
315
+ merged: List[str] = []
316
+ for v in values:
317
+ merged.extend(v)
318
+ stacked_fields.append(merged)
319
+ else:
320
+ stacked_fields.append(first)
321
+ return AICMECompartmentsDataBatch(*stacked_fields)
322
+
323
+ def build_one_aicmebatch(
324
+ self, studies: List[StudyJSON], meta_dosing: MetaDosingConfig
325
+ ) -> AICMECompartmentsDataBatch:
326
+ """Build a single batch from multiple studies.
327
+
328
+ Parameters
329
+ ----------
330
+ studies:
331
+ List of studies to combine. The resulting batch will have
332
+ ``B=len(studies)``.
333
+ meta_dosing:
334
+ Global dosing configuration shared across studies.
335
+
336
+ Returns
337
+ -------
338
+ AICMECompartmentsDataBatch
339
+ Combined batch with batch dimension indexing the supplied
340
+ studies in order.
341
+ """
342
+
343
+ per_study = [self.build_study_batch(s, meta_dosing) for s in studies]
344
+ return self._stack_B(per_study)
345
+
346
+ def build_one_aicmebatch_as_dataset(
347
+ self,
348
+ studies: List[StudyJSON],
349
+ context_strategy,
350
+ target_strategy,
351
+ meta_dosing: MetaDosingConfig,
352
+ *,
353
+ return_studies: bool = False, # ← debugging flag (default = True)
354
+ ) -> List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]:
355
+ """Create batches mirroring ``AICMECompartmentsDataset`` processing.
356
+
357
+ For each study we generate leave-one-out permutations using
358
+ :func:`held_out_ind_json`. The provided ``context_strategy`` and
359
+ ``target_strategy`` are then used to apply the same empirical splitting
360
+ between observed and remaining measurements as performed in
361
+ :class:`AICMECompartmentsDataset`. Each permutation across all studies is
362
+ stacked along the batch dimension ``B``.
363
+
364
+ Parameters
365
+ ----------
366
+ studies:
367
+ List of empirical studies. Each study is expected to contain only a
368
+ context block; target individuals are produced via leave-one-out
369
+ permutations.
370
+ context_strategy / target_strategy:
371
+ Observation strategies matching those used by
372
+ :class:`AICMECompartmentsDataset` for shaping context and target
373
+ data respectively.
374
+ meta_dosing:
375
+ Global dosing configuration.
376
+ return_studies:
377
+ If True (default), return the intermediate permuted study dicts
378
+ instead of building full ``AICMECompartmentsDataBatch`` objects.
379
+ Useful for debugging.
380
+
381
+ Returns
382
+ -------
383
+ List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]
384
+ If `return_studies` is True → list of permuted study dicts.
385
+ If `return_studies` is False → list of ``AICMECompartmentsDataBatch``.
386
+ """
387
+ canon_studies = [canonicalize_study(s, drop_tgt_too_few=False) for s in studies]
388
+ max_perm = max(len(s["context"]) for s in canon_studies)
389
+ per_study_perms = [held_out_ind_json(s, max_perm) for s in canon_studies]
390
+
391
+ batches = []
392
+ for perm_idx in range(max_perm):
393
+ permuted_studies = [
394
+ self._process_one_study_perm(
395
+ study_perms[perm_idx], context_strategy, target_strategy
396
+ )
397
+ for study_perms in per_study_perms
398
+ ]
399
+ if return_studies:
400
+ batches.append(permuted_studies) # debugging: raw dicts
401
+ else:
402
+ batches.append(self.build_one_aicmebatch(permuted_studies, meta_dosing))
403
+ return batches
404
+
405
+ def build_one_aicmebatch_as_dataset_no_heldout(
406
+ self,
407
+ studies: List[StudyJSON],
408
+ context_strategy,
409
+ target_strategy,
410
+ meta_dosing: MetaDosingConfig,
411
+ *,
412
+ return_studies: bool = False,
413
+ ) -> List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]:
414
+ """Create a single empirical batch without leave-one-out targets.
415
+
416
+ This method mirrors :meth:`build_one_aicmebatch_as_dataset` preprocessing
417
+ but does not move any individual from context to target. All individuals
418
+ remain in context and the returned list has a single element.
419
+
420
+ Parameters
421
+ ----------
422
+ studies:
423
+ List of empirical studies.
424
+ context_strategy / target_strategy:
425
+ Observation strategies matching those used by
426
+ :class:`AICMECompartmentsDataset`.
427
+ meta_dosing:
428
+ Global dosing configuration.
429
+ return_studies:
430
+ If ``True``, return the processed ``StudyJSON`` records instead of a
431
+ fully built :class:`AICMECompartmentsDataBatch`.
432
+
433
+ Returns
434
+ -------
435
+ List[Union[AICMECompartmentsDataBatch, List[StudyJSON]]]
436
+ A list with length one containing either processed studies or one
437
+ ``AICMECompartmentsDataBatch``.
438
+ """
439
+ canon_studies = [canonicalize_study(s, drop_tgt_too_few=False) for s in studies]
440
+ context_only_studies: List[StudyJSON] = []
441
+ for study in canon_studies:
442
+ all_inds = list(study.get("context", [])) + list(study.get("target", []))
443
+ context_only_studies.append(
444
+ {
445
+ "context": all_inds,
446
+ "target": [],
447
+ "meta_data": dict(study.get("meta_data", {})),
448
+ }
449
+ )
450
+ processed_studies = [
451
+ self._process_one_study_perm(study, context_strategy, target_strategy)
452
+ for study in context_only_studies
453
+ ]
454
+
455
+ if return_studies:
456
+ return [processed_studies]
457
+ return [self.build_one_aicmebatch(processed_studies, meta_dosing)]
458
+
459
+ def _process_one_study_perm(
460
+ self,
461
+ study: StudyJSON,
462
+ context_strategy,
463
+ target_strategy,
464
+ ) -> StudyJSON:
465
+ """Turn one permuted study into tensors and apply strategies."""
466
+ processed = {"context": [], "target": [], "meta_data": study["meta_data"]}
467
+
468
+ for block, inds, strat in (
469
+ ("context", study["context"], context_strategy),
470
+ ("target", study["target"], target_strategy),
471
+ ):
472
+ processed[block] = self._process_block(inds, strat)
473
+
474
+ return processed
475
+
476
+ def _process_block(
477
+ self,
478
+ inds: List[IndividualJSON],
479
+ strat,
480
+ ) -> List[IndividualJSON]:
481
+ """Convert a list of individuals into padded tensors, then apply strategy."""
482
+ if not inds:
483
+ return []
484
+
485
+ obs, times, mask = self._pack_individuals(inds)
486
+ obs_o, time_o, mask_o, rem_o, rem_t, rem_m = strat.generate_empirical(obs, times, mask)
487
+
488
+ return self._rebuild_individuals(inds, obs_o, time_o, mask_o, rem_o, rem_t, rem_m)
489
+
490
+ def _pack_individuals(
491
+ self,
492
+ inds: List[IndividualJSON],
493
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
494
+ """Pad individuals into (obs, times, mask)."""
495
+ I = len(inds)
496
+ ET = max(len(ind["observations"]) for ind in inds)
497
+ obs = torch.full((I, ET), self.cfg.pad_value_obs)
498
+ times = torch.full((I, ET), self.cfg.pad_value_time)
499
+ mask = torch.zeros((I, ET), dtype=torch.bool)
500
+
501
+ for i, ind in enumerate(inds):
502
+ o = torch.tensor(ind["observations"], dtype=torch.float32)
503
+ t = torch.tensor(ind["observation_times"], dtype=torch.float32)
504
+ L = o.shape[0]
505
+ obs[i, :L], times[i, :L], mask[i, :L] = o, t, True
506
+ return obs, times, mask
507
+
508
+ def _rebuild_individuals(
509
+ self,
510
+ inds: List[IndividualJSON],
511
+ obs_o: torch.Tensor,
512
+ time_o: torch.Tensor,
513
+ mask_o: torch.Tensor,
514
+ rem_o: Optional[torch.Tensor],
515
+ rem_t: Optional[torch.Tensor],
516
+ rem_m: Optional[torch.Tensor],
517
+ ) -> List[IndividualJSON]:
518
+ """Convert tensors back to JSON-like dicts for each individual."""
519
+ block_inds = []
520
+ for i in range(obs_o.shape[0]):
521
+ ind_dict: IndividualJSON = {
522
+ "observations": obs_o[i][mask_o[i]].tolist(),
523
+ "observation_times": time_o[i][mask_o[i]].tolist(),
524
+ }
525
+ name_id = inds[i].get("name_id") if i < len(inds) else None
526
+ if name_id:
527
+ ind_dict["name_id"] = name_id
528
+ if rem_o is not None and rem_m is not None:
529
+ ind_dict["remaining"] = rem_o[i][rem_m[i]].tolist()
530
+ ind_dict["remaining_times"] = rem_t[i][rem_m[i]].tolist()
531
+ block_inds.append(ind_dict)
532
+ return block_inds
533
+
534
+
535
+ def databatch_to_study_jsons(
536
+ batch: AICMECompartmentsDataBatch,
537
+ meta_dosing: MetaDosingConfig,
538
+ ) -> list[StudyJSON]:
539
+ """Convert an ``AICMECompartmentsDataBatch`` back to ``StudyJSON`` records.
540
+
541
+ Parameters
542
+ ----------
543
+ batch:
544
+ Batch carrying tensors with a leading study dimension ``B``.
545
+ meta_dosing:
546
+ Dosing configuration used to decode route type indices.
547
+
548
+ Returns
549
+ -------
550
+ List[StudyJSON]
551
+ One study per element along the batch dimension ``B``. Missing
552
+ ``study_name`` or ``substance_name`` entries are replaced by
553
+ fallback placeholders ``study_{b}`` and ``substance_{b}``.
554
+ """
555
+ route_options = meta_dosing.route_options
556
+ studies: list[StudyJSON] = []
557
+ B = batch.context_obs.shape[0]
558
+
559
+ def _block(
560
+ obs: TT["B", "I", "T", 1],
561
+ time: TT["B", "I", "T", 1],
562
+ mask: TT["B", "I", "T"],
563
+ rem: TT["B", "I", "R", 1],
564
+ rem_time: TT["B", "I", "R", 1],
565
+ rem_mask: TT["B", "I", "R"],
566
+ doses: TT["B", "I"],
567
+ routes: TT["B", "I"],
568
+ ind_mask: TT["B", "I"],
569
+ names: list[list[str]],
570
+ ) -> list[IndividualJSON]:
571
+ inds: list[IndividualJSON] = []
572
+ for i in range(obs.shape[1]):
573
+ if not ind_mask[b, i]:
574
+ continue
575
+ name_list = names[b] if b < len(names) else []
576
+ ind: IndividualJSON = {}
577
+ if i < len(name_list) and name_list[i]:
578
+ ind["name_id"] = name_list[i]
579
+ obs_i = obs[b, i, :, 0] # [T]
580
+ time_i = time[b, i, :, 0] # [T]
581
+ mask_i = mask[b, i] # [T]
582
+ ind["observations"] = obs_i[mask_i].tolist()
583
+ ind["observation_times"] = time_i[mask_i].tolist()
584
+ rem_i = rem[b, i, :, 0] # [R]
585
+ rem_time_i = rem_time[b, i, :, 0] # [R]
586
+ rem_mask_i = rem_mask[b, i] # [R]
587
+ rem_vals = rem_i[rem_mask_i].tolist()
588
+ rem_times = rem_time_i[rem_mask_i].tolist()
589
+ if rem_vals:
590
+ ind["remaining"] = rem_vals
591
+ ind["remaining_times"] = rem_times
592
+ dose = float(doses[b, i].item())
593
+ route_idx = int(routes[b, i].item())
594
+ if dose or route_idx:
595
+ route = (
596
+ route_options[route_idx] if route_idx < len(route_options) else str(route_idx)
597
+ )
598
+ ind["dosing"] = [dose]
599
+ ind["dosing_type"] = [route]
600
+ ind["dosing_times"] = [meta_dosing.time]
601
+ ind["dosing_name"] = [route]
602
+ inds.append(ind)
603
+ return inds
604
+
605
+ for b in range(B):
606
+ study_name = (
607
+ batch.study_name[b]
608
+ if b < len(batch.study_name) and batch.study_name[b]
609
+ else f"study_{b}"
610
+ )
611
+ substance_name = (
612
+ batch.substance_name[b]
613
+ if b < len(batch.substance_name) and batch.substance_name[b]
614
+ else f"substance_{b}"
615
+ )
616
+ meta = {"study_name": study_name, "substance_name": substance_name}
617
+ ctx = _block(
618
+ batch.context_obs,
619
+ batch.context_obs_time,
620
+ batch.context_obs_mask,
621
+ batch.context_rem_sim,
622
+ batch.context_rem_sim_time,
623
+ batch.context_rem_sim_mask,
624
+ batch.context_dosing_amounts,
625
+ batch.context_dosing_route_types,
626
+ batch.mask_context_individuals,
627
+ batch.context_subject_name,
628
+ )
629
+ tgt = _block(
630
+ batch.target_obs,
631
+ batch.target_obs_time,
632
+ batch.target_obs_mask,
633
+ batch.target_rem_sim,
634
+ batch.target_rem_sim_time,
635
+ batch.target_rem_sim_mask,
636
+ batch.target_dosing_amounts,
637
+ batch.target_dosing_route_types,
638
+ batch.mask_target_individuals,
639
+ batch.target_subject_name,
640
+ )
641
+ studies.append({"context": ctx, "target": tgt, "meta_data": meta})
642
+ return studies
643
+
644
+
645
+ def prediction_to_study_jsons(
646
+ prediction_sample: TT["S", "B", "It", "Tr", 1],
647
+ prediction_time: TT["S", "B", "It", "Tr", 1],
648
+ batch: AICMECompartmentsDataBatch,
649
+ meta_dosing: MetaDosingConfig,
650
+ ) -> list[StudyJSON]:
651
+ """Attach prediction samples to study records.
652
+
653
+ Parameters
654
+ ----------
655
+ prediction_sample:
656
+ Predicted trajectories with a leading sample dimension ``S``.
657
+ prediction_time:
658
+ Time points corresponding to ``prediction_sample``.
659
+ batch:
660
+ Original :class:`AICMECompartmentsDataBatch` used to generate the
661
+ predictions.
662
+ meta_dosing:
663
+ Dosing configuration for route decoding.
664
+
665
+ Returns
666
+ -------
667
+ list[StudyJSON]
668
+ Studies with ``prediction_samples`` and ``prediction_times`` fields in
669
+ each predicted target individual.
670
+
671
+ Notes
672
+ -----
673
+ Some predictive samplers (for example FlowPK individual prediction) may
674
+ return predictions for only a subset of target individuals compared with
675
+ the original batch. In that case this function keeps only the first ``It``
676
+ target entries (where ``It`` is inferred from ``prediction_sample``) so
677
+ JSON plots and exported records stay aligned with the predicted tensors.
678
+ """
679
+
680
+ studies = databatch_to_study_jsons(batch, meta_dosing)
681
+ _, B, It, _, _ = prediction_sample.shape # [S, B, It, Tr, 1]
682
+ for b in range(B):
683
+ # Keep studies aligned with the number of predicted target individuals.
684
+ studies[b]["target"] = studies[b]["target"][:It]
685
+ for i in range(min(It, len(studies[b]["target"]))):
686
+ samples = prediction_sample[:, b, i, :, 0] # [S, Tr]
687
+ times = prediction_time[0, b, i, :, 0] # [Tr]
688
+ studies[b]["target"][i]["prediction_samples"] = samples.tolist()
689
+ studies[b]["target"][i]["prediction_times"] = times.tolist()
690
+ return studies
691
+
692
+
693
+ def simulation_obs_to_study_json(
694
+ obs_out: torch.Tensor,
695
+ time_out: torch.Tensor,
696
+ mask_out: torch.Tensor,
697
+ rem_sim: Optional[torch.Tensor],
698
+ rem_time: Optional[torch.Tensor],
699
+ rem_mask: Optional[torch.Tensor],
700
+ dosing_config_array: list,
701
+ dosing_amounts: torch.Tensor,
702
+ study_config,
703
+ idx: int,
704
+ ) -> StudyJSON:
705
+ """Convert processed simulation tensors into a :class:`StudyJSON` entry.
706
+
707
+ Parameters
708
+ ----------
709
+ obs_out, time_out, mask_out:
710
+ Tensors describing the observed concentrations and time points for the
711
+ simulated individuals. ``mask_out`` identifies valid entries in the
712
+ padded tensors.
713
+ rem_sim, rem_time, rem_mask:
714
+ Optional tensors describing the remaining (unobserved) simulation
715
+ trajectory. When provided, the tensors must have the same leading
716
+ dimensions as ``obs_out`` and ``time_out`` with ``rem_mask`` marking
717
+ valid entries.
718
+ dosing_config_array:
719
+ Sequence with dosing configuration objects for each individual.
720
+ dosing_amounts:
721
+ Tensor containing the dosing amount per individual.
722
+ study_config:
723
+ Configuration object describing the simulated study. Only the
724
+ ``drug_id`` attribute is accessed, if present.
725
+ idx:
726
+ Index used to label the generated study name.
727
+
728
+ Returns
729
+ -------
730
+ StudyJSON
731
+ JSON-compatible dictionary describing the context block of the
732
+ simulation.
733
+ """
734
+
735
+ context: list[IndividualJSON] = []
736
+ num_individuals = obs_out.shape[0]
737
+
738
+ for ind_idx in range(num_individuals):
739
+ mask = mask_out[ind_idx].to(torch.bool)
740
+ observations = obs_out[ind_idx][mask].tolist()
741
+ observation_times = time_out[ind_idx][mask].tolist()
742
+
743
+ individual: IndividualJSON = {
744
+ "name_id": f"context_{ind_idx}",
745
+ "observations": observations,
746
+ "observation_times": observation_times,
747
+ }
748
+
749
+ if rem_sim is not None and rem_time is not None and rem_mask is not None:
750
+ rem_mask_row = rem_mask[ind_idx].to(torch.bool)
751
+ if rem_mask_row.any():
752
+ individual["remaining"] = rem_sim[ind_idx][rem_mask_row].tolist()
753
+ individual["remaining_times"] = rem_time[ind_idx][rem_mask_row].tolist()
754
+
755
+ dosing_cfg = dosing_config_array[ind_idx]
756
+ dose = float(dosing_amounts[ind_idx].item())
757
+ route = getattr(dosing_cfg, "route", "")
758
+ dosing_time = float(getattr(dosing_cfg, "time", 0.0))
759
+
760
+ if dose or route:
761
+ individual["dosing"] = [dose]
762
+ individual["dosing_type"] = [route]
763
+ individual["dosing_times"] = [dosing_time]
764
+ individual["dosing_name"] = [route]
765
+
766
+ context.append(individual)
767
+
768
+ study_json: StudyJSON = {
769
+ "context": context,
770
+ "target": [],
771
+ "meta_data": {
772
+ "study_name": f"simulated_study_{idx}",
773
+ "substance_name": getattr(study_config, "drug_id", "simulated_substance"),
774
+ },
775
+ }
776
+
777
+ return study_json
778
+
779
+
780
+ def held_out_ind_json(study: StudyJSON, max_held_out_individuals: int) -> List[StudyJSON]:
781
+ """Create study permutations with one individual moved to target.
782
+
783
+ Parameters
784
+ ----------
785
+ study:
786
+ Study JSON containing only context individuals (``target`` must be empty).
787
+ max_held_out_individuals:
788
+ Maximum number of permutations to generate.
789
+
790
+ Returns
791
+ -------
792
+ List[StudyJSON]
793
+ List with ``max_held_out_individuals`` studies where each of the first
794
+ ``len(context)`` entries corresponds to one context individual being
795
+ moved to the target block. Remaining entries repeat the original study
796
+ with an empty target.
797
+ """
798
+ context = list(study.get("context", []))
799
+ meta = dict(study.get("meta_data", {}))
800
+ out: List[StudyJSON] = []
801
+ n_ctx = len(context)
802
+ limit = min(max_held_out_individuals, n_ctx)
803
+ for idx in range(limit):
804
+ target = [context[idx]]
805
+ ctx = context[:idx] + context[idx + 1 :]
806
+ out.append({"context": ctx, "target": target, "meta_data": meta})
807
+ base = {"context": context, "target": [], "meta_data": meta}
808
+ while len(out) < max_held_out_individuals:
809
+ out.append(base)
810
+ return out
811
+
812
+
813
+ def held_out_list_json(
814
+ builder: JSON2AICMEBuilder,
815
+ studies: List[StudyJSON],
816
+ meta_dosing: MetaDosingConfig,
817
+ max_held_out_individuals: int,
818
+ ) -> List[AICMECompartmentsDataBatch]:
819
+ """
820
+ Generate batches for leave-one-out permutations across studies.
821
+
822
+ Parameters
823
+ ----------
824
+ builder:
825
+ Instance used to convert studies to :class:`AICMECompartmentsDataBatch`.
826
+ studies:
827
+ Studies where only the context block is populated.
828
+ meta_dosing:
829
+ Global dosing configuration.
830
+ max_held_out_individuals:
831
+ Maximum number of held-out permutations per study.
832
+
833
+ Returns
834
+ -------
835
+ List[AICMECompartmentsDataBatch]
836
+ ``max_held_out_individuals`` batches. The ``i``-th batch contains the
837
+ ``i``-th permutation from each study stacked along the batch
838
+ dimension.
839
+ """
840
+ per_study = [held_out_ind_json(s, max_held_out_individuals) for s in studies]
841
+ batches: List[AICMECompartmentsDataBatch] = []
842
+ for i in range(max_held_out_individuals):
843
+ perm = [per_study[j][i] for j in range(len(studies))]
844
+ batches.append(builder.build_one_aicmebatch(perm, meta_dosing))
845
+ return batches
846
+
847
+
848
+ def load_empirical_json_batches(
849
+ json_path: Path,
850
+ meta_dosing: Optional[MetaDosingConfig] = None,
851
+ stats: Optional[EmpiricalJSONStats] = None,
852
+ datamodule: Optional[AICMECompartmentsDataModule] = None,
853
+ ) -> List[AICMECompartmentsDataBatch]:
854
+ """
855
+ Load an empirical study JSON file and build leave-one-out batches.
856
+
857
+ We place all the individuals in the context
858
+
859
+ Parameters
860
+ ----------
861
+ json_path:
862
+ Path to a JSON file containing a list of :class:`StudyJSON` records.
863
+ meta_dosing:
864
+ Global dosing configuration. If ``None`` a default
865
+ :class:`MetaDosingConfig` is used.
866
+ stats:
867
+ Pre-computed statistics describing the dataset. When ``None`` the
868
+ statistics are calculated from ``json_path`` via
869
+ :func:`compute_json_stats`.
870
+ datamodule:
871
+ Optional synthetic data module providing shape information via
872
+ :meth:`AICMECompartmentsDataModule.obtain_shapes`. When given, these
873
+ shapes override those inferred from ``stats``.
874
+
875
+ Returns
876
+ -------
877
+ List[AICMECompartmentsDataBatch]
878
+ Leave-one-out batches constructed from the studies in ``json_path``.
879
+
880
+ Notes
881
+ -----
882
+ The function canonicalises all studies and either uses the provided
883
+ ``stats`` or computes them from the JSON file to determine the number of
884
+ leave-one-out permutations. When ``datamodule`` is supplied the padding
885
+ shapes ``(max_individuals, max_observations, max_remaining)`` are taken
886
+ from :meth:`AICMECompartmentsDataModule.obtain_shapes`.
887
+ """
888
+
889
+ # read file SHOULD BE A LIST OF STUDY JSON
890
+ with json_path.open() as f:
891
+ raw_studies = json.load(f)
892
+
893
+ if not isinstance(raw_studies, list):
894
+ raise ValueError("Expected JSON file to contain a list of StudyJSON records")
895
+
896
+ # ensure data quality
897
+ canon_studies: List[StudyJSON] = [
898
+ canonicalize_study(s, drop_tgt_too_few=False) for s in raw_studies
899
+ ]
900
+
901
+ # we set all the individuals as context
902
+ studies: List[StudyJSON] = []
903
+ for study in canon_studies:
904
+ all_individuals = list(study.get("context", [])) + list(study.get("target", []))
905
+ studies.append(
906
+ {"context": all_individuals, "target": [], "meta_data": study.get("meta_data", {})}
907
+ )
908
+
909
+ # define shapes
910
+ if not studies:
911
+ raise ValueError("No studies found in JSON file")
912
+ if datamodule is not None:
913
+ max_inds, max_obs, max_rem = datamodule.obtain_shapes() # (I, T, R)
914
+ ctx_cap = getattr(datamodule.train_dataset, "max_context_individuals", max_inds)
915
+ tgt_cap = getattr(datamodule.train_dataset, "n_of_target_individuals", max_inds)
916
+ else:
917
+ # compute statitics of the whole dataset
918
+ stats = compute_json_stats(canon_studies)
919
+ max_inds, max_obs, max_rem = (
920
+ stats.max_total_individuals,
921
+ stats.max_observations,
922
+ stats.max_remaining,
923
+ )
924
+ ctx_cap = max_inds
925
+ tgt_cap = max_inds
926
+
927
+ # the maximum batch is so that we have all the empirical at once
928
+ cfg = EmpiricalBatchConfig(
929
+ max_databatch_size=len(studies),
930
+ max_individuals=max_inds,
931
+ max_observations=max_obs,
932
+ max_remaining=max_rem,
933
+ max_context_individuals=ctx_cap,
934
+ max_target_individuals=tgt_cap,
935
+ )
936
+ builder = JSON2AICMEBuilder(cfg)
937
+ meta = meta_dosing or MetaDosingConfig()
938
+
939
+ return held_out_list_json(
940
+ builder, studies, meta, max_held_out_individuals=stats.max_total_individuals
941
+ )
942
+
943
+
944
+ def load_empirical_json_batches_as_dm(
945
+ json_path: Optional[Path] = None,
946
+ meta_dosing: Optional[MetaDosingConfig] = None,
947
+ stats: Optional[EmpiricalJSONStats] = None,
948
+ datamodule: Optional[AICMECompartmentsDataModule] = None,
949
+ raw_studies: Optional[List[StudyJSON]] = None,
950
+ *,
951
+ held_out: bool = True,
952
+ ) -> List[AICMECompartmentsDataBatch]:
953
+ """Load an empirical study JSON file and build leave-one-out batches.
954
+
955
+ This variant mirrors the empirical preprocessing performed by
956
+ :class:`AICMECompartmentsDataset` by relying on the observation strategies
957
+ of a provided :class:`AICMECompartmentsDataModule` and using
958
+ :meth:`JSON2AICMEBuilder.build_one_aicmebatch_as_dataset`.
959
+
960
+ Parameters
961
+ ----------
962
+ json_path:
963
+ Path to a JSON file containing a list of :class:`StudyJSON` records.
964
+ meta_dosing:
965
+ Global dosing configuration. If ``None`` a default
966
+ :class:`MetaDosingConfig` is used.
967
+ stats:
968
+ Pre-computed statistics describing the dataset. When ``None`` the
969
+ statistics are calculated from ``json_path`` via
970
+ :func:`compute_json_stats`.
971
+ datamodule:
972
+ Synthetic data module providing observation strategies and shape
973
+ information via :meth:`AICMECompartmentsDataModule.obtain_shapes`.
974
+ The module must be provided; its shapes override those inferred from
975
+ ``stats``.
976
+ held_out:
977
+ If ``True`` (default), build leave-one-out permutations (one empirical
978
+ individual in target). If ``False``, keep all empirical individuals in
979
+ context and return a single batch.
980
+
981
+ Returns
982
+ -------
983
+ List[AICMECompartmentsDataBatch]
984
+ Leave-one-out batches constructed from the studies in ``json_path``
985
+ using the datamodule's strategies.
986
+ """
987
+
988
+ if datamodule is None:
989
+ raise ValueError("datamodule must be provided to supply observation strategies")
990
+
991
+ if raw_studies is None:
992
+ with json_path.open() as f:
993
+ raw_studies = json.load(f)
994
+
995
+ if not isinstance(raw_studies, list):
996
+ raise ValueError("Expected JSON file to contain a list of StudyJSON records")
997
+
998
+ canon_studies: List[StudyJSON] = [
999
+ canonicalize_study(s, drop_tgt_too_few=False) for s in raw_studies
1000
+ ]
1001
+
1002
+ if stats is None:
1003
+ stats = compute_json_stats(canon_studies)
1004
+
1005
+ studies: List[StudyJSON] = []
1006
+ for study in canon_studies:
1007
+ all_inds = list(study.get("context", [])) + list(study.get("target", []))
1008
+ studies.append({"context": all_inds, "target": [], "meta_data": study.get("meta_data", {})})
1009
+
1010
+ if not studies:
1011
+ raise ValueError("No studies found in JSON file")
1012
+
1013
+ max_inds, max_obs, max_rem = datamodule.obtain_shapes() # (I, T, R)
1014
+ ctx_cap = getattr(datamodule.train_dataset, "max_context_individuals", max_inds)
1015
+ tgt_cap = getattr(datamodule.train_dataset, "n_of_target_individuals", max_inds)
1016
+ context_strategy = getattr(datamodule, "context_strategy", None)
1017
+ # For empirical targets we prefer the dedicated datamodule override
1018
+ # (legacy PK behavior + fixed capacities), falling back to target_strategy.
1019
+ target_strategy = getattr(datamodule, "empirical_target_strategy", None)
1020
+ if target_strategy is None:
1021
+ target_strategy = getattr(datamodule, "target_strategy", None)
1022
+ if context_strategy is None or target_strategy is None:
1023
+ raise ValueError("datamodule is missing context or target strategies")
1024
+
1025
+ ctx_obs_cap, ctx_rem_cap = context_strategy.get_shapes()
1026
+ tgt_obs_cap, tgt_rem_cap = target_strategy.get_shapes()
1027
+
1028
+ cfg = EmpiricalBatchConfig(
1029
+ max_databatch_size=len(studies),
1030
+ max_individuals=max_inds,
1031
+ max_observations=max_obs,
1032
+ max_remaining=max_rem,
1033
+ max_context_individuals=ctx_cap,
1034
+ max_target_individuals=tgt_cap,
1035
+ max_context_observations=ctx_obs_cap,
1036
+ max_target_observations=tgt_obs_cap,
1037
+ max_context_remaining=ctx_rem_cap,
1038
+ max_target_remaining=tgt_rem_cap,
1039
+ )
1040
+ builder = JSON2AICMEBuilder(cfg)
1041
+ meta = meta_dosing or MetaDosingConfig()
1042
+
1043
+ if held_out:
1044
+ return builder.build_one_aicmebatch_as_dataset(
1045
+ studies, context_strategy, target_strategy, meta
1046
+ )
1047
+ return builder.build_one_aicmebatch_as_dataset_no_heldout(
1048
+ studies, context_strategy, target_strategy, meta
1049
+ )
1050
+
1051
+
1052
+ def load_empirical_hf_batches_as_dm(
1053
+ repo_id: str,
1054
+ split: str = "train",
1055
+ meta_dosing: Optional[MetaDosingConfig] = None,
1056
+ stats: Optional[EmpiricalJSONStats] = None,
1057
+ datamodule: Optional[AICMECompartmentsDataModule] = None,
1058
+ *,
1059
+ held_out: bool = True,
1060
+ ) -> List[AICMECompartmentsDataBatch]:
1061
+ """Load a StudyJSON dataset from Hugging Face Hub.
1062
+
1063
+ Parameters
1064
+ ----------
1065
+ repo_id:
1066
+ Hugging Face dataset id.
1067
+ split:
1068
+ Dataset split to load.
1069
+ meta_dosing:
1070
+ Dosing configuration.
1071
+ stats:
1072
+ Optional precomputed dataset statistics.
1073
+ datamodule:
1074
+ Datamodule providing empirical shape and strategy information.
1075
+ held_out:
1076
+ If ``True`` (default), build leave-one-out permutations. If ``False``,
1077
+ keep all empirical individuals in context and return a single batch.
1078
+ """
1079
+
1080
+ if datamodule is None:
1081
+ raise ValueError("datamodule must be provided to supply observation strategies")
1082
+
1083
+ # Load from HF Hub
1084
+ ds = load_dataset(repo_id, split=split)
1085
+ raw_studies = [dict(study) for study in ds] # Hugging Face rows are dict-like
1086
+
1087
+ # reuse your old code
1088
+ canon_studies: List[StudyJSON] = [
1089
+ canonicalize_study(s, drop_tgt_too_few=False) for s in raw_studies
1090
+ ]
1091
+
1092
+ if stats is None:
1093
+ stats = compute_json_stats(canon_studies)
1094
+
1095
+ studies: List[StudyJSON] = []
1096
+ for study in canon_studies:
1097
+ all_inds = list(study.get("context", [])) + list(study.get("target", []))
1098
+ studies.append({"context": all_inds, "target": [], "meta_data": study.get("meta_data", {})})
1099
+
1100
+ if not studies:
1101
+ raise ValueError("No studies found in HF dataset")
1102
+
1103
+ max_inds, max_obs, max_rem = datamodule.obtain_shapes()
1104
+ ctx_cap = getattr(datamodule.train_dataset, "max_context_individuals", max_inds)
1105
+ tgt_cap = getattr(datamodule.train_dataset, "n_of_target_individuals", max_inds)
1106
+ context_strategy = getattr(datamodule, "context_strategy", None)
1107
+ # For empirical targets we prefer the dedicated datamodule override
1108
+ # (legacy PK behavior + fixed capacities), falling back to target_strategy.
1109
+ target_strategy = getattr(datamodule, "empirical_target_strategy", None)
1110
+ if target_strategy is None:
1111
+ target_strategy = getattr(datamodule, "target_strategy", None)
1112
+ if context_strategy is None or target_strategy is None:
1113
+ raise ValueError("datamodule is missing context or target strategies")
1114
+
1115
+ ctx_obs_cap, ctx_rem_cap = context_strategy.get_shapes()
1116
+ tgt_obs_cap, tgt_rem_cap = target_strategy.get_shapes()
1117
+
1118
+ cfg = EmpiricalBatchConfig(
1119
+ max_databatch_size=len(studies),
1120
+ max_individuals=max_inds,
1121
+ max_observations=max_obs,
1122
+ max_remaining=max_rem,
1123
+ max_context_individuals=ctx_cap,
1124
+ max_target_individuals=tgt_cap,
1125
+ max_context_observations=ctx_obs_cap,
1126
+ max_target_observations=tgt_obs_cap,
1127
+ max_context_remaining=ctx_rem_cap,
1128
+ max_target_remaining=tgt_rem_cap,
1129
+ )
1130
+ builder = JSON2AICMEBuilder(cfg)
1131
+ meta = meta_dosing or MetaDosingConfig()
1132
+
1133
+ if held_out:
1134
+ return builder.build_one_aicmebatch_as_dataset(
1135
+ studies, context_strategy, target_strategy, meta
1136
+ )
1137
+ return builder.build_one_aicmebatch_as_dataset_no_heldout(
1138
+ studies, context_strategy, target_strategy, meta
1139
+ )
sim_priors_pk/data/data_empirical/json_schema.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TypedDict schemas for empirical pharmacokinetic JSON inputs."""
2
+
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypedDict
4
+
5
+ try: # pragma: no cover - optional torch dependency
6
+ import torch
7
+ from torchtyping import TensorType as TT
8
+ except ModuleNotFoundError: # pragma: no cover - allow missing torch
9
+ torch = None # type: ignore
10
+ TT = object # type: ignore
11
+
12
+ if TYPE_CHECKING: # pragma: no cover - typing only
13
+ from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
14
+
15
+
16
+ class IndividualJSON(TypedDict, total=False):
17
+ """Schema for a single individual's PK data.
18
+
19
+ Optional ``prediction_samples`` and ``prediction_times`` fields allow
20
+ storing model forecasts for the individual's future trajectory.
21
+ Each element in ``prediction_samples`` corresponds to a full simulated
22
+ trajectory for the times listed in ``prediction_times``.
23
+ """
24
+
25
+ name_id: str
26
+ observations: List[float]
27
+ observation_times: List[float]
28
+ remaining: List[float]
29
+ remaining_times: List[float]
30
+ dosing: List[float]
31
+ dosing_type: List[str]
32
+ dosing_times: List[float]
33
+ dosing_name: List[str]
34
+ prediction_samples: List[List[float]]
35
+ prediction_times: List[float]
36
+ covariates: Dict[str, object]
37
+
38
+
39
+ class StudyJSON(TypedDict):
40
+ """Schema for a full study consisting of context and target individuals."""
41
+
42
+ context: List[IndividualJSON]
43
+ target: List[IndividualJSON]
44
+ meta_data: Dict[str, str]
45
+
46
+
47
+ MIN_OBS_DEFAULT = 0
48
+
49
+
50
+ class ValidationError(Exception):
51
+ """Raised when data do not conform to :class:`StudyJSON`."""
52
+
53
+ pass
54
+
55
+
56
+ def canonicalize_individual(
57
+ ind: IndividualJSON,
58
+ *,
59
+ min_obs: int = MIN_OBS_DEFAULT,
60
+ drop_if_too_few: bool = True,
61
+ ) -> Optional[IndividualJSON]:
62
+ """Return a canonical version of ``ind``.
63
+
64
+ Parameters
65
+ ----------
66
+ ind:
67
+ Individual JSON record to canonicalize. The input dictionary is **not**
68
+ mutated.
69
+ min_obs:
70
+ Minimum required number of observations. Defaults to
71
+ :data:`MIN_OBS_DEFAULT`.
72
+ drop_if_too_few:
73
+ If ``True`` and the individual has fewer than ``min_obs`` observations
74
+ after sorting/de-duplication, ``None`` is returned.
75
+
76
+ Returns
77
+ -------
78
+ Optional[IndividualJSON]
79
+ Canonicalized record or ``None`` when dropped.
80
+
81
+ Notes
82
+ -----
83
+ The function performs the following steps:
84
+
85
+ - Validate the presence and equal length of ``observations`` and
86
+ ``observation_times``.
87
+ - Sort observations by ascending time and remove duplicate time entries
88
+ keeping the first occurrence.
89
+ - Optionally drop the individual if the number of observations is below
90
+ ``min_obs``.
91
+ - Ensure ``remaining``/``remaining_times`` are disjoint from
92
+ ``observation_times`` and of equal length.
93
+ - If any dosing related fields are provided, require that all dosing fields
94
+ are present and have equal lengths.
95
+ """
96
+
97
+ # --- observations & times ---
98
+ if "observations" not in ind or "observation_times" not in ind:
99
+ raise ValidationError("observations and observation_times are required")
100
+
101
+ obs = list(ind["observations"])
102
+ times = list(ind["observation_times"])
103
+ if len(obs) != len(times):
104
+ raise ValidationError("observations and observation_times must match in length")
105
+
106
+ # sort and de-duplicate by time (stable sort keeps first occurrence)
107
+ pairs = sorted(zip(times, obs), key=lambda x: x[0])
108
+ seen = set()
109
+ obs_sorted: List[float] = []
110
+ times_sorted: List[float] = []
111
+ for t, o in pairs:
112
+ if t in seen:
113
+ continue
114
+ seen.add(t)
115
+ times_sorted.append(t)
116
+ obs_sorted.append(o)
117
+
118
+ if len(obs_sorted) < min_obs and drop_if_too_few:
119
+ return None
120
+
121
+ new_ind: IndividualJSON = {}
122
+ if "name_id" in ind:
123
+ new_ind["name_id"] = ind["name_id"]
124
+ new_ind["observations"] = obs_sorted
125
+ new_ind["observation_times"] = times_sorted
126
+
127
+ # --- remaining ---
128
+ has_rem = "remaining" in ind or "remaining_times" in ind
129
+ if has_rem:
130
+ if "remaining" not in ind or "remaining_times" not in ind:
131
+ raise ValidationError(
132
+ "both remaining and remaining_times required when one is provided"
133
+ )
134
+ rem = list(ind["remaining"])
135
+ rem_t = list(ind["remaining_times"])
136
+ if len(rem) != len(rem_t):
137
+ raise ValidationError("remaining and remaining_times must match in length")
138
+ obs_time_set = set(times_sorted)
139
+ rem_filtered: List[float] = []
140
+ rem_t_filtered: List[float] = []
141
+ for t, r in zip(rem_t, rem):
142
+ if t in obs_time_set:
143
+ continue
144
+ rem_t_filtered.append(t)
145
+ rem_filtered.append(r)
146
+ new_ind["remaining"] = rem_filtered
147
+ new_ind["remaining_times"] = rem_t_filtered
148
+
149
+ # --- dosing ---
150
+ dosing_keys = ["dosing", "dosing_type", "dosing_times", "dosing_name"]
151
+ present_dosing = [k for k in dosing_keys if k in ind]
152
+ if present_dosing:
153
+ if len(present_dosing) != len(dosing_keys):
154
+ raise ValidationError("all dosing fields must be present when dosing is provided")
155
+ lengths = [len(ind[k]) for k in dosing_keys] # type: ignore[index]
156
+ if len(set(lengths)) != 1:
157
+ raise ValidationError("dosing fields must have equal lengths")
158
+ for k in dosing_keys:
159
+ new_ind[k] = list(ind[k]) # type: ignore[index]
160
+
161
+ # --- covariates ---
162
+ if "covariates" in ind:
163
+ new_ind["covariates"] = dict(ind["covariates"])
164
+
165
+ # --- prediction samples ---
166
+ if "prediction_samples" in ind:
167
+ new_ind["prediction_samples"] = [list(s) for s in ind["prediction_samples"]]
168
+ if "prediction_times" in ind:
169
+ new_ind["prediction_times"] = list(ind["prediction_times"])
170
+ if "prediction_mean" in ind:
171
+ new_ind["prediction_mean"] = list(ind["prediction_mean"])
172
+ if "prediction_std" in ind:
173
+ new_ind["prediction_std"] = list(ind["prediction_std"])
174
+
175
+ return new_ind
176
+
177
+
178
+ def canonicalize_study(
179
+ study: StudyJSON,
180
+ *,
181
+ min_obs_ctx: int = MIN_OBS_DEFAULT,
182
+ min_obs_tgt: int = MIN_OBS_DEFAULT,
183
+ drop_tgt_too_few: bool = True,
184
+ ) -> StudyJSON:
185
+ """Canonicalize all individuals in ``study`` and validate meta data."""
186
+
187
+ meta = study.get("meta_data", {})
188
+ if not meta.get("study_name") or not meta.get("substance_name"):
189
+ raise ValidationError("meta_data must include non-empty study_name and substance_name")
190
+
191
+ context_canon: List[IndividualJSON] = []
192
+ for ind in study.get("context", []):
193
+ canon = canonicalize_individual(ind, min_obs=min_obs_ctx, drop_if_too_few=False)
194
+ if canon is not None:
195
+ context_canon.append(canon)
196
+
197
+ target_canon: List[IndividualJSON] = []
198
+ for ind in study.get("target", []):
199
+ canon = canonicalize_individual(ind, min_obs=min_obs_tgt, drop_if_too_few=drop_tgt_too_few)
200
+ if canon is not None:
201
+ target_canon.append(canon)
202
+
203
+ new_study: StudyJSON = {
204
+ "context": context_canon,
205
+ "target": target_canon,
206
+ "meta_data": dict(meta),
207
+ }
208
+ return new_study
209
+
210
+
211
+ def studies_from_sampled_targets(
212
+ *,
213
+ db: "AICMECompartmentsDataBatch",
214
+ samples: "TT['S', 'B', 'T', 1]",
215
+ times: "TT['B', 'T', 1]",
216
+ mask: "TT['B', 'T']",
217
+ route_options: Sequence[str],
218
+ dosing_time: float,
219
+ name_prefix: str = "new_individual",
220
+ ) -> List[StudyJSON]:
221
+ """Convert sampled trajectories into :class:`StudyJSON` records.
222
+
223
+ Parameters
224
+ ----------
225
+ db:
226
+ Batch containing the conditioning study information. Only the fields
227
+ accessed in this function are required, allowing reuse with compatible
228
+ NamedTuple implementations used throughout the project.
229
+ samples, times, mask:
230
+ Output tensors from ``sample_new_individual`` where ``samples`` carries
231
+ the simulated trajectories, ``times`` their corresponding decode times
232
+ and ``mask`` selects valid entries along the temporal dimension.
233
+ route_options:
234
+ Lookup table translating dosing route indices into human readable
235
+ labels. Indices outside the provided range are returned as their string
236
+ representation.
237
+ dosing_time:
238
+ Absolute time at which the dosing event occurred. Used for both context
239
+ and newly sampled target individuals when dosing information is
240
+ present.
241
+ name_prefix:
242
+ Prefix for generated target individual identifiers. Defaults to
243
+ ``"new_individual"``.
244
+
245
+ Returns
246
+ -------
247
+ list[StudyJSON]
248
+ One ``StudyJSON`` per batch element in ``db``.
249
+ """
250
+
251
+ if torch is None:
252
+ raise ValidationError("torch is required to build StudyJSON records from tensors")
253
+
254
+ S, B, _, _ = samples.shape
255
+ studies: List[StudyJSON] = []
256
+
257
+ for b in range(B):
258
+ study_name = (
259
+ db.study_name[b] if b < len(db.study_name) and db.study_name[b] else f"study_{b}"
260
+ )
261
+ substance_name = (
262
+ db.substance_name[b]
263
+ if b < len(db.substance_name) and db.substance_name[b]
264
+ else f"substance_{b}"
265
+ )
266
+
267
+ context_list: List[IndividualJSON] = []
268
+ I = db.context_obs.shape[1]
269
+ for i in range(I):
270
+ if not db.mask_context_individuals[b, i]:
271
+ continue
272
+
273
+ ind: IndividualJSON = {}
274
+ if b < len(db.context_subject_name) and i < len(db.context_subject_name[b]):
275
+ name = db.context_subject_name[b][i]
276
+ if name:
277
+ ind["name_id"] = name
278
+
279
+ obs_i = db.context_obs[b, i, :, 0]
280
+ time_i = db.context_obs_time[b, i, :, 0]
281
+ mask_i = db.context_obs_mask[b, i]
282
+ ind["observations"] = obs_i[mask_i].tolist()
283
+ ind["observation_times"] = time_i[mask_i].tolist()
284
+
285
+ if db.context_rem_sim.shape[2] > 0:
286
+ rem_i = db.context_rem_sim[b, i, :, 0]
287
+ rem_t = db.context_rem_sim_time[b, i, :, 0]
288
+ rem_m = db.context_rem_sim_mask[b, i]
289
+ rem_vals = rem_i[rem_m].tolist()
290
+ rem_times = rem_t[rem_m].tolist()
291
+ if rem_vals:
292
+ ind["remaining"] = rem_vals
293
+ ind["remaining_times"] = rem_times
294
+
295
+ dose = float(db.context_dosing_amounts[b, i].item())
296
+ route_idx = int(db.context_dosing_route_types[b, i].item())
297
+ if dose or route_idx:
298
+ route = (
299
+ route_options[route_idx] if route_idx < len(route_options) else str(route_idx)
300
+ )
301
+ ind["dosing"] = [dose]
302
+ ind["dosing_type"] = [route]
303
+ ind["dosing_times"] = [dosing_time]
304
+ ind["dosing_name"] = [route]
305
+
306
+ context_list.append(ind)
307
+
308
+ target_list: List[IndividualJSON] = []
309
+ valid_mask = mask[b]
310
+ valid_times = times[b, valid_mask, 0].tolist()
311
+ for s in range(S):
312
+ traj = samples[s, b, valid_mask, 0].tolist()
313
+ ind: IndividualJSON = {
314
+ "name_id": f"{name_prefix}_{s}",
315
+ "observations": traj,
316
+ "observation_times": valid_times,
317
+ }
318
+ dose = float(db.target_dosing_amounts[b, 0].item())
319
+ route_idx = int(db.target_dosing_route_types[b, 0].item())
320
+ if dose or route_idx:
321
+ route = (
322
+ route_options[route_idx] if route_idx < len(route_options) else str(route_idx)
323
+ )
324
+ ind["dosing"] = [dose]
325
+ ind["dosing_type"] = [route]
326
+ ind["dosing_times"] = [dosing_time]
327
+ ind["dosing_name"] = [route]
328
+ target_list.append(ind)
329
+
330
+ studies.append(
331
+ {
332
+ "context": context_list,
333
+ "target": target_list,
334
+ "meta_data": {
335
+ "study_name": study_name,
336
+ "substance_name": substance_name,
337
+ },
338
+ }
339
+ )
340
+
341
+ return studies
342
+
343
+
344
+ def prediction_stats(study: StudyJSON) -> StudyJSON:
345
+ """Compute prediction mean and std for target individuals.
346
+
347
+ For each target individual with ``prediction_samples`` the function
348
+ calculates the mean and standard deviation across the sample dimension and
349
+ stores the results in ``prediction_mean`` and ``prediction_std`` fields.
350
+
351
+ Parameters
352
+ ----------
353
+ study:
354
+ ``StudyJSON`` record containing prediction samples.
355
+
356
+ Returns
357
+ -------
358
+ StudyJSON
359
+ The input study where target individuals now also carry ``prediction_mean``
360
+ and ``prediction_std`` fields. The input mapping is mutated for
361
+ convenience.
362
+ """
363
+
364
+ for ind in study.get("target", []):
365
+ samples = ind.get("prediction_samples")
366
+ if samples:
367
+ if torch is None:
368
+ raise ValidationError("torch is required to compute prediction summaries")
369
+ samples_t: TT["S", "Tr"] = torch.tensor(samples)
370
+ ind["prediction_mean"] = samples_t.mean(dim=0).tolist()
371
+ ind["prediction_std"] = samples_t.std(dim=0, unbiased=False).tolist()
372
+ return study
sim_priors_pk/data/data_empirical/json_stats.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is only used for checking the shapes of the empirical data that are passed to the Dataloader"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import defaultdict
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Sequence, Set
8
+
9
+ from .json_schema import StudyJSON
10
+
11
+
12
+ @dataclass
13
+ class EmpiricalJSONStats:
14
+ """Basic statistics collected from empirical study JSON files.
15
+
16
+ Attributes
17
+ ----------
18
+ min_context_individuals, max_context_individuals:
19
+ Range of context individuals across studies.
20
+ min_target_individuals, max_target_individuals:
21
+ Range of target individuals across studies.
22
+ min_observation, max_observation:
23
+ Extremal observed values across all individuals.
24
+ substances:
25
+ Sorted list of distinct substance names.
26
+ max_total_individuals:
27
+ Maximum combined number of context and target individuals in a study.
28
+ max_observations:
29
+ Maximum number of observation time points for any individual.
30
+ max_remaining:
31
+ Maximum number of remaining time points for any individual.
32
+ substance_summaries:
33
+ Nested mapping keyed by substance name containing per-substance
34
+ statistics. Each inner dictionary exposes the total number of
35
+ individuals, the minimum and maximum number of observation time points
36
+ per individual, and the sorted list of unique time steps observed
37
+ across all individuals (including observation and remaining times) for
38
+ the substance.
39
+ studies_by_substance:
40
+ Mapping from substance name to the list of studies associated with it.
41
+ """
42
+
43
+ min_context_individuals: int
44
+ max_context_individuals: int
45
+ min_target_individuals: int
46
+ max_target_individuals: int
47
+ min_observation: float
48
+ max_observation: float
49
+ substances: List[str]
50
+ max_total_individuals: int
51
+ max_observations: int
52
+ max_remaining: int
53
+ substance_summaries: Dict[str, Dict[str, object]]
54
+ studies_by_substance: Dict[str, List[StudyJSON]]
55
+
56
+ def studies_for_substance(self, substance: str) -> List[StudyJSON]:
57
+ """Return all studies that reference ``substance``.
58
+
59
+ Parameters
60
+ ----------
61
+ substance:
62
+ Name of the substance whose studies should be returned.
63
+
64
+ Returns
65
+ -------
66
+ List[StudyJSON]
67
+ Study dictionaries associated with ``substance``. An empty list is
68
+ returned when the substance was not observed.
69
+ """
70
+
71
+ return list(self.studies_by_substance.get(substance, []))
72
+
73
+ def get_substance_summary(self, substance: str) -> Dict[str, object]:
74
+ """Return the per-substance statistics for ``substance``.
75
+
76
+ Parameters
77
+ ----------
78
+ substance:
79
+ Name of the substance whose statistics should be retrieved.
80
+
81
+ Returns
82
+ -------
83
+ Dict[str, object]
84
+ Dictionary containing the ``individual_count``,
85
+ ``min_observations``, ``max_observations`` and
86
+ ``observation_time_steps`` entries. An empty dictionary is returned
87
+ if the substance is unknown.
88
+ """
89
+
90
+ summary = self.substance_summaries.get(substance)
91
+ if summary is None:
92
+ return {}
93
+ return dict(summary)
94
+
95
+
96
+ def compute_json_stats(studies: Sequence[StudyJSON]) -> EmpiricalJSONStats:
97
+ """Compute statistics across empirical pharmacokinetic studies.
98
+
99
+ Parameters
100
+ ----------
101
+ studies:
102
+ Sequence of :class:`StudyJSON` objects to aggregate.
103
+
104
+ Returns
105
+ -------
106
+ EmpiricalJSONStats
107
+ Aggregated statistics across all provided studies.
108
+ """
109
+
110
+ min_ctx = float("inf")
111
+ max_ctx = 0
112
+ min_tgt = float("inf")
113
+ max_tgt = 0
114
+ min_obs = float("inf")
115
+ max_obs = float("-inf")
116
+ substances = set()
117
+ max_total_inds = 0
118
+ max_obs_len = 0
119
+ max_rem_len = 0
120
+
121
+ substance_counts: Dict[str, int] = defaultdict(int)
122
+ substance_min_obs: Dict[str, int] = {}
123
+ substance_max_obs: Dict[str, int] = {}
124
+ substance_times: Dict[str, Set[float]] = defaultdict(set)
125
+ studies_by_substance: Dict[str, List[StudyJSON]] = defaultdict(list)
126
+
127
+ for study in studies:
128
+ c_len = len(study.get("context", []))
129
+ t_len = len(study.get("target", []))
130
+ total_len = c_len + t_len
131
+ min_ctx = min(min_ctx, c_len)
132
+ max_ctx = max(max_ctx, c_len)
133
+ min_tgt = min(min_tgt, t_len)
134
+ max_tgt = max(max_tgt, t_len)
135
+ max_total_inds = max(max_total_inds, total_len)
136
+ meta = study.get("meta_data", {})
137
+ substance = meta.get("substance_name")
138
+ if substance:
139
+ substances.add(substance)
140
+ studies_by_substance[substance].append(study)
141
+ for ind in study.get("context", []) + study.get("target", []):
142
+ obs = ind.get("observations", [])
143
+ obs_len = len(obs)
144
+ rem = ind.get("remaining", [])
145
+ times = ind.get("observation_times", [])
146
+ rem_times = ind.get("remaining_times", [])
147
+
148
+ max_obs_len = max(max_obs_len, len(obs))
149
+ max_rem_len = max(max_rem_len, len(rem))
150
+ if obs:
151
+ min_obs = min(min_obs, min(obs))
152
+ max_obs = max(max_obs, max(obs))
153
+
154
+ if substance:
155
+ substance_counts[substance] += 1
156
+ current_min = substance_min_obs.get(substance)
157
+ if current_min is None:
158
+ substance_min_obs[substance] = obs_len
159
+ else:
160
+ substance_min_obs[substance] = min(current_min, obs_len)
161
+ current_max = substance_max_obs.get(substance)
162
+ if current_max is None:
163
+ substance_max_obs[substance] = obs_len
164
+ else:
165
+ substance_max_obs[substance] = max(current_max, obs_len)
166
+ substance_times[substance].update(times)
167
+ substance_times[substance].update(rem_times)
168
+
169
+ if min_ctx == float("inf"):
170
+ min_ctx = 0
171
+ if min_tgt == float("inf"):
172
+ min_tgt = 0
173
+ if min_obs == float("inf"):
174
+ min_obs = float("nan")
175
+ if max_obs == float("-inf"):
176
+ max_obs = float("nan")
177
+
178
+ substance_summaries = {
179
+ substance: {
180
+ "individual_count": substance_counts.get(substance, 0),
181
+ "min_observations": substance_min_obs.get(substance, 0),
182
+ "max_observations": substance_max_obs.get(substance, 0),
183
+ "observation_time_steps": sorted(substance_times.get(substance, set())),
184
+ }
185
+ for substance in sorted(substances)
186
+ }
187
+
188
+ return EmpiricalJSONStats(
189
+ min_context_individuals=int(min_ctx),
190
+ max_context_individuals=int(max_ctx),
191
+ min_target_individuals=int(min_tgt),
192
+ max_target_individuals=int(max_tgt),
193
+ min_observation=float(min_obs),
194
+ max_observation=float(max_obs),
195
+ substances=sorted(substances),
196
+ max_total_individuals=int(max_total_inds),
197
+ max_observations=int(max_obs_len),
198
+ max_remaining=int(max_rem_len),
199
+ substance_summaries=substance_summaries,
200
+ studies_by_substance={k: list(v) for k, v in studies_by_substance.items()},
201
+ )
sim_priors_pk/data/data_empirical/simulx_to_json.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tools for converting simulx output .csv files (simulation from an NLME model) to study JSON format
3
+ """
4
+
5
+ import csv
6
+ from collections import defaultdict
7
+ from typing import Sequence
8
+
9
+ from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
10
+
11
+
12
+ def simulx_to_json(
13
+ csv_path,
14
+ study_name="simulated_study",
15
+ substance_name="Drug_A",
16
+ dosing_type="oral"
17
+ ) -> Sequence[StudyJSON]:
18
+ # rep -> ID -> data
19
+ reps = defaultdict(lambda: defaultdict(lambda: {
20
+ "observations": [],
21
+ "observation_times": [],
22
+ "dosing": [],
23
+ "dosing_type": [],
24
+ "dosing_times": [],
25
+ "dosing_name": []
26
+ }))
27
+
28
+ with open(csv_path, newline="") as f:
29
+ reader = csv.DictReader(f)
30
+ for row in reader:
31
+ rep = int(row["rep"])
32
+ id_ = row["ID"]
33
+ time = float(row["TIME"])
34
+
35
+ # Observations
36
+ if row["value"] != ".":
37
+ reps[rep][id_]["observations"].append(float(row["value"]))
38
+ reps[rep][id_]["observation_times"].append(time)
39
+
40
+ # Dosing (assumed at TIME == 0)
41
+ if time == 0 and row["AMOUNT"] != ".":
42
+ reps[rep][id_]["dosing"].append(float(row["AMOUNT"]))
43
+ reps[rep][id_]["dosing_times"].append(0.0)
44
+ reps[rep][id_]["dosing_type"].append(dosing_type)
45
+ reps[rep][id_]["dosing_name"].append(dosing_type)
46
+
47
+ # Build final output: one JSON object per rep
48
+ output = []
49
+
50
+ for rep, ids in sorted(reps.items()):
51
+ contexts = []
52
+ for i, (id_, data) in enumerate(ids.items()):
53
+ contexts.append({
54
+ "name_id": f"context_{id_}",
55
+ **data
56
+ })
57
+
58
+ study = StudyJSON({
59
+ "context": contexts,
60
+ "meta_data": {
61
+ "study_name": f"{study_name}_rep{rep}",
62
+ "substance_name": substance_name
63
+ }
64
+ })
65
+ output.append(study)
66
+
67
+ return output
68
+
69
+ if __name__ == "__main__":
70
+ output = simulx_to_json(csv_path="data/raw_nlme_simulx/indometacin-test-data.csv")
71
+
sim_priors_pk/data/data_generation/__init__.py ADDED
File without changes
sim_priors_pk/data/data_generation/compartment_models.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+ from typing import Callable, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torchdiffeq import odeint
8
+ from torchtyping import TensorType
9
+
10
+ from sim_priors_pk.config_classes.data_config import (
11
+ DosingConfig,
12
+ DosingWithDurationConfig,
13
+ MetaDosingConfig,
14
+ MetaDosingWithDurationConfig,
15
+ MetaStudyConfig,
16
+ )
17
+ from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig
18
+
19
+
20
+ @dataclass
21
+ class StudyConfig:
22
+ """
23
+ This corresponds to the configuration of one study
24
+ """
25
+
26
+ drug_id: str # Identifier for the drug
27
+ num_individuals: int # Number of individuals in the population
28
+ num_peripherals: int # Number of peripheral compartments
29
+ log_k_a_mean: float # Mean absorption rate constant
30
+ log_k_a_std: float # Standard deviation for absorption rate constant
31
+ k_a_tmag: float # Magnitude of time-dependent variation of absorption rate constant
32
+ k_a_tscl: float # Scale of time-dependent variation of absorption rate constant
33
+ log_k_e_mean: float # Mean elimination rate constant
34
+ log_k_e_std: float # Standard deviation for elimination rate constant
35
+ k_e_tmag: float # Magnitude of time-dependent variation of elimination rate constant
36
+ k_e_tscl: float # Scale of time-dependent variation of elimination rate constant
37
+ log_V_mean: float # Mean volume of central compartment
38
+ log_V_std: float # Standard deviation for volume of central compartment
39
+ V_tmag: float # Magnitude of time-dependent variation of volume of central compartment
40
+ V_tscl: float # Scale of time-dependent variation of volume of central compartment
41
+ log_k_1p_mean: List[float] # Mean rate constants (central to other peripherals)
42
+ log_k_1p_std: List[float] # Standard deviations for k_1p
43
+ k_1p_tmag: List[float] # Magnitude of time-dependent variation of k_1p
44
+ k_1p_tscl: List[float] # Scale of time-dependent variation of k_1p
45
+ log_k_p1_mean: List[float] # Mean rate constants (other peripherals to central)
46
+ log_k_p1_std: List[float] # Standard deviations for k_p1
47
+ k_p1_tmag: List[float] # Magnitude of time-dependent variation of k_p1
48
+ k_p1_tscl: List[float] # Scale of time-dependent variation of k_p1
49
+ time_start: float # Start time for the study
50
+ time_stop: float # End time for the study
51
+ rel_ruv: float # Relative residual unexplained variability for the study
52
+
53
+
54
+ @dataclass
55
+ class IndividualConfig:
56
+ """
57
+ This corresponds to the configuration of one individual.
58
+ """
59
+
60
+ num_peripherals: int = 2 # Number of peripheral compartments
61
+ k_a: Callable[[float], float] = lambda t: 0.1 # Absorption rate constant (gut to central)
62
+ k_e: Callable[[float], float] = lambda t: 0.05 # Elimination rate constant (central)
63
+ V: Callable[[float], float] = lambda t: 0.05 # Volume of central compartment
64
+ k_1p: List[Callable[[float], float]] = field(
65
+ default_factory=lambda: [lambda t: 0.01, lambda t: 0.01]
66
+ ) # Rate constants from central to other peripherals
67
+ k_p1: List[Callable[[float], float]] = field(
68
+ default_factory=lambda: [lambda t: 0.01, lambda t: 0.01]
69
+ ) # Rate constants from other peripherals to central
70
+ rel_ruv: float = 0.1 # Relative residual unexplained variability per individual
71
+
72
+
73
+ def sample_study_config(config: MetaStudyConfig):
74
+ """
75
+ Samples a StudyConfig object based on the MetaStudyConfig.
76
+ """
77
+ # Generate random values for each parameter
78
+ drug_id = random.choice(config.drug_id_options)
79
+ num_individuals = random.randint(*config.num_individuals_range)
80
+ num_peripherals = random.randint(*config.num_peripherals_range)
81
+
82
+ # Sample mean, std, and tmag for each rate constant
83
+ log_k_a_mean = random.uniform(*config.log_k_a_mean_range)
84
+ log_k_a_std = random.uniform(*config.log_k_a_std_range)
85
+ k_a_tmag = random.uniform(*config.k_a_tmag_range)
86
+ k_a_tscl = random.uniform(*config.k_a_tscl_range)
87
+
88
+ log_k_e_mean = random.uniform(*config.log_k_e_mean_range)
89
+ log_k_e_std = random.uniform(*config.log_k_e_std_range)
90
+ k_e_tmag = random.uniform(*config.k_e_tmag_range)
91
+ k_e_tscl = random.uniform(*config.k_e_tscl_range)
92
+
93
+ log_V_mean = random.uniform(*config.log_V_mean_range)
94
+ log_V_std = random.uniform(*config.log_V_std_range)
95
+ V_tmag = random.uniform(*config.V_tmag_range)
96
+ V_tscl = random.uniform(*config.V_tscl_range)
97
+
98
+ log_k_1p_mean = [random.uniform(*config.log_k_1p_mean_range) for _ in range(num_peripherals)]
99
+ log_k_1p_std = [random.uniform(*config.log_k_1p_std_range) for _ in range(num_peripherals)]
100
+ k_1p_tmag = [random.uniform(*config.k_1p_tmag_range) for _ in range(num_peripherals)]
101
+ k_1p_tscl = [random.uniform(*config.k_1p_tscl_range) for _ in range(num_peripherals)]
102
+
103
+ log_k_p1_mean = [random.uniform(*config.log_k_p1_mean_range) for _ in range(num_peripherals)]
104
+ log_k_p1_std = [random.uniform(*config.log_k_p1_std_range) for _ in range(num_peripherals)]
105
+ k_p1_tmag = [random.uniform(*config.k_p1_tmag_range) for _ in range(num_peripherals)]
106
+ k_p1_tscl = [random.uniform(*config.k_p1_tscl_range) for _ in range(num_peripherals)]
107
+
108
+ rel_ruv = random.uniform(*config.rel_ruv_range)
109
+
110
+ return StudyConfig(
111
+ drug_id=drug_id,
112
+ num_individuals=num_individuals,
113
+ num_peripherals=num_peripherals,
114
+ log_k_a_mean=log_k_a_mean,
115
+ log_k_a_std=log_k_a_std,
116
+ k_a_tmag=k_a_tmag,
117
+ k_a_tscl=k_a_tscl,
118
+ log_k_e_mean=log_k_e_mean,
119
+ log_k_e_std=log_k_e_std,
120
+ k_e_tmag=k_e_tmag,
121
+ k_e_tscl=k_e_tscl,
122
+ log_V_mean=log_V_mean,
123
+ log_V_std=log_V_std,
124
+ V_tmag=V_tmag,
125
+ V_tscl=V_tscl,
126
+ log_k_1p_mean=log_k_1p_mean,
127
+ log_k_1p_std=log_k_1p_std,
128
+ k_1p_tmag=k_1p_tmag,
129
+ k_1p_tscl=k_1p_tscl,
130
+ log_k_p1_mean=log_k_p1_mean,
131
+ log_k_p1_std=log_k_p1_std,
132
+ k_p1_tmag=k_p1_tmag,
133
+ k_p1_tscl=k_p1_tscl,
134
+ time_start=config.time_start,
135
+ time_stop=config.time_stop,
136
+ rel_ruv=rel_ruv,
137
+ )
138
+
139
+
140
+ def sample_rate_function(mean_rate, variability, variability_type="sinusoidal"):
141
+ """
142
+ Samples a time-dependent rate function.
143
+ :param mean_rate: Mean rate constant
144
+ :param variability: Variability in the rate constant
145
+ :param variability_type: Type of variability ("sinusoidal" or "decaying")
146
+ :return: A time-dependent rate function
147
+ """
148
+ if variability_type == "sinusoidal":
149
+
150
+ def rate_function(t):
151
+ return mean_rate + variability * torch.sin(t) # Sinusoidal variability
152
+ elif variability_type == "decaying":
153
+
154
+ def rate_function(t):
155
+ return mean_rate * torch.exp(-variability * t) # Decaying variability
156
+ else:
157
+ raise ValueError(f"Unknown variability_type: {variability_type}")
158
+ return rate_function
159
+
160
+
161
+ def simulate_ou_process(
162
+ mu: float, sigma: float, theta: float, dt: float, T: float, seed: Optional[int] = None
163
+ ) -> np.ndarray:
164
+ """Simulate a mean-reverting Ornstein-Uhlenbeck process."""
165
+ if seed is not None:
166
+ np.random.seed(seed)
167
+
168
+ N = int(T / dt)
169
+ X = np.zeros(N)
170
+
171
+ # Start from the stationary distribution
172
+ X[0] = np.random.normal(mu, np.sqrt(sigma**2 / (2 * theta)))
173
+
174
+ for t in range(1, N):
175
+ dW = np.random.normal(0, np.sqrt(dt))
176
+ X[t] = X[t - 1] + theta * (mu - X[t - 1]) * dt + sigma * dW
177
+
178
+ return X
179
+
180
+
181
+ def sample_individual_configs(study_config: StudyConfig, n: Optional[int] = None):
182
+ """
183
+ Samples parameters for a population of individuals.
184
+
185
+ Parameters
186
+ ----------
187
+ study_config : StudyConfig
188
+ Configuration object with parameter distributions.
189
+ n : int, optional
190
+ Number of individuals to sample. If None, defaults to
191
+ study_config.num_individuals.
192
+
193
+ Returns
194
+ -------
195
+ List[IndividualConfig]
196
+ A list of sampled individual configurations.
197
+ """
198
+ num_individuals = n if n is not None else study_config.num_individuals
199
+ individual_configs = []
200
+
201
+ for _ in range(num_individuals):
202
+ # Sample parameters from lognormal distributions
203
+ k_a = np.random.lognormal(study_config.log_k_a_mean, study_config.log_k_a_std)
204
+ k_e = np.random.lognormal(study_config.log_k_e_mean, study_config.log_k_e_std)
205
+ V = np.random.lognormal(study_config.log_V_mean, study_config.log_V_std)
206
+ k_1p = [
207
+ np.random.lognormal(mean, std)
208
+ for mean, std in zip(study_config.log_k_1p_mean, study_config.log_k_1p_std)
209
+ ]
210
+ k_p1 = [
211
+ np.random.lognormal(mean, std)
212
+ for mean, std in zip(study_config.log_k_p1_mean, study_config.log_k_p1_std)
213
+ ]
214
+
215
+ # Ornstein–Uhlenbeck processes for time-dependent variability
216
+ dt = 0.1
217
+ ou_times = np.arange(study_config.time_start, study_config.time_stop, dt)
218
+ ou_k_a = k_a * np.exp(
219
+ simulate_ou_process(
220
+ 0,
221
+ study_config.k_a_tmag * np.sqrt(2 * study_config.k_a_tscl),
222
+ study_config.k_a_tmag,
223
+ dt,
224
+ study_config.time_stop - study_config.time_start,
225
+ )
226
+ )
227
+ ou_k_e = k_e * np.exp(
228
+ simulate_ou_process(
229
+ 0,
230
+ study_config.k_e_tmag * np.sqrt(2 * study_config.k_e_tscl),
231
+ study_config.k_e_tmag,
232
+ dt,
233
+ study_config.time_stop - study_config.time_start,
234
+ )
235
+ )
236
+ ou_V = V * np.exp(
237
+ simulate_ou_process(
238
+ 0,
239
+ study_config.V_tmag * np.sqrt(2 * study_config.V_tscl),
240
+ study_config.V_tmag,
241
+ dt,
242
+ study_config.time_stop - study_config.time_start,
243
+ )
244
+ )
245
+
246
+ # Time-dependent rate functions
247
+ def k_a_fn(t, ou_k_a=ou_k_a):
248
+ return np.interp(t, ou_times, ou_k_a)
249
+
250
+ def k_e_fn(t, ou_k_e=ou_k_e):
251
+ return np.interp(t, ou_times, ou_k_e)
252
+
253
+ def V_fn(t, ou_V=ou_V):
254
+ return np.interp(t, ou_times, ou_V)
255
+
256
+ # Peripheral exchange rates (sinusoidal modulation as placeholder)
257
+ k_1p_fn = [
258
+ lambda t,
259
+ k_1p_i=k_1p[i],
260
+ tmag_i=study_config.k_1p_tmag[i],
261
+ tscl_i=study_config.k_1p_tscl[i]: k_1p_i * (1 + tmag_i * np.sin(t / tscl_i))
262
+ for i in range(len(k_1p))
263
+ ]
264
+ k_p1_fn = [
265
+ lambda t,
266
+ k_p1_i=k_p1[i],
267
+ tmag_i=study_config.k_p1_tmag[i],
268
+ tscl_i=study_config.k_p1_tscl[i]: k_p1_i * (1 + tmag_i * np.sin(t / tscl_i))
269
+ for i in range(len(k_p1))
270
+ ]
271
+
272
+ # Create config for this individual
273
+ config = IndividualConfig(
274
+ num_peripherals=study_config.num_peripherals,
275
+ k_a=k_a_fn,
276
+ k_e=k_e_fn,
277
+ V=V_fn,
278
+ k_1p=k_1p_fn,
279
+ k_p1=k_p1_fn,
280
+ rel_ruv=study_config.rel_ruv,
281
+ )
282
+ individual_configs.append(config)
283
+
284
+ return individual_configs
285
+
286
+
287
+ def create_dynamic_ode_matrix(config: IndividualConfig, t: float):
288
+ """
289
+ Creates the ODE matrix for the compartment model at time t.
290
+ :param config: IndividualConfig object
291
+ :param t: Current time
292
+ :return: ODE matrix as a torch tensor
293
+ """
294
+ num_compartments = 2 + config.num_peripherals # gut, central, and peripherals
295
+ ode_matrix = torch.zeros((num_compartments, num_compartments))
296
+
297
+ # Gut compartment
298
+ ode_matrix[0, 0] = -config.k_a(t) # d_gut/dt = -k_a(t) * gut
299
+
300
+ # Central compartment
301
+ ode_matrix[1, 0] = config.k_a(t) # d_central/dt += k_a(t) * gut
302
+ ode_matrix[1, 1] = -config.k_e(t) # d_central/dt += -k_e(t) * central
303
+
304
+ # Peripheral compartments
305
+ for i in range(config.num_peripherals):
306
+ ode_matrix[1, 1] -= config.k_1p[i](t) # d_central/dt += - sum_p(k_1p(t)) * central
307
+ ode_matrix[1, 2 + i] = config.k_p1[i](t) # d_central/dt += k_p1[i](t) * peripheral(i)
308
+ ode_matrix[2 + i, 1] = config.k_1p[i](t) # d_peripheral(i)/dt += k_1p[i](t) * central
309
+ ode_matrix[2 + i, 2 + i] = -config.k_p1[i](
310
+ t
311
+ ) # d_peripheral(i)/dt += -k_p1[i](t) * peripheral(i)
312
+
313
+ return ode_matrix
314
+
315
+
316
+ def create_dynamic_ode_matrix_batched(configs, t, num_peripherals):
317
+ """
318
+ Creates batched ODE matrices for multiple individuals.
319
+
320
+ Parameters:
321
+ ----------
322
+ configs : list
323
+ List of IndividualConfig objects.
324
+ t : float
325
+ Current time point.
326
+ num_peripherals : int
327
+ Number of peripheral compartments (same for all individuals).
328
+
329
+ Returns:
330
+ -------
331
+ A_all : torch.Tensor
332
+ Tensor of shape (N, M, M) containing ODE matrices for all individuals.
333
+ """
334
+ import torch
335
+
336
+ N = len(configs)
337
+ M = 2 + num_peripherals
338
+ A_all = torch.zeros((N, M, M), dtype=torch.float32)
339
+
340
+ # Compute batched rate parameters
341
+ k_a_all = torch.tensor([config.k_a(t) for config in configs], dtype=torch.float32)
342
+ k_e_all = torch.tensor([config.k_e(t) for config in configs], dtype=torch.float32)
343
+ k_1p_all = torch.tensor(
344
+ [[config.k_1p[i](t) for i in range(num_peripherals)] for config in configs],
345
+ dtype=torch.float32,
346
+ )
347
+ k_p1_all = torch.tensor(
348
+ [[config.k_p1[i](t) for i in range(num_peripherals)] for config in configs],
349
+ dtype=torch.float32,
350
+ )
351
+
352
+ # Populate the batched ODE matrices
353
+ A_all[:, 0, 0] = -k_a_all # Gut compartment
354
+ A_all[:, 1, 0] = k_a_all # Absorption into central
355
+ A_all[:, 1, 1] = -k_e_all - k_1p_all.sum(dim=1) # Central compartment
356
+ A_all[:, 1, 2 : 2 + num_peripherals] = k_p1_all # Central to peripheral
357
+ A_all[:, 2 : 2 + num_peripherals, 1] = k_1p_all # Peripheral to central
358
+ for i in range(num_peripherals):
359
+ A_all[:, 2 + i, 2 + i] = -k_p1_all[:, i] # Peripheral compartments
360
+
361
+ return A_all
362
+
363
+
364
+ def sample_study(
365
+ individual_config_array, dosing_config_array, t: torch.Tensor, solver_method: str = "rk4"
366
+ ) -> Tuple[
367
+ torch.Tensor, # [N, T] concentration profiles
368
+ torch.Tensor, # [N, T] time points
369
+ torch.Tensor, # [N] dosing amounts
370
+ torch.Tensor, # [N] dosing route types (0 = oral, 1 = iv)
371
+ ]:
372
+ """
373
+ Simulates the pharmacokinetic study for a group of individuals and returns
374
+ concentration profiles, time points, and dosing metadata.
375
+
376
+ Parameters:
377
+ ----------
378
+ individual_config_array : list
379
+ List of IndividualConfig objects for each individual.
380
+ dosing_config_array : list
381
+ List of DosingConfig objects for each individual.
382
+ t : torch.Tensor
383
+ A 1D tensor of time points [T].
384
+
385
+ Returns:
386
+ -------
387
+ full_simulation : torch.Tensor
388
+ Concentration profiles [N, T].
389
+ full_simulation_times : torch.Tensor
390
+ Time points [N, T].
391
+ dosing_amounts : torch.Tensor
392
+ Dosing amounts [N].
393
+ dosing_route_types : torch.Tensor
394
+ Route types [N], 0 = oral, 1 = iv.
395
+ """
396
+ # Sanity check
397
+ if len(individual_config_array) != len(dosing_config_array):
398
+ raise ValueError("Number of individuals and dosing configurations must match.")
399
+
400
+ N = len(individual_config_array)
401
+ num_peripherals_list = [cfg.num_peripherals for cfg in individual_config_array]
402
+ all_same_peripherals = all(n == num_peripherals_list[0] for n in num_peripherals_list)
403
+
404
+ # Extract dosing info
405
+ dosing_amounts = torch.tensor(
406
+ [cfg.dose for cfg in dosing_config_array], dtype=torch.float32
407
+ ) # [N]
408
+ routes_str = [cfg.route for cfg in dosing_config_array]
409
+ route_map = {"oral": 0, "iv": 1}
410
+ dosing_route_types = torch.tensor([route_map[r] for r in routes_str], dtype=torch.int64) # [N]
411
+
412
+ if all_same_peripherals:
413
+ P = num_peripherals_list[0]
414
+ M = 2 + P
415
+ y0 = torch.zeros((N, M), dtype=torch.float32)
416
+ is_oral = dosing_route_types == 0
417
+ is_iv = dosing_route_types == 1
418
+ y0[is_oral, 0] = dosing_amounts[is_oral]
419
+ y0[is_iv, 1] = dosing_amounts[is_iv]
420
+
421
+ def ode_func(t, y):
422
+ A_all = create_dynamic_ode_matrix_batched(individual_config_array, t.item(), P)
423
+ return torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
424
+
425
+ y = odeint(ode_func, y0, t, method=solver_method) # [T, N, M]
426
+ V_all = torch.tensor(
427
+ [[cfg.V(ti.item()) for ti in t] for cfg in individual_config_array], dtype=torch.float32
428
+ ) # [N, T]
429
+ full_simulation = y[:, :, 1].T / V_all # [N, T]
430
+ full_simulation *= (
431
+ 1 + torch.randn_like(full_simulation) * individual_config_array[0].rel_ruv
432
+ )
433
+ else:
434
+ full_simulation = []
435
+ for config, dosing_config in zip(individual_config_array, dosing_config_array):
436
+ P = config.num_peripherals
437
+ M = 2 + P
438
+ if dosing_config.route == "oral":
439
+ y0 = torch.tensor([dosing_config.dose] + [0.0] * (M - 1), dtype=torch.float32)
440
+ elif dosing_config.route == "iv":
441
+ y0 = torch.tensor([0.0, dosing_config.dose] + [0.0] * (M - 2), dtype=torch.float32)
442
+ else:
443
+ raise ValueError(f"Unsupported route: {dosing_config.route}")
444
+
445
+ def ode_func(t, y):
446
+ A = create_dynamic_ode_matrix(config, t.item())
447
+ return torch.matmul(A, y)
448
+
449
+ y = odeint(ode_func, y0, t, method=solver_method) # [T, M]
450
+ V = torch.tensor([config.V(ti.item()) for ti in t], dtype=torch.float32) # [T]
451
+ concentration = y[:, 1] / V
452
+ concentration *= 1 + torch.randn_like(concentration) * config.rel_ruv
453
+ full_simulation.append(concentration)
454
+
455
+ full_simulation = torch.stack(full_simulation) # [N, T]
456
+
457
+ full_times = t.unsqueeze(0).repeat(N, 1) # [N, T]
458
+
459
+ return full_simulation, full_times, dosing_amounts, dosing_route_types
460
+
461
+
462
+ def sample_study_with_duration(
463
+ individual_config_array,
464
+ dosing_config_array: List[DosingWithDurationConfig],
465
+ t: torch.Tensor,
466
+ solver_method: str = "rk4",
467
+ ) -> Tuple[
468
+ torch.Tensor, # [N, T] concentration profiles
469
+ torch.Tensor, # [N, T] time points
470
+ torch.Tensor, # [N] dosing amounts
471
+ torch.Tensor, # [N] dosing route types (0 = oral, 1 = iv)
472
+ ]:
473
+ """
474
+ Simulates the pharmacokinetic study for a group of individuals and returns
475
+ concentration profiles, time points, and dosing metadata.
476
+
477
+ This is a parallel implementation to sample_study that supports infusion dosing with duration.
478
+ Once validated, the two can be merged.
479
+
480
+ Parameters:
481
+ ----------
482
+ individual_config_array : list
483
+ List of IndividualConfig objects for each individual.
484
+ dosing_config_array : list
485
+ List of DosingWithDurationConfig objects for each individual.
486
+ t : torch.Tensor
487
+ A 1D tensor of time points [T].
488
+
489
+ Returns:
490
+ -------
491
+ full_simulation : torch.Tensor
492
+ Concentration profiles [N, T].
493
+ full_simulation_times : torch.Tensor
494
+ Time points [N, T].
495
+ dosing_amounts : torch.Tensor
496
+ Dosing amounts [N].
497
+ dosing_route_types : torch.Tensor
498
+ Route types [N], 0 = oral, 1 = iv.
499
+ """
500
+ # Sanity check
501
+ if len(individual_config_array) != len(dosing_config_array):
502
+ raise ValueError("Number of individuals and dosing configurations must match.")
503
+
504
+ N = len(individual_config_array)
505
+ num_peripherals_list = [cfg.num_peripherals for cfg in individual_config_array]
506
+ all_same_peripherals = all(n == num_peripherals_list[0] for n in num_peripherals_list)
507
+
508
+ # Extract dosing info
509
+ dosing_amounts = torch.tensor(
510
+ [cfg.dose for cfg in dosing_config_array], dtype=torch.float32
511
+ ) # [N]
512
+ routes_str = [cfg.route for cfg in dosing_config_array]
513
+ route_map = {"oral": 0, "iv": 1}
514
+ dosing_route_types = torch.tensor([route_map[r] for r in routes_str], dtype=torch.int64) # [N]
515
+ dosing_durations = torch.tensor(
516
+ [cfg.duration for cfg in dosing_config_array], dtype=torch.float32
517
+ ) # [N]
518
+
519
+ if all_same_peripherals and all(dosing_durations == 0):
520
+ P = num_peripherals_list[0]
521
+ M = 2 + P # gut, central, peripherals
522
+ y0 = torch.zeros((N, M), dtype=torch.float32)
523
+
524
+ is_oral = dosing_route_types == 0
525
+ is_iv_bolus = dosing_route_types == 1
526
+
527
+ y0[is_oral, 0] = dosing_amounts[is_oral]
528
+ y0[is_iv_bolus, 1] = dosing_amounts[is_iv_bolus]
529
+
530
+ def ode_func(t, y):
531
+ A_all = create_dynamic_ode_matrix_batched(individual_config_array, t.item(), P)
532
+ return torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
533
+
534
+ # ODE solving during infusion
535
+ y = odeint(ode_func, y0, t, method=solver_method) # [T, N, M]
536
+ V_all = torch.tensor(
537
+ [[cfg.V(ti.item()) for ti in t] for cfg in individual_config_array], dtype=torch.float32
538
+ ) # [N, T]
539
+ full_simulation = y[:, :, 1].T / V_all # [N, T]
540
+ full_simulation *= (
541
+ 1 + torch.randn_like(full_simulation) * individual_config_array[0].rel_ruv
542
+ )
543
+ else:
544
+ full_simulation = []
545
+ for config, dosing_config in zip(individual_config_array, dosing_config_array):
546
+ P = config.num_peripherals
547
+ M = 2 + P # gut, central, peripherals
548
+ if dosing_config.route == "oral":
549
+ assert dosing_config.duration == 0, "Oral dosing cannot have a duration."
550
+ y0 = torch.tensor([dosing_config.dose] + [0.0] * (M - 1), dtype=torch.float32)
551
+ elif dosing_config.route == "iv":
552
+ if dosing_config.duration > 0:
553
+ # Infusion dosing
554
+ y0 = torch.tensor(
555
+ [0.0, 0.0] + [0.0] * (M - 2),
556
+ dtype=torch.float32,
557
+ )
558
+ else: # Bolus dosing
559
+ y0 = torch.tensor(
560
+ [0.0, dosing_config.dose] + [0.0] * (M - 2), dtype=torch.float32
561
+ )
562
+ else:
563
+ raise ValueError(f"Unsupported route: {dosing_config.route}")
564
+
565
+ def ode_func(t, y):
566
+ A = create_dynamic_ode_matrix(config, t.item())
567
+ b = torch.zeros_like(y)
568
+ if (
569
+ dosing_config.route == "iv"
570
+ and dosing_config.duration > 0
571
+ and t.item() < dosing_config.duration
572
+ ):
573
+ # During infusion, add rate to central compartment
574
+ b[1] = dosing_config.dose / dosing_config.duration
575
+ return torch.matmul(A, y) + b
576
+
577
+ y = odeint(ode_func, y0, t, method=solver_method) # [T, M]
578
+ V = torch.tensor([config.V(ti.item()) for ti in t], dtype=torch.float32) # [T]
579
+ concentration = y[:, 1] / V
580
+ concentration *= 1 + torch.randn_like(concentration) * config.rel_ruv
581
+ full_simulation.append(concentration)
582
+
583
+ full_simulation = torch.stack(full_simulation) # [N, T]
584
+
585
+ full_times = t.unsqueeze(0).repeat(N, 1) # [N, T]
586
+
587
+ return full_simulation, full_times, dosing_amounts, dosing_route_types
588
+
589
+
590
+ def derive_timescale_parameters(config: StudyConfig, meta_config: MetaStudyConfig):
591
+ """
592
+ Derive peak time and terminal half life for typical parameters,
593
+ which can then be used to inform a study-specific sampling schedule.
594
+ """
595
+ k_a = np.exp(config.log_k_a_mean)
596
+ k_e = np.exp(config.log_k_e_mean)
597
+ tmax = (np.log(k_e) - np.log(k_a)) / (k_e - k_a)
598
+
599
+ # mean residence time approximation for terminal half-life
600
+ MRT = 1 / k_e
601
+ # for i in range(config.num_peripherals):
602
+ # k_1i = np.exp(config.log_k_p1_mean[i])
603
+ # MRT += 1/k_1i
604
+ t12 = np.log(2) * MRT
605
+ if t12 > meta_config.time_stop:
606
+ t12 = float(meta_config.time_stop / 2.0)
607
+ if tmax > t12:
608
+ tmax = float(t12 * 0.5)
609
+ return torch.Tensor([tmax, t12])
610
+
611
+
612
+ def sample_dosing_configs(config: MetaDosingConfig):
613
+ """
614
+ Sample a dosing configuration based on the meta dosing configuration.
615
+ Route may be the same for all individuals in the study or not.
616
+ Doses are lognormally distributed with log-mean and log-std sample uniformly from the specified range.
617
+ In the special case of logdose_std_range = (0, 0), the dose is identical for all individuals.
618
+ """
619
+ dosing_configs = []
620
+
621
+ if config.same_route:
622
+ route = np.random.choice(config.route_options, p=config.route_weights)
623
+ route = np.repeat(route, config.num_individuals)
624
+ else:
625
+ route = np.random.choice(
626
+ config.route_options, p=config.route_weights, size=config.num_individuals
627
+ )
628
+
629
+ # Draw lognormal distribution parameters for dose
630
+ logdose_mean = np.random.uniform(*config.logdose_mean_range)
631
+ logdose_std = np.random.uniform(*config.logdose_std_range)
632
+ dose = np.random.lognormal(logdose_mean, logdose_std, size=config.num_individuals)
633
+
634
+ for i in range(config.num_individuals):
635
+ time = config.time
636
+
637
+ this_config = DosingConfig(dose=dose[i], route=route[i], time=time)
638
+
639
+ dosing_configs.append(this_config)
640
+
641
+ return dosing_configs
642
+
643
+
644
+ def sample_dosing_with_duration_configs(config: MetaDosingWithDurationConfig):
645
+ """
646
+ Sample a dosing configuration based on the meta dosing configuration.
647
+ Route may be the same for all individuals in the study or not.
648
+ Doses are lognormally distributed with log-mean and log-std sample uniformly from the specified range.
649
+ In the special case of logdose_std_range = (0, 0), the dose is identical for all individuals.
650
+ """
651
+ dosing_configs = []
652
+
653
+ if config.same_route:
654
+ route = np.random.choice(config.route_options, p=config.route_weights)
655
+ route = np.repeat(route, config.num_individuals)
656
+ else:
657
+ route = np.random.choice(
658
+ config.route_options, p=config.route_weights, size=config.num_individuals
659
+ )
660
+
661
+ # Draw durations for infusion dosing depending on route
662
+ duration_raw = np.random.uniform(
663
+ config.duration_range[0], config.duration_range[1], size=config.num_individuals
664
+ )
665
+
666
+ # Draw lognormal distribution parameters for dose
667
+ logdose_mean = np.random.uniform(*config.logdose_mean_range)
668
+ logdose_std = np.random.uniform(*config.logdose_std_range)
669
+ dose = np.random.lognormal(logdose_mean, logdose_std, size=config.num_individuals)
670
+
671
+ for i in range(config.num_individuals):
672
+ time = config.time
673
+
674
+ # Add duration flag based on route duration weights
675
+ duration_flag = np.random.binomial(1, config.route_duration_weights[route[i]], size=1)[0]
676
+
677
+ # Define a dosing config with a (default) duration of 0; can be modified once MetaDosingConfig supports it
678
+ this_config = DosingWithDurationConfig(
679
+ dose=dose[i],
680
+ route=route[i],
681
+ time=time,
682
+ duration=duration_raw[i] * duration_flag,
683
+ )
684
+
685
+ dosing_configs.append(this_config)
686
+
687
+ return dosing_configs
688
+
689
+
690
+ def get_random_simulation(
691
+ model_config: NodePKExperimentConfig,
692
+ ) -> Tuple[TensorType["I", "T"], TensorType["I", "T"]]:
693
+ """
694
+ Generates random simulation data based on the model configuration.
695
+
696
+ Args:
697
+ model_config (NodePKConfig): Configuration for the simulation.
698
+
699
+ Returns:
700
+ Tuple[TensorType["I", "T"], TensorType["I", "T"]]: Time steps and simulation points.
701
+ """
702
+ I = model_config.meta_study.num_individuals_range[0]
703
+ T = model_config.meta_study.time_num_steps
704
+
705
+ # Generate time steps using linspace
706
+ time_steps = (
707
+ torch.linspace(
708
+ model_config.meta_study.time_start,
709
+ model_config.meta_study.time_stop,
710
+ T,
711
+ dtype=torch.float32,
712
+ )
713
+ .unsqueeze(0)
714
+ .repeat(I, 1)
715
+ ) # Shape: [I, T]
716
+
717
+ # Generate random simulation points with the same shape
718
+ simulation_points = torch.rand(I, T) # Shape: [I, T]
719
+ simulation_points = simulation_points / model_config.meta_study.time_stop
720
+
721
+ return simulation_points, time_steps
sim_priors_pk/data/data_generation/compartment_models_management.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyright: reportAssignmentType=false
2
+ # compartment_models_management.py
3
+ import json
4
+ import logging
5
+ from dataclasses import replace
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torchtyping import TensorType
12
+
13
+ from sim_priors_pk.config_classes.data_config import (
14
+ DosingConfig,
15
+ DosingWithDurationConfig,
16
+ MetaDosingConfig,
17
+ MetaDosingWithDurationConfig,
18
+ MetaStudyConfig,
19
+ ObservationsConfig,
20
+ )
21
+ from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
22
+ from sim_priors_pk.data.data_generation.compartment_models import (
23
+ StudyConfig,
24
+ derive_timescale_parameters,
25
+ sample_dosing_configs,
26
+ sample_dosing_with_duration_configs,
27
+ sample_individual_configs,
28
+ sample_study,
29
+ sample_study_config,
30
+ sample_study_with_duration,
31
+ )
32
+ from sim_priors_pk.data.data_generation.observations_classes import ObservationStrategyFactory
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING: # pragma: no cover - typing only
37
+ from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON
38
+ else: # pragma: no cover - runtime fallback avoids heavy import cycle
39
+ IndividualJSON = Dict[str, object]
40
+ StudyJSON = Dict[str, object]
41
+
42
+
43
+ def is_valid_simulation(sim: torch.Tensor) -> bool:
44
+ """Returns True if the simulation is numerically valid and all values are < 10."""
45
+ return torch.isfinite(sim).all() and (sim >= 0).all() and (sim < 10).all()
46
+
47
+
48
+ def sample_dosing_configs_repeated_target(config: MetaDosingConfig, n_targets: int):
49
+ """
50
+ Generate dosing configs where all target individuals share the same
51
+ dose and route.
52
+
53
+ Parameters
54
+ ----------
55
+ config : MetaDosingConfig
56
+ Meta dosing configuration (num_individuals field may be ignored).
57
+ n_targets : int
58
+ Number of target individuals to generate.
59
+
60
+ Returns
61
+ -------
62
+ List[DosingConfig]
63
+ Identical dosing configs repeated `n_targets` times.
64
+ """
65
+ # Choose one route for all targets
66
+ route = np.random.choice(config.route_options, p=config.route_weights)
67
+
68
+ # Sample one dose (lognormal)
69
+ logdose_mean = np.random.uniform(*config.logdose_mean_range)
70
+ logdose_std = np.random.uniform(*config.logdose_std_range)
71
+ dose_value = float(np.random.lognormal(logdose_mean, logdose_std))
72
+
73
+ # Build identical configs
74
+ dosing_configs = [
75
+ DosingConfig(dose=dose_value, route=route, time=config.time) for _ in range(n_targets)
76
+ ]
77
+ return dosing_configs
78
+
79
+
80
+ def sample_dosing_with_duration_configs_repeated_target(
81
+ config: MetaDosingWithDurationConfig, n_targets: int
82
+ ):
83
+ """
84
+ Generate dosing configs where all target individuals share the same
85
+ dose and route.
86
+
87
+ Parameters
88
+ ----------
89
+ config : MetaDosingWithDurationConfig
90
+ Meta dosing configuration with duration (num_individuals field may be ignored).
91
+ n_targets : int
92
+ Number of target individuals to generate.
93
+
94
+ Returns
95
+ -------
96
+ List[DosingConfig]
97
+ Identical dosing configs repeated `n_targets` times.
98
+ """
99
+ # Choose one route for all targets
100
+ route = np.random.choice(config.route_options, p=config.route_weights)
101
+
102
+ # Handling the duration logic
103
+ duration_weight = config.route_duration_weights[route]
104
+ duration_range = np.random.uniform(*config.duration_range)
105
+ duration = duration_weight * duration_range
106
+
107
+ # Sample one dose (lognormal)
108
+ logdose_mean = np.random.uniform(*config.logdose_mean_range)
109
+ logdose_std = np.random.uniform(*config.logdose_std_range)
110
+ dose_value = float(np.random.lognormal(logdose_mean, logdose_std))
111
+
112
+ # Build identical configs
113
+ dosing_configs = [
114
+ DosingWithDurationConfig(dose=dose_value, route=route, time=config.time, duration=duration)
115
+ for _ in range(n_targets)
116
+ ]
117
+ return dosing_configs
118
+
119
+
120
+ # ──────────────────────────────────────────────────────────────
121
+ # NEW: split where *all* individuals are target
122
+ # ──────────────────────────────────────────────────────────────
123
+ def split_context_only(
124
+ full_simulation: torch.Tensor,
125
+ full_simulation_times: torch.Tensor,
126
+ ) -> Tuple[torch.Tensor, torch.Tensor, list[int]]:
127
+ """Return all individuals as context, no targets."""
128
+ num_individuals = full_simulation.shape[0]
129
+ context_indices = list(range(num_individuals))
130
+ return full_simulation, full_simulation_times, context_indices
131
+
132
+
133
+ def split_simulations_repeated_target(
134
+ full_simulation: torch.Tensor,
135
+ full_simulation_times: torch.Tensor,
136
+ ) -> Tuple[
137
+ Optional[torch.Tensor],
138
+ Optional[torch.Tensor],
139
+ torch.Tensor,
140
+ torch.Tensor,
141
+ list[int],
142
+ list[int],
143
+ ]:
144
+ """
145
+ Variant of split_simulations where **all individuals are in the target set**
146
+ and no context individuals are returned.
147
+
148
+ Parameters
149
+ ----------
150
+ full_simulation : torch.Tensor [N, T]
151
+ full_simulation_times : torch.Tensor [N, T]
152
+
153
+ Returns
154
+ -------
155
+ context_simulation : None
156
+ context_simulation_times : None
157
+ target_simulation : torch.Tensor [N, T]
158
+ target_simulation_times : torch.Tensor [N, T]
159
+ context_indices : []
160
+ target_indices : list[int] = [0,...,N-1]
161
+ """
162
+ num_individuals = full_simulation.shape[0]
163
+ target_indices = list(range(num_individuals))
164
+
165
+ return (
166
+ None,
167
+ None,
168
+ full_simulation,
169
+ full_simulation_times,
170
+ [],
171
+ target_indices,
172
+ )
173
+
174
+
175
+ def _generate_full_simulation(
176
+ meta_study_config: MetaStudyConfig,
177
+ meta_dosing_config: MetaDosingConfig,
178
+ *,
179
+ retry_on_invalid: bool = True,
180
+ idx: int = 0,
181
+ ) -> Tuple[
182
+ torch.Tensor,
183
+ torch.Tensor,
184
+ torch.Tensor,
185
+ torch.Tensor,
186
+ torch.Tensor,
187
+ torch.Tensor,
188
+ StudyConfig,
189
+ list[DosingConfig],
190
+ int,
191
+ ]:
192
+ """Internal helper returning the raw tensors alongside sampling metadata."""
193
+ study_config = sample_study_config(meta_study_config)
194
+ indiv_config_array = sample_individual_configs(study_config)
195
+ time_scales = derive_timescale_parameters(study_config, meta_study_config)
196
+
197
+ time_points = torch.linspace(
198
+ meta_study_config.time_start,
199
+ meta_study_config.time_stop,
200
+ meta_study_config.time_num_steps,
201
+ dtype=torch.float32,
202
+ )
203
+
204
+ local_meta_dosing = replace(meta_dosing_config, num_individuals=study_config.num_individuals)
205
+ dosing_config_array = sample_dosing_configs(local_meta_dosing)
206
+
207
+ full_sim, full_times, dosing_amounts, dosing_routes = sample_study(
208
+ indiv_config_array,
209
+ dosing_config_array,
210
+ time_points,
211
+ meta_study_config.solver_method,
212
+ )
213
+
214
+ if not is_valid_simulation(full_sim):
215
+ attempt_number = idx + 1
216
+ if attempt_number > 5:
217
+ logger.warning(
218
+ "Invalid simulation encountered during attempt %d (recursion depth %d); retry_on_invalid=%s.",
219
+ attempt_number,
220
+ idx,
221
+ retry_on_invalid,
222
+ )
223
+ if retry_on_invalid:
224
+ (
225
+ full_sim,
226
+ full_times,
227
+ dosing_amounts,
228
+ dosing_routes,
229
+ time_points,
230
+ time_scales,
231
+ study_config,
232
+ dosing_config_array,
233
+ downstream_failures,
234
+ ) = _generate_full_simulation(
235
+ meta_study_config,
236
+ meta_dosing_config,
237
+ retry_on_invalid=retry_on_invalid,
238
+ idx=idx + 1,
239
+ )
240
+ return (
241
+ full_sim,
242
+ full_times,
243
+ dosing_amounts,
244
+ dosing_routes,
245
+ time_points,
246
+ time_scales,
247
+ study_config,
248
+ dosing_config_array,
249
+ downstream_failures + 1,
250
+ )
251
+ raise RuntimeError("Invalid simulation")
252
+
253
+ return (
254
+ full_sim,
255
+ full_times,
256
+ dosing_amounts,
257
+ dosing_routes,
258
+ time_points,
259
+ time_scales,
260
+ study_config,
261
+ dosing_config_array,
262
+ 0,
263
+ )
264
+
265
+
266
+ def _generate_full_simulation_with_duration(
267
+ meta_study_config: MetaStudyConfig,
268
+ meta_dosing_config: MetaDosingWithDurationConfig,
269
+ *,
270
+ retry_on_invalid: bool = True,
271
+ idx: int = 0,
272
+ ) -> Tuple[
273
+ torch.Tensor,
274
+ torch.Tensor,
275
+ torch.Tensor,
276
+ torch.Tensor,
277
+ torch.Tensor,
278
+ torch.Tensor,
279
+ StudyConfig,
280
+ list[DosingConfig],
281
+ int,
282
+ ]:
283
+ """
284
+ Internal helper returning the raw tensors alongside sampling metadata.
285
+ This is a parallel implementation to `_generate_full_simulation` that supports
286
+ dosing with duration. Once validated, the two can be merged.
287
+ """
288
+ study_config = sample_study_config(meta_study_config)
289
+ indiv_config_array = sample_individual_configs(study_config)
290
+ time_scales = derive_timescale_parameters(study_config, meta_study_config)
291
+
292
+ time_points = torch.linspace(
293
+ meta_study_config.time_start,
294
+ meta_study_config.time_stop,
295
+ meta_study_config.time_num_steps,
296
+ dtype=torch.float32,
297
+ )
298
+
299
+ local_meta_dosing = replace(meta_dosing_config, num_individuals=study_config.num_individuals)
300
+ dosing_config_array = sample_dosing_with_duration_configs(local_meta_dosing)
301
+
302
+ full_sim, full_times, dosing_amounts, dosing_routes = sample_study_with_duration(
303
+ indiv_config_array,
304
+ dosing_config_array,
305
+ time_points,
306
+ meta_study_config.solver_method,
307
+ )
308
+
309
+ if not is_valid_simulation(full_sim):
310
+ attempt_number = idx + 1
311
+ if attempt_number > 5:
312
+ logger.warning(
313
+ "Invalid simulation encountered during attempt %d (recursion depth %d); retry_on_invalid=%s.",
314
+ attempt_number,
315
+ idx,
316
+ retry_on_invalid,
317
+ )
318
+ if retry_on_invalid:
319
+ (
320
+ full_sim,
321
+ full_times,
322
+ dosing_amounts,
323
+ dosing_routes,
324
+ time_points,
325
+ time_scales,
326
+ study_config,
327
+ dosing_config_array,
328
+ downstream_failures,
329
+ ) = _generate_full_simulation_with_duration(
330
+ meta_study_config,
331
+ meta_dosing_config,
332
+ retry_on_invalid=retry_on_invalid,
333
+ idx=idx + 1,
334
+ )
335
+ return (
336
+ full_sim,
337
+ full_times,
338
+ dosing_amounts,
339
+ dosing_routes,
340
+ time_points,
341
+ time_scales,
342
+ study_config,
343
+ dosing_config_array,
344
+ downstream_failures + 1,
345
+ )
346
+ raise RuntimeError("Invalid simulation")
347
+
348
+ return (
349
+ full_sim,
350
+ full_times,
351
+ dosing_amounts,
352
+ dosing_routes,
353
+ time_points,
354
+ time_scales,
355
+ study_config,
356
+ dosing_config_array,
357
+ 0,
358
+ )
359
+
360
+
361
+ def _generate_simple_exp_simulation(
362
+ meta_study_config,
363
+ ) -> Tuple[
364
+ TensorType["N", "T"], # full_simulation
365
+ TensorType["N", "T"], # full_simulation_times
366
+ TensorType["N"], # dosing_amounts
367
+ TensorType["N"], # dosing_route_types
368
+ TensorType["T"], # time_points
369
+ TensorType[2], # time_scales [tmax, t12]
370
+ ]:
371
+ """
372
+ Minimal synthetic PK-like simulator.
373
+
374
+ Changes:
375
+ - Samples a single per-RUN decay rate k ~ U(decay_rate_range) and uses it for all individuals.
376
+ - Uses only band_scale_range, baseline_range, and (new) decay_rate_range.
377
+
378
+ Derivations per RUN:
379
+ baseline_run ~ U(baseline_range)
380
+ band_scale_run ~ U(band_scale_range)
381
+ decay_rate ~ U(decay_rate_range)
382
+ intercept_mean = 1.0 + baseline_run
383
+ intercept_std = 0.5 * band_scale_run
384
+ """
385
+
386
+ # ---------------------------
387
+ # Basic hyperparameters
388
+ # ---------------------------
389
+ N: int = getattr(meta_study_config, "num_individuals", 16)
390
+ Tn: int = getattr(meta_study_config, "time_num_steps", 40)
391
+ t_min: float = getattr(meta_study_config, "time_start", 0.0)
392
+ t_max: float = getattr(meta_study_config, "time_stop", 24.0)
393
+
394
+ band_scale_range = getattr(meta_study_config, "band_scale_range", (0.1, 0.3))
395
+ baseline_range = getattr(meta_study_config, "baseline_range", (0.0, 0.1))
396
+ decay_rate_range = getattr(meta_study_config, "decay_rate_range", (0.3, 0.6)) # NEW
397
+
398
+ # ---------------------------
399
+ # Per-RUN draws (no seeds)
400
+ # ---------------------------
401
+ def _urun(lo, hi): # uniform helper
402
+ return (torch.rand(1) * (hi - lo) + lo).item()
403
+
404
+ band_scale_run = _urun(*band_scale_range)
405
+ baseline_run = _urun(*baseline_range)
406
+ decay_rate_k = _urun(*decay_rate_range) # shared across all individuals this run
407
+
408
+ intercept_mean = 1.0 + baseline_run
409
+ intercept_std = 0.5 * band_scale_run
410
+
411
+ # ---------------------------
412
+ # Time grid & single-exp shape
413
+ # ---------------------------
414
+ t: TensorType["T", 1] = torch.linspace(t_min, t_max, Tn).unsqueeze(-1) # [T,1]
415
+ f_t: TensorType["T", 1] = torch.exp(-decay_rate_k * t) # f_t(0)=1, shared shape
416
+
417
+ # ---------------------------
418
+ # Per-individual intercepts
419
+ # ---------------------------
420
+ intercepts: TensorType["N", 1, 1] = torch.normal(
421
+ mean=float(intercept_mean),
422
+ std=float(intercept_std),
423
+ size=(N, 1, 1),
424
+ ).clamp_min(0.0)
425
+
426
+ # Build samples: scaled shape + shared run baseline.
427
+ samples: TensorType["N", "T", 1] = intercepts * f_t.unsqueeze(0) + baseline_run
428
+ samples = samples.clamp_min(0.0) # numerical safety
429
+
430
+ # ---------------------------
431
+ # Dummy dosing / time scales
432
+ # ---------------------------
433
+ dosing_amounts: TensorType["N"] = torch.zeros(N)
434
+ dosing_routes: TensorType["N"] = torch.zeros(N)
435
+ duration = float(t_max - t_min)
436
+ tmax = 0.3 * duration
437
+ t12 = 0.75 * duration
438
+ time_scales: TensorType[2] = torch.tensor([tmax, t12], dtype=torch.float32)
439
+
440
+ # ---------------------------
441
+ # Construct outputs
442
+ # ---------------------------
443
+ full_sim = samples.squeeze(-1) # [N, T]
444
+ full_sim_times = t.expand(N, -1, -1).squeeze(-1) # [N, T]
445
+ time_points = t.squeeze(-1) # [T]
446
+
447
+ return (
448
+ full_sim,
449
+ full_sim_times,
450
+ dosing_amounts,
451
+ dosing_routes,
452
+ time_points,
453
+ time_scales,
454
+ )
455
+
456
+
457
+ def _generate_pulse_simulation(
458
+ meta_study_config,
459
+ ) -> Tuple[
460
+ TensorType["N", "T"], # full_simulation
461
+ TensorType["N", "T"], # full_simulation_times
462
+ TensorType["N"], # dosing_amounts
463
+ TensorType["N"], # dosing_route_types
464
+ TensorType["T"], # time_points
465
+ TensorType[2], # time_scales [t_peak, t_half_tail]
466
+ ]:
467
+ """
468
+ Pulse-like PK-style simulator (rise -> peak -> decay).
469
+
470
+ Config used (all optional with safe defaults):
471
+ - num_individuals (int)
472
+ - time_start, time_stop, time_num_steps
473
+ - band_scale_range: (lo, hi) # controls intercept std via 0.5 * band_scale_run
474
+ - baseline_range: (lo, hi) # per-run vertical offset added to all traces
475
+ - decay_rate_range: (lo, hi) # per-run tail rate r; larger r => faster decay
476
+
477
+ Construction (per RUN):
478
+ duration = time_stop - time_start
479
+ t_peak = 0.30 * duration
480
+ r ~ U(decay_rate_range)
481
+ beta = 1 / r
482
+ alpha = 1 + r * t_peak # => peak near t_peak for Gamma(alpha, beta)
483
+ f(t) = t^(alpha-1) * exp(-t/beta) # normalized so max(f)=1
484
+
485
+ baseline_run ~ U(baseline_range)
486
+ band_scale_run ~ U(band_scale_range)
487
+ intercept_mean = 1.0 + baseline_run
488
+ intercept_std = 0.5 * band_scale_run
489
+
490
+ Per INDIVIDUAL:
491
+ intercept_i ~ Normal(intercept_mean, intercept_std), clamped to >= 0
492
+
493
+ Output:
494
+ samples_i(t) = intercept_i * f_norm(t) + baseline_run, clamped to >= 0
495
+ """
496
+
497
+ # ---------------------------
498
+ # Basics
499
+ # ---------------------------
500
+ N: int = getattr(meta_study_config, "num_individuals", 16)
501
+ Tn: int = getattr(meta_study_config, "time_num_steps", 40)
502
+ t_min: float = getattr(meta_study_config, "time_start", 0.0)
503
+ t_max: float = getattr(meta_study_config, "time_stop", 24.0)
504
+
505
+ band_scale_range = getattr(meta_study_config, "band_scale_range", (0.1, 0.3))
506
+ baseline_range = getattr(meta_study_config, "baseline_range", (0.0, 0.1))
507
+ decay_rate_range = getattr(meta_study_config, "decay_rate_range", (0.3, 0.6))
508
+
509
+ # ---------------------------
510
+ # Per-RUN draws (no seeds)
511
+ # ---------------------------
512
+ def _urun(lo, hi):
513
+ return (torch.rand(1) * (hi - lo) + lo).item()
514
+
515
+ band_scale_run = _urun(*band_scale_range)
516
+ baseline_run = _urun(*baseline_range)
517
+ r_tail = _urun(*decay_rate_range) # shared by all individuals this run
518
+
519
+ duration = float(t_max - t_min)
520
+ t_peak = 0.30 * duration # desired peak position
521
+ beta = 1.0 / max(r_tail, 1e-6) # tail scale
522
+ alpha = 1.0 + r_tail * t_peak # ensures peak near t_peak (alpha>1)
523
+
524
+ # Guardrails: make sure alpha > 1 for a proper rise-then-decay
525
+ if alpha <= 1.05:
526
+ alpha = 1.05
527
+
528
+ # ---------------------------
529
+ # Time grid & Gamma-shaped pulse
530
+ # ---------------------------
531
+ t: TensorType["T"] = torch.linspace(t_min, t_max, Tn) # [T]
532
+ t_shift = t - t_min # start at 0
533
+ # Gamma shape (unnormalized). For t=0, t^(alpha-1) is 0 if alpha>1.
534
+ f_t = (t_shift.clamp_min(0.0) ** (alpha - 1.0)) * torch.exp(-t_shift / beta)
535
+
536
+ # Normalize to max=1 so intercept controls amplitude
537
+ f_max = torch.amax(f_t).clamp_min(1e-12)
538
+ f_t = f_t / f_max # [T]
539
+
540
+ # ---------------------------
541
+ # Per-individual intercepts
542
+ # ---------------------------
543
+ intercept_mean = 1.0 + baseline_run
544
+ intercept_std = 0.5 * band_scale_run
545
+
546
+ intercepts: TensorType["N", 1] = torch.normal(
547
+ mean=float(intercept_mean),
548
+ std=float(intercept_std),
549
+ size=(N, 1),
550
+ ).clamp_min(0.0)
551
+
552
+ # Samples: scale by intercept, add per-run baseline
553
+ samples: TensorType["N", "T"] = (intercepts * f_t.unsqueeze(0)) + baseline_run
554
+ samples = samples.clamp_min(0.0)
555
+
556
+ # ---------------------------
557
+ # Dummy dosing / time scales
558
+ # ---------------------------
559
+ dosing_amounts: TensorType["N"] = torch.zeros(N)
560
+ dosing_routes: TensorType["N"] = torch.zeros(N)
561
+
562
+ # Report t_peak and an approximate tail half-life (after the peak)
563
+ t_half_tail = t_peak + (torch.log(torch.tensor(2.0)) / max(r_tail, 1e-6)).item()
564
+ time_scales: TensorType[2] = torch.tensor([t_peak, t_half_tail], dtype=torch.float32)
565
+
566
+ # ---------------------------
567
+ # Construct outputs
568
+ # ---------------------------
569
+ full_sim = samples # [N, T]
570
+ full_sim_times: TensorType = t.unsqueeze(0).expand(N, -1) # [N, T]
571
+ time_points = t # [T]
572
+
573
+ return (
574
+ full_sim,
575
+ full_sim_times,
576
+ dosing_amounts,
577
+ dosing_routes,
578
+ time_points,
579
+ time_scales,
580
+ )
581
+
582
+
583
+ def _generate_simple_simulation(
584
+ meta_study_config,
585
+ ) -> Tuple[
586
+ TensorType["N", "T"],
587
+ TensorType["N", "T"],
588
+ TensorType["N"],
589
+ TensorType["N"],
590
+ TensorType["T"],
591
+ TensorType[2],
592
+ ]:
593
+ """
594
+ Dispatcher that mixes two generators:
595
+ - with probability p1: _generate_simple_exp_simulation(...)
596
+ - with probability 1 - p1: _generate_pulse_simulation(...)
597
+
598
+ Config:
599
+ - p1 (float in [0,1]), default 0.5
600
+ """
601
+ p1 = float(getattr(meta_study_config, "p1", 0.5))
602
+ # clamp to [0,1]
603
+ p1 = 0.0 if p1 < 0.0 else (1.0 if p1 > 1.0 else p1)
604
+
605
+ if torch.rand(1).item() < p1:
606
+ return _generate_simple_exp_simulation(meta_study_config)
607
+ else:
608
+ return _generate_pulse_simulation(meta_study_config)
609
+
610
+
611
+ def prepare_full_simulation(
612
+ meta_study_config,
613
+ meta_dosing_config,
614
+ *,
615
+ retry_on_invalid: bool = True,
616
+ idx: int = 0,
617
+ ) -> Tuple[
618
+ TensorType["N", "T", 1],
619
+ TensorType["N", "T"],
620
+ TensorType["N"],
621
+ TensorType["N"],
622
+ TensorType["T"],
623
+ TensorType[2],
624
+ ]:
625
+ """
626
+ Generate a full INDIVIDUAL study simulation (before context/target split).
627
+
628
+ This bundles the common steps shared across all dataset generators.
629
+ If `meta_study_config.simple_mode=True`, uses `_generate_simple_simulation`.
630
+ """
631
+
632
+ if getattr(meta_study_config, "simple_mode", False):
633
+ return _generate_simple_simulation(meta_study_config)
634
+
635
+ (
636
+ full_sim,
637
+ full_times,
638
+ dosing_amounts,
639
+ dosing_routes,
640
+ time_points,
641
+ time_scales,
642
+ _,
643
+ _,
644
+ _,
645
+ ) = _generate_full_simulation(
646
+ meta_study_config,
647
+ meta_dosing_config,
648
+ retry_on_invalid=retry_on_invalid,
649
+ idx=idx,
650
+ )
651
+
652
+ return full_sim, full_times, dosing_amounts, dosing_routes, time_points, time_scales
653
+
654
+
655
+ def prepare_full_simulation_with_duration(
656
+ meta_study_config,
657
+ meta_dosing_config,
658
+ *,
659
+ retry_on_invalid: bool = True,
660
+ idx: int = 0,
661
+ ) -> Tuple[
662
+ TensorType["N", "T", 1],
663
+ TensorType["N", "T"],
664
+ TensorType["N"],
665
+ TensorType["N"],
666
+ TensorType["T"],
667
+ TensorType[2],
668
+ ]:
669
+ """
670
+ Generate a full INDIVIDUAL study simulation (before context/target split).
671
+
672
+ This bundles the common steps shared across all dataset generators.
673
+ If `meta_study_config.simple_mode=True`, uses `_generate_simple_simulation`.
674
+
675
+ This is a parallel implementation to `prepare_full_simulation` that supports
676
+ dosing with duration. Once validated, the two can be merged.
677
+ """
678
+
679
+ if getattr(meta_study_config, "simple_mode", False):
680
+ return _generate_simple_simulation(meta_study_config)
681
+
682
+ (
683
+ full_sim,
684
+ full_times,
685
+ dosing_amounts,
686
+ dosing_routes,
687
+ time_points,
688
+ time_scales,
689
+ _,
690
+ _,
691
+ _,
692
+ ) = _generate_full_simulation_with_duration(
693
+ meta_study_config,
694
+ meta_dosing_config,
695
+ retry_on_invalid=retry_on_invalid,
696
+ idx=idx,
697
+ )
698
+
699
+ return full_sim, full_times, dosing_amounts, dosing_routes, time_points, time_scales
700
+
701
+
702
+ def _ensure_strictly_increasing_observations(
703
+ obs_times: list[float], obs_vals: list[list[float]], *, individual_id: str
704
+ ) -> None:
705
+ """Validate that the provided observation times are strictly increasing.
706
+
707
+ Parameters
708
+ ----------
709
+ obs_times:
710
+ Sequence of observation timestamps extracted from the simulator.
711
+ obs_vals:
712
+ Sequence of observation values sampled at ``obs_times``.
713
+ individual_id:
714
+ Identifier of the individual being validated. Included in the
715
+ diagnostic error message to simplify debugging when duplicates are
716
+ detected in batched runs.
717
+ """
718
+
719
+ if len(obs_times) != len(obs_vals):
720
+ raise ValueError(
721
+ "Observation times must be sorted and match the number of observations. "
722
+ f"Received lengths times={len(obs_times)} and values={len(obs_vals)} for "
723
+ f"{individual_id}. Observations={obs_vals}, times={obs_times}."
724
+ )
725
+
726
+ for idx_time in range(len(obs_times) - 1):
727
+ if obs_times[idx_time] >= obs_times[idx_time + 1]:
728
+ raise ValueError(
729
+ "Observation times must be sorted and match the number of observations. "
730
+ f"Detected non-increasing times for {individual_id} at position {idx_time}. "
731
+ f"Observations={obs_vals}, times={obs_times}."
732
+ )
733
+
734
+
735
+ def prepare_full_simulation_to_study_json(
736
+ meta_study_config: MetaStudyConfig,
737
+ observation_config: ObservationsConfig,
738
+ meta_dosing_config: MetaDosingConfig,
739
+ *,
740
+ retry_on_invalid: bool = True,
741
+ idx: int = 0,
742
+ ) -> tuple[StudyJSON, int]:
743
+ """Generate a full simulation and convert it into a :class:`StudyJSON` record.
744
+
745
+ Parameters
746
+ ----------
747
+ meta_study_config:
748
+ Sampling configuration describing the population and numerical solver.
749
+ If meta_study_config.simple_mode is True, uses simplified synthetic data.
750
+ observation_config:
751
+ Configuration for the observation strategy used to extract measurements
752
+ from the raw simulation. All generated observations are stored under
753
+ the ``context`` section of the returned study.
754
+ meta_dosing_config:
755
+ Configuration describing the dosing regimen for each simulated
756
+ individual.
757
+ retry_on_invalid:
758
+ When ``True`` (default) the function retries simulation sampling if the
759
+ generated trajectories are numerically invalid.
760
+ idx:
761
+ Internal recursion depth counter exposed for debugging and testing.
762
+
763
+ Returns
764
+ -------
765
+ tuple[StudyJSON, int]
766
+ Canonical JSON representation of the simulated study with all
767
+ individuals stored in the ``context`` field and an empty ``target``
768
+ list, alongside the number of failed attempts before obtaining the
769
+ valid simulation.
770
+ """
771
+ if getattr(meta_study_config, "simple_mode", False):
772
+ # Handle simple synthetic data generation
773
+ (
774
+ full_sim,
775
+ full_times,
776
+ dosing_amounts,
777
+ dosing_routes,
778
+ _time_points,
779
+ time_scales,
780
+ ) = _generate_simple_simulation(meta_study_config)
781
+ study_config = {""}
782
+ dosing_config_array = [
783
+ DosingConfig(dose=float(d), route="", time=0.0) for d in dosing_amounts
784
+ ]
785
+ failed_attempts = 0
786
+ else:
787
+ # Original mechanistic simulation code
788
+ (
789
+ full_sim,
790
+ full_times,
791
+ dosing_amounts,
792
+ _dosing_routes,
793
+ _time_points,
794
+ time_scales,
795
+ study_config,
796
+ dosing_config_array,
797
+ failed_attempts,
798
+ ) = _generate_full_simulation(
799
+ meta_study_config,
800
+ meta_dosing_config,
801
+ retry_on_invalid=retry_on_invalid,
802
+ idx=idx,
803
+ )
804
+
805
+ observation_strategy = ObservationStrategyFactory.from_config(
806
+ observation_config, meta_study_config
807
+ )
808
+ obs_out, time_out, mask_out, rem_sim, rem_time, rem_mask, _ = observation_strategy.generate(
809
+ full_simulation=full_sim,
810
+ full_simulation_times=full_times,
811
+ time_scales=time_scales,
812
+ )
813
+
814
+ context: list[IndividualJSON] = []
815
+ num_individuals = full_sim.shape[0]
816
+
817
+ for ind_idx in range(num_individuals):
818
+ mask = mask_out[ind_idx].to(torch.bool)
819
+ observations = obs_out[ind_idx][mask].tolist()
820
+ observation_times = time_out[ind_idx][mask].tolist()
821
+
822
+ _ensure_strictly_increasing_observations(
823
+ observation_times,
824
+ observations,
825
+ individual_id=f"context_{ind_idx}",
826
+ )
827
+
828
+ individual: IndividualJSON = {
829
+ "name_id": f"context_{ind_idx}",
830
+ "observations": observations,
831
+ "observation_times": observation_times,
832
+ }
833
+
834
+ if rem_sim is not None and rem_time is not None and rem_mask is not None:
835
+ rem_mask_row = rem_mask[ind_idx].to(torch.bool)
836
+ if rem_mask_row.any():
837
+ individual["remaining"] = rem_sim[ind_idx][rem_mask_row].tolist()
838
+ individual["remaining_times"] = rem_time[ind_idx][rem_mask_row].tolist()
839
+
840
+ dosing_cfg = dosing_config_array[ind_idx]
841
+ dose = float(dosing_amounts[ind_idx].item())
842
+ route = getattr(dosing_cfg, "route", "")
843
+ dosing_time = float(getattr(dosing_cfg, "time", 0.0))
844
+
845
+ if dose or route:
846
+ individual["dosing"] = [dose]
847
+ individual["dosing_type"] = [route]
848
+ individual["dosing_times"] = [dosing_time]
849
+ individual["dosing_name"] = [route]
850
+
851
+ context.append(individual)
852
+
853
+ study_json: StudyJSON = {
854
+ "context": context,
855
+ "target": [],
856
+ "meta_data": {
857
+ "study_name": f"simulated_study_{idx}",
858
+ "substance_name": getattr(study_config, "drug_id", "simulated_substance"),
859
+ },
860
+ }
861
+
862
+ return study_json, failed_attempts
863
+
864
+
865
+ def prepare_full_simulation_with_repeated_targets(
866
+ meta_study_config: MetaStudyConfig,
867
+ meta_dosing_config: MetaDosingConfig,
868
+ n_targets: int,
869
+ *,
870
+ different_dosing: bool = False,
871
+ retry_on_invalid: bool = True,
872
+ idx: int = 0,
873
+ ):
874
+ """
875
+ Generate a context simulation (normal dosing) plus a new set of target
876
+ individuals.
877
+
878
+ Parameters
879
+ ----------
880
+ different_dosing:
881
+ If ``False`` (default), all target individuals share one repeated
882
+ dosing configuration.
883
+ If ``True``, each target individual gets an independent dosing sample
884
+ from the same distribution used for context individuals.
885
+
886
+ Returns
887
+ -------
888
+ context_sim, context_times,
889
+ target_sim, target_times,
890
+ dosing_amounts_ctx, dosing_routes_ctx,
891
+ dosing_amounts_tgt, dosing_routes_tgt,
892
+ time_points, time_scales
893
+ """
894
+ study_config = sample_study_config(meta_study_config)
895
+ indiv_config_array = sample_individual_configs(study_config)
896
+ time_scales = derive_timescale_parameters(study_config, meta_study_config)
897
+
898
+ time_points = torch.linspace(
899
+ meta_study_config.time_start,
900
+ meta_study_config.time_stop,
901
+ meta_study_config.time_num_steps,
902
+ dtype=torch.float32,
903
+ )
904
+
905
+ # Context part
906
+ local_meta_dosing_ctx = replace(
907
+ meta_dosing_config, num_individuals=study_config.num_individuals
908
+ )
909
+ dosing_config_array_ctx = sample_dosing_configs(local_meta_dosing_ctx)
910
+
911
+ full_sim, full_times, dosing_amounts_all, dosing_routes_all = sample_study(
912
+ indiv_config_array,
913
+ dosing_config_array_ctx,
914
+ time_points,
915
+ meta_study_config.solver_method,
916
+ )
917
+ if not is_valid_simulation(full_sim):
918
+ if retry_on_invalid:
919
+ return prepare_full_simulation_with_repeated_targets(
920
+ meta_study_config,
921
+ meta_dosing_config,
922
+ n_targets,
923
+ different_dosing=different_dosing,
924
+ idx=idx + 1,
925
+ )
926
+ raise RuntimeError("Invalid context simulation")
927
+
928
+ context_sim, context_times, ctx_idx = split_context_only(full_sim, full_times)
929
+ dosing_amounts_ctx = dosing_amounts_all[ctx_idx]
930
+ dosing_routes_ctx = dosing_routes_all[ctx_idx]
931
+
932
+ dosing_amounts_ctx = dosing_amounts_all[ctx_idx]
933
+ dosing_routes_ctx = dosing_routes_all[ctx_idx]
934
+
935
+ # Target part
936
+ indiv_cfg_targets = sample_individual_configs(study_config, n=n_targets)
937
+ local_meta_dosing_tgt = replace(meta_dosing_config, num_individuals=n_targets)
938
+ if different_dosing:
939
+ dosing_config_array_tgt = sample_dosing_configs(local_meta_dosing_tgt)
940
+ else:
941
+ dosing_config_array_tgt = sample_dosing_configs_repeated_target(
942
+ local_meta_dosing_tgt, n_targets
943
+ )
944
+
945
+ full_sim_tgt, full_times_tgt, dosing_amounts_tgt, dosing_routes_tgt = sample_study(
946
+ indiv_cfg_targets,
947
+ dosing_config_array_tgt,
948
+ time_points,
949
+ meta_study_config.solver_method,
950
+ )
951
+ if not is_valid_simulation(full_sim_tgt):
952
+ if retry_on_invalid:
953
+ return prepare_full_simulation_with_repeated_targets(
954
+ meta_study_config,
955
+ meta_dosing_config,
956
+ n_targets,
957
+ different_dosing=different_dosing,
958
+ idx=idx + 1,
959
+ )
960
+ raise RuntimeError("Invalid target simulation")
961
+
962
+ _, _, target_sim, target_times, _, tgt_idx = split_simulations_repeated_target(
963
+ full_sim_tgt, full_times_tgt
964
+ )
965
+
966
+ return (
967
+ context_sim,
968
+ context_times,
969
+ target_sim,
970
+ target_times,
971
+ dosing_amounts_ctx,
972
+ dosing_routes_ctx,
973
+ dosing_amounts_tgt[tgt_idx],
974
+ dosing_routes_tgt[tgt_idx],
975
+ time_points,
976
+ time_scales,
977
+ )
978
+
979
+
980
+ def prepare_full_simulation_list_with_repeated_targets(
981
+ meta_study_config: MetaStudyConfig,
982
+ meta_dosing_config: MetaDosingConfig,
983
+ n_targets: int,
984
+ num_of_different_dosages: int,
985
+ *,
986
+ retry_on_invalid: bool = True,
987
+ idx: int = 0,
988
+ ):
989
+ """Generate one shared context and ``L`` target sets with repeated dosing.
990
+
991
+ Parameters
992
+ ----------
993
+ meta_study_config:
994
+ Sampling configuration controlling PK population and solver behaviour.
995
+ meta_dosing_config:
996
+ Dosing-distribution configuration used for both context and targets.
997
+ n_targets:
998
+ Number of target individuals for each dosing condition.
999
+ num_of_different_dosages:
1000
+ Number of target dosing conditions ``L``.
1001
+ retry_on_invalid:
1002
+ Whether to retry sampling when numerical invalid simulations are found.
1003
+ idx:
1004
+ Retry depth / attempt index used for diagnostics.
1005
+
1006
+ Returns
1007
+ -------
1008
+ tuple
1009
+ ``(context_sim, context_times, dosing_amounts_ctx, dosing_routes_ctx,``
1010
+ ``target_simulations, target_times_list, target_dosing_amounts_list,``
1011
+ ``target_dosing_routes_list, time_points, time_scales)`` where each
1012
+ target list has length ``num_of_different_dosages``.
1013
+ """
1014
+
1015
+ if num_of_different_dosages < 0:
1016
+ raise ValueError("num_of_different_dosages must be non-negative")
1017
+
1018
+ study_config = sample_study_config(meta_study_config)
1019
+ indiv_config_array = sample_individual_configs(study_config)
1020
+ time_scales = derive_timescale_parameters(study_config, meta_study_config)
1021
+
1022
+ # [T]
1023
+ time_points = torch.linspace(
1024
+ meta_study_config.time_start,
1025
+ meta_study_config.time_stop,
1026
+ meta_study_config.time_num_steps,
1027
+ dtype=torch.float32,
1028
+ )
1029
+
1030
+ # Context is sampled exactly once.
1031
+ local_meta_dosing_ctx = replace(
1032
+ meta_dosing_config, num_individuals=study_config.num_individuals
1033
+ )
1034
+ dosing_config_array_ctx = sample_dosing_configs(local_meta_dosing_ctx)
1035
+ full_sim, full_times, dosing_amounts_all, dosing_routes_all = sample_study(
1036
+ indiv_config_array,
1037
+ dosing_config_array_ctx,
1038
+ time_points,
1039
+ meta_study_config.solver_method,
1040
+ )
1041
+ if not is_valid_simulation(full_sim):
1042
+ if retry_on_invalid:
1043
+ return prepare_full_simulation_list_with_repeated_targets(
1044
+ meta_study_config,
1045
+ meta_dosing_config,
1046
+ n_targets,
1047
+ num_of_different_dosages,
1048
+ idx=idx + 1,
1049
+ )
1050
+ raise RuntimeError("Invalid context simulation")
1051
+
1052
+ # context_sim: [N_ctx, T], context_times: [N_ctx, T]
1053
+ context_sim, context_times, ctx_idx = split_context_only(full_sim, full_times)
1054
+ dosing_amounts_ctx = dosing_amounts_all[ctx_idx]
1055
+ dosing_routes_ctx = dosing_routes_all[ctx_idx]
1056
+
1057
+ # Keep the same target PK individuals across all dosing conditions so that
1058
+ # only dosing changes across list elements.
1059
+ indiv_cfg_targets = sample_individual_configs(study_config, n=n_targets)
1060
+ local_meta_dosing_tgt = replace(meta_dosing_config, num_individuals=n_targets)
1061
+
1062
+ target_simulations = []
1063
+ target_times_list = []
1064
+ target_dosing_amounts_list = []
1065
+ target_dosing_routes_list = []
1066
+ seen_dosing_signatures: set[tuple[str, float]] = set()
1067
+
1068
+ for _ in range(num_of_different_dosages):
1069
+ attempts = 0
1070
+ while True:
1071
+ attempts += 1
1072
+ dosing_config_array_tgt = sample_dosing_configs_repeated_target(
1073
+ local_meta_dosing_tgt, n_targets
1074
+ )
1075
+
1076
+ # Ensure distinct dosing regimens across list elements.
1077
+ dosing_signature = ("", 0.0)
1078
+ if n_targets > 0 and len(dosing_config_array_tgt) > 0:
1079
+ first_cfg = dosing_config_array_tgt[0]
1080
+ dosing_signature = (
1081
+ str(getattr(first_cfg, "route", "")),
1082
+ float(getattr(first_cfg, "dose", 0.0)),
1083
+ )
1084
+ if dosing_signature in seen_dosing_signatures and num_of_different_dosages > 1:
1085
+ if attempts < 100:
1086
+ continue
1087
+ logger.warning(
1088
+ "Could not sample a unique repeated target dosing signature after %d attempts.",
1089
+ attempts,
1090
+ )
1091
+
1092
+ full_sim_tgt, full_times_tgt, dosing_amounts_tgt, dosing_routes_tgt = sample_study(
1093
+ indiv_cfg_targets,
1094
+ dosing_config_array_tgt,
1095
+ time_points,
1096
+ meta_study_config.solver_method,
1097
+ )
1098
+ if not is_valid_simulation(full_sim_tgt):
1099
+ if retry_on_invalid and attempts < 100:
1100
+ continue
1101
+ if retry_on_invalid:
1102
+ return prepare_full_simulation_list_with_repeated_targets(
1103
+ meta_study_config,
1104
+ meta_dosing_config,
1105
+ n_targets,
1106
+ num_of_different_dosages,
1107
+ idx=idx + 1,
1108
+ )
1109
+ raise RuntimeError("Invalid target simulation")
1110
+
1111
+ _, _, target_sim, target_times, _, tgt_idx = split_simulations_repeated_target(
1112
+ full_sim_tgt, full_times_tgt
1113
+ )
1114
+
1115
+ target_simulations.append(target_sim)
1116
+ target_times_list.append(target_times)
1117
+ target_dosing_amounts_list.append(dosing_amounts_tgt[tgt_idx])
1118
+ target_dosing_routes_list.append(dosing_routes_tgt[tgt_idx])
1119
+ if n_targets > 0:
1120
+ seen_dosing_signatures.add(dosing_signature)
1121
+ break
1122
+
1123
+ return (
1124
+ context_sim,
1125
+ context_times,
1126
+ dosing_amounts_ctx,
1127
+ dosing_routes_ctx,
1128
+ target_simulations,
1129
+ target_times_list,
1130
+ target_dosing_amounts_list,
1131
+ target_dosing_routes_list,
1132
+ time_points,
1133
+ time_scales,
1134
+ )
1135
+
1136
+
1137
+ def prepare_ensemble_of_simulations(
1138
+ meta_study_config: MetaStudyConfig,
1139
+ observation_config: ObservationsConfig,
1140
+ meta_dosing_config: MetaDosingConfig,
1141
+ number_of_samples: int,
1142
+ file_name: Optional[str] = None,
1143
+ group_size: Optional[int] = None,
1144
+ ) -> tuple[list[StudyJSON] | list[list[StudyJSON]], float]:
1145
+ """Generate an ensemble of simulated studies.
1146
+
1147
+ The helper repeatedly calls :func:`prepare_full_simulation_to_study_json`
1148
+ to produce ``number_of_samples`` independent simulations. When ``file_name``
1149
+ is provided, the resulting list is serialized as JSON for reproducibility
1150
+ and downstream processing.
1151
+
1152
+ Parameters
1153
+ ----------
1154
+ meta_study_config:
1155
+ Sampling configuration controlling the pharmacokinetic population and
1156
+ solver settings.
1157
+ observation_config:
1158
+ Observation strategy applied to each generated simulation.
1159
+ meta_dosing_config:
1160
+ Configuration describing the dosing regimen per simulated individual.
1161
+ number_of_samples:
1162
+ Number of simulations to generate.
1163
+ file_name:
1164
+ Optional path used to persist the generated ensemble as a JSON file.
1165
+ group_size:
1166
+ Optional number of studies per group. If provided, the return value is
1167
+ a list of lists where each sublist has ``group_size`` elements. Extra
1168
+ simulations that do not fit evenly into the last group are ignored.
1169
+
1170
+ Returns
1171
+ -------
1172
+ tuple[list[StudyJSON] | list[list[StudyJSON]], float]
1173
+ Ensemble of simulated studies (flat or grouped) and the proportion of
1174
+ failed simulation attempts encountered while generating the ensemble.
1175
+ """
1176
+
1177
+ studies: list[StudyJSON] = []
1178
+ total_failed_attempts = 0
1179
+ for idx in range(number_of_samples):
1180
+ study, failed_attempts = prepare_full_simulation_to_study_json(
1181
+ meta_study_config=meta_study_config,
1182
+ observation_config=observation_config,
1183
+ meta_dosing_config=meta_dosing_config,
1184
+ idx=idx,
1185
+ )
1186
+ studies.append(study)
1187
+ total_failed_attempts += failed_attempts
1188
+
1189
+ # --- Optional serialization ---
1190
+ if file_name:
1191
+ path = Path(file_name)
1192
+ path.write_text(json.dumps(studies, indent=2))
1193
+
1194
+ # --- Compute failure rate ---
1195
+ total_successful = len(studies)
1196
+ total_attempts = total_failed_attempts + total_successful
1197
+ failure_rate = total_failed_attempts / total_attempts if total_attempts > 0 else 0.0
1198
+
1199
+ # --- Optional grouping ---
1200
+ if group_size and group_size > 0:
1201
+ n_full_groups = len(studies) // group_size
1202
+ grouped_studies = [
1203
+ studies[i * group_size : (i + 1) * group_size] for i in range(n_full_groups)
1204
+ ]
1205
+ return grouped_studies, failure_rate
1206
+
1207
+ return studies, failure_rate
1208
+
1209
+
1210
+ def prepare_full_simulation_to_study_json_context_target(
1211
+ meta_study_config: MetaStudyConfig,
1212
+ observation_config: ObservationsConfig,
1213
+ meta_dosing_config_context: MetaDosingConfig,
1214
+ meta_dosing_config_target: MetaDosingConfig,
1215
+ *,
1216
+ retry_on_invalid: bool = True,
1217
+ idx: int = 0,
1218
+ ) -> tuple[StudyJSON, int]:
1219
+ """Generate a full simulation and convert it into a :class:`StudyJSON` record.
1220
+ Different dosing regimens are used for context and target individuals.
1221
+
1222
+ Parameters
1223
+ ----------
1224
+ meta_study_config:
1225
+ Sampling configuration describing the population and numerical solver.
1226
+ If meta_study_config.simple_mode is True, uses simplified synthetic data.
1227
+ observation_config:
1228
+ Configuration for the observation strategy used to extract measurements
1229
+ from the raw simulation. All generated observations are stored under
1230
+ the ``context`` section of the returned study.
1231
+ meta_dosing_config_context:
1232
+ Configuration describing the dosing regimen for each simulated
1233
+ individual in the context set.
1234
+ meta_dosing_config_target:
1235
+ Configuration describing the dosing regimen for each simulated
1236
+ individual in the target set.
1237
+ retry_on_invalid:
1238
+ When ``True`` (default) the function retries simulation sampling if the
1239
+ generated trajectories are numerically invalid.
1240
+ idx:
1241
+ Internal recursion depth counter exposed for debugging and testing.
1242
+
1243
+ Returns
1244
+ -------
1245
+ tuple[StudyJSON, int]
1246
+ Canonical JSON representation of the simulated study with all
1247
+ individuals stored in the ``context`` field and an empty ``target``
1248
+ list, alongside the number of failed attempts before obtaining the
1249
+ valid simulation.
1250
+ """
1251
+
1252
+ def prepare_section(name, meta_dosing_config):
1253
+ (
1254
+ full_sim,
1255
+ full_times,
1256
+ dosing_amounts,
1257
+ _dosing_routes,
1258
+ _time_points,
1259
+ time_scales,
1260
+ study_config,
1261
+ dosing_config_array,
1262
+ failed_attempts,
1263
+ ) = _generate_full_simulation(
1264
+ meta_study_config,
1265
+ meta_dosing_config,
1266
+ retry_on_invalid=retry_on_invalid,
1267
+ idx=idx,
1268
+ )
1269
+
1270
+ observation_strategy = ObservationStrategyFactory.from_config(
1271
+ observation_config, meta_study_config
1272
+ )
1273
+ obs_out, time_out, mask_out, rem_sim, rem_time, rem_mask, _ = observation_strategy.generate(
1274
+ full_simulation=full_sim,
1275
+ full_simulation_times=full_times,
1276
+ time_scales=time_scales,
1277
+ )
1278
+
1279
+ section: list[IndividualJSON] = []
1280
+ num_individuals = full_sim.shape[0]
1281
+
1282
+ for ind_idx in range(num_individuals):
1283
+ mask = mask_out[ind_idx].to(torch.bool)
1284
+ observations = obs_out[ind_idx][mask].tolist()
1285
+ observation_times = time_out[ind_idx][mask].tolist()
1286
+
1287
+ _ensure_strictly_increasing_observations(
1288
+ observation_times,
1289
+ observations,
1290
+ individual_id=f"{name}_{ind_idx}",
1291
+ )
1292
+
1293
+ individual: IndividualJSON = {
1294
+ "name_id": f"{name}_{ind_idx}",
1295
+ "observations": observations,
1296
+ "observation_times": observation_times,
1297
+ }
1298
+
1299
+ if rem_sim is not None and rem_time is not None and rem_mask is not None:
1300
+ rem_mask_row = rem_mask[ind_idx].to(torch.bool)
1301
+ if rem_mask_row.any():
1302
+ individual["remaining"] = rem_sim[ind_idx][rem_mask_row].tolist()
1303
+ individual["remaining_times"] = rem_time[ind_idx][rem_mask_row].tolist()
1304
+
1305
+ dosing_cfg = dosing_config_array[ind_idx]
1306
+ dose = float(dosing_amounts[ind_idx].item())
1307
+ route = getattr(dosing_cfg, "route", "")
1308
+ dosing_time = float(getattr(dosing_cfg, "time", 0.0))
1309
+
1310
+ if dose or route:
1311
+ individual["dosing"] = [dose]
1312
+ individual["dosing_type"] = [route]
1313
+ individual["dosing_times"] = [dosing_time]
1314
+ individual["dosing_name"] = [route]
1315
+
1316
+ section.append(individual)
1317
+
1318
+ return section, study_config, failed_attempts
1319
+
1320
+ # Set RNG to have the same study config for both context and target
1321
+ torch.manual_seed(42)
1322
+ context, study_config, failed_attempts_context = prepare_section(
1323
+ "context", meta_dosing_config_context
1324
+ )
1325
+ torch.manual_seed(42)
1326
+ target, _, failed_attempts_target = prepare_section("target", meta_dosing_config_target)
1327
+
1328
+ study_json: StudyJSON = {
1329
+ "context": context,
1330
+ "target": target,
1331
+ "meta_data": {
1332
+ "study_name": f"simulated_study_{idx}",
1333
+ "substance_name": getattr(study_config, "drug_id", "simulated_substance"),
1334
+ },
1335
+ }
1336
+ failed_attempts = failed_attempts_context + failed_attempts_target
1337
+
1338
+ return study_json, failed_attempts
sim_priors_pk/data/data_generation/dosing_models.py ADDED
File without changes
sim_priors_pk/data/data_generation/observations_classes.py ADDED
@@ -0,0 +1,1776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Callable, Optional, Tuple, List
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torchtyping import TensorType
7
+ from sim_priors_pk.config_classes.data_config import ObservationsConfig, MetaStudyConfig
8
+ from sim_priors_pk.data.data_generation.observations_functions import fix_past_time_random_selection
9
+
10
+
11
+ def _sample_past_count_with_bias(
12
+ low: int,
13
+ high: int,
14
+ *,
15
+ generative_bias: bool,
16
+ generator: torch.Generator,
17
+ device: torch.device,
18
+ ) -> int:
19
+ """Sample the number of past observations under the configured bias mode."""
20
+
21
+ if high <= 0:
22
+ return 0
23
+
24
+ if generative_bias:
25
+ sample_zero = int(torch.randint(0, 2, (1,), generator=generator, device=device).item()) == 0
26
+ if sample_zero:
27
+ return 0
28
+
29
+ rest_low = max(1, low)
30
+ if rest_low > high:
31
+ return 0
32
+ if rest_low == high:
33
+ return rest_low
34
+ return int(
35
+ torch.randint(
36
+ rest_low,
37
+ high + 1,
38
+ (1,),
39
+ generator=generator,
40
+ device=device,
41
+ ).item()
42
+ )
43
+
44
+ if low >= high:
45
+ return int(high)
46
+
47
+ return int(torch.randint(low, high + 1, (1,), generator=generator, device=device).item())
48
+
49
+
50
+ class ObservationStrategy(ABC):
51
+ def __init__(self, observations_config: ObservationsConfig, meta_config: MetaStudyConfig):
52
+ self.observations_config = observations_config
53
+ self.meta_config = meta_config
54
+
55
+ def _drop_non_positive_times_from_mask(self, times: Tensor, mask: Tensor) -> Tensor:
56
+ """Optionally invalidate observations at non-positive timestamps.
57
+
58
+ When ``drop_time_zero_observations=True`` in :class:`ObservationsConfig`,
59
+ entries with ``time <= 0`` are excluded from downstream sampling.
60
+ """
61
+ if not getattr(self.observations_config, "drop_time_zero_observations", False):
62
+ return mask
63
+ return mask & (times > 0)
64
+
65
+ def generate(
66
+ self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
67
+ ) -> Tuple[Tensor, ...]:
68
+ """Wrap raw generate: apply add_rem flag"""
69
+ # call subclass raw generation
70
+ obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._generate_raw(
71
+ full_simulation, full_simulation_times, **kwargs
72
+ )
73
+ # drop remaining if not desired
74
+ if not self.observations_config.add_rem:
75
+ rem_sim = rem_time = rem_mask = None
76
+ return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, None
77
+
78
+ @abstractmethod
79
+ def _generate_raw(
80
+ self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
81
+ ) -> Tuple[
82
+ Tensor,
83
+ TensorType["B", "T_obs"],
84
+ TensorType["B", "T_obs"],
85
+ TensorType["B", "T_rem"],
86
+ TensorType["B", "T_rem"],
87
+ TensorType["B", "T_rem"],
88
+ ]:
89
+ """Generate observations and remaining sims raw, regardless of add_rem"""
90
+ pass
91
+
92
+ def get_shapes(self) -> Tuple[int, int]:
93
+ """Wrap raw shapes: apply add_rem flag"""
94
+ max_obs, max_rem = self._get_shapes_raw()
95
+ if not self.observations_config.add_rem:
96
+ max_rem = 0
97
+ return max_obs, max_rem
98
+
99
+ @abstractmethod
100
+ def _get_shapes_raw(self) -> Tuple[int, int]:
101
+ """Return max observations and max remaining assuming add_rem=True"""
102
+ pass
103
+
104
+
105
+ class PKPeakHalfLifeStrategy(ObservationStrategy):
106
+ """Observation strategy tailored to pharmacokinetic (PK) curves.
107
+
108
+ The strategy samples observations around the absorption peak and along the
109
+ elimination phase of a PK simulation. It uses a canonical grid composed of
110
+ four segments:
111
+
112
+ 1. Several points before the peak that are proportional to the configured
113
+ peak time.
114
+ 2. The peak itself.
115
+ 3. Several points after the peak spaced by multiples of the provided
116
+ half-life.
117
+ 4. Optional remainder points that are handed back to the caller when
118
+ ``add_rem`` is enabled.
119
+
120
+ For **synthetic simulations**, the strategy still uses this canonical grid
121
+ and nearest-neighbour alignment.
122
+
123
+ For **empirical data**, measurements are treated as already canonical:
124
+
125
+ * No canonical time grid construction.
126
+ * No time normalisation or template matching.
127
+ * No interpolation or re-scaling to canonical coordinates.
128
+
129
+ Empirical sequences are only padded / truncated to the internal capacity
130
+ implied by :class:`ObservationsConfig` and :class:`MetaStudyConfig`, and
131
+ then passed through the same past/future splitting logic.
132
+
133
+ Past/future splitting
134
+ ----------------------
135
+ When ``split_past_future=True``, the canonical sequence for each row is
136
+ split into:
137
+
138
+ * a *past* observation block of fixed width (``max_obs``), and
139
+ * an optional *remainder* block of width (``max_rem``).
140
+
141
+ In the default mode (no fixed past selection), the number of past points
142
+ is sampled according to ``generative_bias``:
143
+
144
+ * ``False`` samples in ``[min_past, max_past]``.
145
+ * ``True`` samples exactly ``0`` with probability 0.5 and, otherwise,
146
+ samples uniformly in ``[max(1, min_past), max_past]``.
147
+
148
+ Under ``generative_bias=False``, **short sequences** receive a special treatment: when
149
+ the number of valid canonical points is less than or equal to the
150
+ observation capacity, *all* valid points are placed in the observation
151
+ block and none are shifted into the remainder.
152
+
153
+ Fixed past selection
154
+ --------------------
155
+ Calling :meth:`fix_past_selection(k)` activates a strict mode in which
156
+ the strategy tries to expose exactly ``k`` earliest valid timestamps as
157
+ "past" for each series, subject to the following structural limits:
158
+
159
+ 1. The number of real data points available in the series.
160
+ 2. The observation capacity dictated by :meth:`_get_shapes_raw`.
161
+
162
+ Concretely, for each row:
163
+
164
+ * Let ``k`` be the fixed past count.
165
+ * Let ``total_valid`` be the number of valid canonical points.
166
+ * Let ``past_required = min(k, total_valid)``.
167
+
168
+ The observation block receives
169
+
170
+ * ``obs_count = min(past_required, max_obs)`` earliest valid points.
171
+
172
+ If ``past_required > obs_count`` (for example because ``k`` exceeds the
173
+ number of observation slots), the remaining required past events
174
+ ``past_required - obs_count`` are the *first entries* in the remainder
175
+ block (subject to the remainder capacity). This guarantees that, as long
176
+ as data and shapes allow, the first ``k`` valid timestamps appear in
177
+ ``obs`` + ``rem`` before any later timestamps.
178
+
179
+ Calling :meth:`release_past_selection()` returns to the default stochastic
180
+ behaviour governed by ``min_past``/``max_past``.
181
+ """
182
+
183
+ _PEAK_PHASE_MULTIPLIERS = (0.1, 0.2, 0.5, 0.8)
184
+ _POST_PEAK_HALF_LIFE_MULTIPLIERS = (
185
+ 0.25,
186
+ 0.50,
187
+ 1.00,
188
+ 2.00,
189
+ 4.00,
190
+ 6.00,
191
+ 8.00,
192
+ 9.00,
193
+ 14.0,
194
+ 19.0,
195
+ 30.0,
196
+ )
197
+ _RAW_CANONICAL_POINTS = len(_PEAK_PHASE_MULTIPLIERS) + 1 + len(_POST_PEAK_HALF_LIFE_MULTIPLIERS)
198
+
199
+ def __init__(
200
+ self, observations_config: ObservationsConfig, meta_config: MetaStudyConfig
201
+ ) -> None:
202
+ super().__init__(observations_config, meta_config)
203
+ self.max_num_obs = observations_config.max_num_obs
204
+ self.split_past_future = observations_config.split_past_future
205
+ self.min_past = observations_config.min_past
206
+ self.max_past = observations_config.max_past
207
+ self.generative_bias = observations_config.generative_bias
208
+ # None → default random selection. When set, the strategy enforces a
209
+ # strict fixed-past semantics as documented above.
210
+ self._fixed_past_obs_count: Optional[int] = None
211
+
212
+ def fix_past_selection(self, obs_count: int) -> None:
213
+ """Activate strict ``k``-past behaviour.
214
+
215
+ When this mode is active and ``split_past_future=True``, every call to
216
+ :meth:`generate` or :meth:`generate_empirical` will:
217
+
218
+ * expose up to ``obs_count`` earliest valid timestamps in the
219
+ observation block, bounded by the available data and the observation
220
+ capacity;
221
+ * place any additional required past events (when ``obs_count`` is
222
+ larger than the observation capacity) at the *front* of the remainder
223
+ block (when a remainder is present).
224
+
225
+ The strategy is allowed to allocate fewer than ``obs_count`` past
226
+ events only when:
227
+
228
+ * the series contains fewer real data points than ``obs_count``, or
229
+ * the observation/remainder shapes leave insufficient slots.
230
+
231
+ In all other cases the earliest valid timestamps are allocated in the
232
+ order: observation block first, then remainder.
233
+ """
234
+
235
+ if not self.split_past_future:
236
+ # No split → fixed past count is meaningless.
237
+ return
238
+
239
+ if obs_count < self.min_past or obs_count > self.max_past:
240
+ raise ValueError(
241
+ "Fixed past observation count must lie within the configured min/max bounds."
242
+ )
243
+ self._fixed_past_obs_count = int(obs_count)
244
+
245
+ def release_past_selection(self) -> None:
246
+ """Return to the default random past selection behaviour."""
247
+ self._fixed_past_obs_count = None
248
+
249
+ @classmethod
250
+ def _build_canonical_grid(
251
+ cls,
252
+ *,
253
+ t_peak: float,
254
+ t_half: float,
255
+ device: torch.device,
256
+ dtype: torch.dtype,
257
+ ) -> Tensor:
258
+ """Construct the canonical grid for a single simulation.
259
+
260
+ The grid covers the pre-peak, peak and post-peak regime of the curve by
261
+ scaling two fundamental quantities supplied at runtime: the time of the
262
+ peak concentration ``t_peak`` and the half-life ``t_half``. Both values
263
+ are expected to be expressed in the same units as the simulation time
264
+ axis.
265
+ """
266
+ before_peak = [mult * t_peak for mult in cls._PEAK_PHASE_MULTIPLIERS]
267
+ after_peak = [t_peak + mult * t_half for mult in cls._POST_PEAK_HALF_LIFE_MULTIPLIERS]
268
+ values = before_peak + [t_peak] + after_peak
269
+ return torch.tensor(values, device=device, dtype=dtype)
270
+
271
+ def _canonical_grid_capacity(self) -> int:
272
+ """Return the number of canonical grid points available.
273
+
274
+ The capacity is the minimum between the simulator resolution and the
275
+ theoretical number of canonical points. This ensures that the
276
+ observation tensors never attempt to gather indices outside the
277
+ original simulation.
278
+ """
279
+ time_steps = getattr(self.meta_config, "time_num_steps", self.max_num_obs)
280
+ return max(
281
+ 0,
282
+ min(int(self.max_num_obs), int(time_steps), self._RAW_CANONICAL_POINTS),
283
+ )
284
+
285
+ def _get_shapes_raw(self) -> Tuple[int, int]:
286
+ """Compute the maximum number of observation and remainder slots.
287
+
288
+ Returns
289
+ -------
290
+ max_obs, max_rem : int, int
291
+ * ``max_obs`` – maximum number of observation time-steps.
292
+ * ``max_rem`` – maximum number of remainder time-steps when
293
+ ``add_rem`` is enabled.
294
+
295
+ Raises
296
+ ------
297
+ ValueError
298
+ If a past/future split is requested but the canonical capacity
299
+ cannot satisfy the configured ``min_past`` requirement.
300
+ """
301
+ canonical_cap = self._canonical_grid_capacity()
302
+ if canonical_cap == 0:
303
+ return 0, 0
304
+
305
+ if self.split_past_future:
306
+ if canonical_cap < self.min_past:
307
+ raise ValueError("Canonical grid capacity is smaller than the configured min_past")
308
+ max_obs = min(self.max_past, canonical_cap)
309
+ max_rem = max(0, canonical_cap - self.min_past)
310
+ else:
311
+ max_obs = canonical_cap
312
+ max_rem = canonical_cap
313
+
314
+ return max_obs, max_rem
315
+
316
+ @staticmethod
317
+ def _deduplicate_sorted_indices(
318
+ idx: Tensor, valid_mask: Optional[Tensor] = None
319
+ ) -> Tuple[Tensor, Tensor]:
320
+ """Collapse repeated gather indices while preserving alignment.
321
+
322
+ ``idx`` is expected to be monotonically increasing. Consecutive
323
+ duplicates are collapsed into a single entry at the front of the tensor
324
+ and the corresponding ``valid_mask`` entries are shifted accordingly.
325
+ """
326
+ if valid_mask is None:
327
+ valid_mask = torch.ones_like(idx, dtype=torch.bool)
328
+
329
+ if idx.numel() <= 1:
330
+ return idx, valid_mask
331
+
332
+ duplicate_mask = torch.zeros_like(idx, dtype=torch.bool)
333
+ duplicate_mask[1:] = idx[1:] == idx[:-1]
334
+
335
+ if not duplicate_mask.any():
336
+ return idx, valid_mask
337
+
338
+ unique_mask = ~duplicate_mask
339
+ kept_idx = idx[unique_mask]
340
+ duplicate_idx = idx[duplicate_mask]
341
+
342
+ padded_idx = torch.empty_like(idx)
343
+ padded_idx[: kept_idx.numel()] = kept_idx
344
+ padded_idx[kept_idx.numel() :] = duplicate_idx
345
+
346
+ kept_valid = valid_mask[unique_mask]
347
+ padded_mask = torch.zeros_like(valid_mask)
348
+ padded_mask[: kept_valid.numel()] = kept_valid
349
+
350
+ return padded_idx, padded_mask
351
+
352
+ def _assemble_from_canonical(
353
+ self,
354
+ canonical_vals: Tensor,
355
+ canonical_times: Tensor,
356
+ canonical_mask: Tensor,
357
+ *,
358
+ generator: Optional[torch.Generator] = None,
359
+ ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
360
+ """Convert canonical tensors into output observations.
361
+
362
+ The canonical representation stores **all** admissible samples for a
363
+ batch element. This helper slices the canonical tensors into the
364
+ "past" observations that will be returned to the caller and (when
365
+ requested) the "future" remainder.
366
+
367
+ Allocation invariants
368
+ ---------------------
369
+ For each batch row:
370
+
371
+ * Let ``valid_idx`` be the indices where ``canonical_mask`` is True,
372
+ sorted in ascending order.
373
+ * The observation block always receives the **earliest**
374
+ ``obs_count`` indices from ``valid_idx``.
375
+ * The remainder block (when present) receives later indices only; it
376
+ never contains timestamps that precede those in the observation block.
377
+ * Under ``generative_bias=False``, short sequences
378
+ (``total_valid <= max_obs``) keep all valid points in the
379
+ observation block and do not shift points to the remainder.
380
+
381
+ When :meth:`fix_past_selection(k)` is active, we define::
382
+
383
+ past_required = min(k, total_valid)
384
+
385
+ and allocate:
386
+
387
+ * ``obs_count = min(past_required, max_obs)`` to the observation
388
+ block; and
389
+ * any surplus past events ``past_required - obs_count`` at the **front**
390
+ of the remainder block (subject to the remainder capacity), followed
391
+ by any truly future points.
392
+
393
+ Releasing the fixed selection returns to the stochastic behaviour
394
+ controlled by ``generative_bias``.
395
+ """
396
+ max_obs, max_rem = self._get_shapes_raw()
397
+ device = canonical_vals.device
398
+ dtype = canonical_vals.dtype
399
+ batch, _ = canonical_vals.shape
400
+
401
+ obs_out = torch.zeros(batch, max_obs, dtype=dtype, device=device)
402
+ obs_time = torch.zeros_like(obs_out)
403
+ obs_mask = torch.zeros(batch, max_obs, dtype=torch.bool, device=device)
404
+
405
+ rem_sim = rem_time = rem_mask = None
406
+ if max_rem > 0:
407
+ rem_sim = torch.zeros(batch, max_rem, dtype=dtype, device=device)
408
+ rem_time = torch.zeros_like(rem_sim)
409
+ rem_mask = torch.zeros(batch, max_rem, dtype=torch.bool, device=device)
410
+
411
+ gen = generator if generator is not None else torch.default_generator
412
+
413
+ for row in range(batch):
414
+ valid_idx = canonical_mask[row].nonzero(as_tuple=True)[0]
415
+ total_valid = int(valid_idx.numel())
416
+ if total_valid == 0:
417
+ continue
418
+
419
+ fixed_k = self._fixed_past_obs_count if self.split_past_future else None
420
+
421
+ # ------------------------------------------------------------------
422
+ # 1) Decide obs_count
423
+ # ------------------------------------------------------------------
424
+ if self.split_past_future and fixed_k is not None:
425
+ # Strict fixed-past semantics. Structural limits:
426
+ # - real data (total_valid)
427
+ # - observation capacity (max_obs)
428
+ past_required = min(fixed_k, total_valid)
429
+ obs_capacity = min(max_obs, total_valid)
430
+ obs_count = min(past_required, obs_capacity)
431
+ else:
432
+ # Default stochastic behaviour; the short-series fix is kept
433
+ # for the non-biased mode only.
434
+ if self.split_past_future:
435
+ low = min(self.min_past, total_valid)
436
+ high = min(self.max_past, total_valid)
437
+
438
+ sampled = _sample_past_count_with_bias(
439
+ low=low,
440
+ high=high,
441
+ generative_bias=self.generative_bias,
442
+ generator=gen,
443
+ device=device,
444
+ )
445
+
446
+ if (not self.generative_bias) and total_valid <= max_obs:
447
+ # Short-series fix: never push valid points into the
448
+ # remainder just to satisfy a random split.
449
+ obs_count = total_valid
450
+ else:
451
+ obs_count = min(sampled, max_obs)
452
+ else:
453
+ obs_count = min(total_valid, max_obs)
454
+
455
+ # Safety clamp.
456
+ obs_count = max(0, min(obs_count, min(max_obs, total_valid)))
457
+
458
+ # ------------------------------------------------------------------
459
+ # 2) Fill observation block (earliest obs_count indices)
460
+ # ------------------------------------------------------------------
461
+ if obs_count > 0:
462
+ take = valid_idx[:obs_count]
463
+ obs_out[row, :obs_count] = canonical_vals[row, take]
464
+ obs_time[row, :obs_count] = canonical_times[row, take]
465
+ obs_mask[row, :obs_count] = True
466
+
467
+ # ------------------------------------------------------------------
468
+ # 3) Fill remainder block (if enabled)
469
+ # ------------------------------------------------------------------
470
+ if rem_sim is not None:
471
+ if self.split_past_future and fixed_k is not None:
472
+ # Remaining required past events plus genuine future.
473
+ past_required = min(fixed_k, total_valid)
474
+ # indices that are still part of the fixed past window
475
+ # but did not fit into the observation block
476
+ extra_past_idx = valid_idx[obs_count:past_required]
477
+ future_idx = valid_idx[past_required:]
478
+
479
+ candidates: List[Tensor] = []
480
+ if extra_past_idx.numel() > 0:
481
+ candidates.append(extra_past_idx)
482
+ if future_idx.numel() > 0:
483
+ candidates.append(future_idx)
484
+ if candidates:
485
+ remainder_candidates = torch.cat(candidates, dim=0)
486
+ else:
487
+ remainder_candidates = valid_idx.new_empty((0,), dtype=valid_idx.dtype)
488
+ else:
489
+ # Default behaviour: everything after the obs window.
490
+ remainder_candidates = valid_idx[obs_count:]
491
+
492
+ rem_count = min(int(remainder_candidates.numel()), max_rem)
493
+ if rem_count > 0:
494
+ rem_idx = remainder_candidates[:rem_count]
495
+ rem_sim[row, :rem_count] = canonical_vals[row, rem_idx]
496
+ rem_time[row, :rem_count] = canonical_times[row, rem_idx]
497
+ rem_mask[row, :rem_count] = True
498
+
499
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask
500
+
501
+ def _align_simulation_to_canonical(
502
+ self,
503
+ full_simulation: Tensor,
504
+ full_simulation_times: Tensor,
505
+ *,
506
+ time_scales: Tensor,
507
+ num_obs_sampler: Optional[Callable[[int], Tensor]] = None,
508
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
509
+ """Gather canonical samples from a simulated PK curve.
510
+
511
+ Synthetic behaviour is unchanged compared to the original strategy:
512
+ we build a canonical grid, snap it to the nearest simulation times and
513
+ optionally subsample points via ``num_obs_sampler``.
514
+ """
515
+ device = full_simulation.device
516
+ dtype = full_simulation.dtype
517
+ batch, _ = full_simulation.shape
518
+ time_steps = int(full_simulation_times.size(1))
519
+
520
+ # DataLoader workers may receive empty row slices (B=0). In that case
521
+ # there is no reference timeline to align against; return an empty
522
+ # canonical block and let _assemble_from_canonical create [B, *] outputs.
523
+ if batch == 0 or time_steps == 0:
524
+ zero = torch.zeros(batch, 0, dtype=dtype, device=device)
525
+ mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
526
+ return zero, zero, mask, time_scales.clone()
527
+
528
+ canonical_cap = self._canonical_grid_capacity()
529
+ if canonical_cap == 0:
530
+ zero = torch.zeros(batch, 0, dtype=dtype, device=device)
531
+ mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
532
+ return zero, zero, mask, time_scales.clone()
533
+
534
+ grid = self._build_canonical_grid(
535
+ t_peak=time_scales[0].item(),
536
+ t_half=time_scales[1].item(),
537
+ device=device,
538
+ dtype=dtype,
539
+ )[:canonical_cap]
540
+
541
+ ref_times = full_simulation_times[0]
542
+ min_time = ref_times.min()
543
+ max_time = ref_times.max()
544
+ grid_valid_mask = (grid >= min_time) & (grid <= max_time)
545
+
546
+ idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
547
+ idx, order = idx.sort()
548
+ grid_valid_mask = grid_valid_mask[order]
549
+ idx, grid_valid_mask = self._deduplicate_sorted_indices(idx, grid_valid_mask)
550
+
551
+ gather_idx = idx[None, :].expand(batch, -1)
552
+ batch_idx = torch.arange(batch, device=device)[:, None]
553
+
554
+ canonical_vals = full_simulation[batch_idx, gather_idx]
555
+ canonical_times = full_simulation_times[batch_idx, gather_idx]
556
+
557
+ invalid_slots = ~grid_valid_mask
558
+ if invalid_slots.any():
559
+ canonical_vals[:, invalid_slots] = 0
560
+ canonical_times[:, invalid_slots] = 0
561
+
562
+ if num_obs_sampler is None:
563
+ total_counts = torch.full((batch,), canonical_cap, dtype=torch.long, device=device)
564
+ else:
565
+ sampled = num_obs_sampler(batch).to(device=device).long()
566
+ total_counts = sampled.clamp(min=0, max=canonical_cap)
567
+
568
+ max_valid = int(grid_valid_mask.sum().item())
569
+ if max_valid == 0:
570
+ total_counts.zero_()
571
+ else:
572
+ total_counts.clamp_(max=max_valid)
573
+
574
+ valid_order = grid_valid_mask.long().cumsum(dim=0) - 1
575
+ valid_order = torch.where(
576
+ grid_valid_mask,
577
+ valid_order,
578
+ torch.full_like(valid_order, -1, dtype=valid_order.dtype),
579
+ )
580
+ canonical_mask = grid_valid_mask[None, :] & (valid_order[None, :] < total_counts[:, None])
581
+ canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
582
+
583
+ return canonical_vals, canonical_times, canonical_mask, time_scales.clone()
584
+
585
+ def _align_empirical_to_canonical(
586
+ self,
587
+ empirical_obs: Tensor,
588
+ empirical_times: Tensor,
589
+ empirical_mask: Tensor,
590
+ ) -> Tuple[Tensor, Tensor, Tensor]:
591
+ """(Legacy) Project empirical observations onto the canonical grid.
592
+
593
+ This method is retained for backward compatibility but is **not** used
594
+ by :meth:`generate_empirical`, which now treats empirical data as
595
+ already canonical. New code should avoid calling this helper.
596
+ """
597
+ device = empirical_obs.device
598
+ dtype = empirical_obs.dtype
599
+ batch, _ = empirical_obs.shape
600
+ canonical_cap = self._canonical_grid_capacity()
601
+
602
+ canonical_vals = torch.zeros(batch, canonical_cap, dtype=dtype, device=device)
603
+ canonical_times = torch.zeros_like(canonical_vals)
604
+ canonical_mask = torch.zeros(batch, canonical_cap, dtype=torch.bool, device=device)
605
+
606
+ if canonical_cap == 0:
607
+ return canonical_vals, canonical_times, canonical_mask
608
+
609
+ for row in range(batch):
610
+ valid_idx = empirical_mask[row].nonzero(as_tuple=True)[0]
611
+ if valid_idx.numel() == 0:
612
+ continue
613
+
614
+ obs_row = empirical_obs[row, valid_idx]
615
+ time_row = empirical_times[row, valid_idx]
616
+ max_time = torch.maximum(time_row.max(), torch.tensor(1.0, device=device))
617
+ norm_time = time_row / max_time
618
+
619
+ peak_idx = obs_row.argmax().item()
620
+ t_peak = norm_time[peak_idx].item()
621
+ post_times = norm_time[peak_idx:]
622
+ post_obs = obs_row[peak_idx:]
623
+ half_level = obs_row[peak_idx] / 2
624
+ below_half = (post_obs <= half_level).nonzero(as_tuple=True)[0]
625
+ if below_half.numel() == 0:
626
+ half_time = post_times[-1].item()
627
+ else:
628
+ half_time = post_times[below_half[0]].item()
629
+ t_half = max(half_time - t_peak, 1e-3)
630
+
631
+ grid = self._build_canonical_grid(
632
+ t_peak=t_peak if t_peak > 0 else 1e-3,
633
+ t_half=t_half,
634
+ device=device,
635
+ dtype=dtype,
636
+ )[:canonical_cap].clamp(max=1.0)
637
+
638
+ actual_grid = grid * max_time
639
+ distances = torch.cdist(actual_grid[:, None], time_row[:, None])
640
+ nearest = distances.argmin(dim=1)
641
+
642
+ usable = min(time_row.numel(), grid.numel())
643
+ if usable == 0:
644
+ continue
645
+
646
+ canonical_vals[row, :usable] = obs_row[nearest[:usable]]
647
+ canonical_times[row, :usable] = time_row[nearest[:usable]]
648
+ canonical_mask[row, :usable] = True
649
+
650
+ canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
651
+
652
+ return canonical_vals, canonical_times, canonical_mask
653
+
654
+ def _prepare_empirical_as_canonical(
655
+ self,
656
+ empirical_obs: Tensor,
657
+ empirical_times: Tensor,
658
+ empirical_mask: Tensor,
659
+ ) -> Tuple[Tensor, Tensor, Tensor]:
660
+ """Treat empirical observations as already canonical.
661
+
662
+ This helper:
663
+
664
+ * does **not** build any canonical grid;
665
+ * does **not** normalise or re-scale time;
666
+ * simply copies valid empirical points in their original order into
667
+ fixed-size tensors, padding with zeros / False as needed.
668
+
669
+ The resulting tensors have width equal to the canonical capacity so
670
+ that they can be passed to :meth:`_assemble_from_canonical`.
671
+ """
672
+ device = empirical_obs.device
673
+ dtype = empirical_obs.dtype
674
+ batch, _ = empirical_obs.shape
675
+ canonical_cap = self._canonical_grid_capacity()
676
+
677
+ canonical_vals = torch.zeros(batch, canonical_cap, dtype=dtype, device=device)
678
+ canonical_times = torch.zeros_like(canonical_vals)
679
+ canonical_mask = torch.zeros(batch, canonical_cap, dtype=torch.bool, device=device)
680
+
681
+ if canonical_cap == 0:
682
+ return canonical_vals, canonical_times, canonical_mask
683
+
684
+ for row in range(batch):
685
+ valid_idx = empirical_mask[row].nonzero(as_tuple=True)[0]
686
+ if valid_idx.numel() == 0:
687
+ continue
688
+
689
+ take_count = min(int(valid_idx.numel()), canonical_cap)
690
+ take_idx = valid_idx[:take_count]
691
+
692
+ canonical_vals[row, :take_count] = empirical_obs[row, take_idx]
693
+ canonical_times[row, :take_count] = empirical_times[row, take_idx]
694
+ canonical_mask[row, :take_count] = True
695
+
696
+ canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
697
+
698
+ return canonical_vals, canonical_times, canonical_mask
699
+
700
+ def _generate_raw(
701
+ self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
702
+ ) -> Tuple[
703
+ Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
704
+ ]:
705
+ """Deterministic canonical PK sampling for synthetic simulations."""
706
+ time_scales: Optional[Tensor] = kwargs.get("time_scales")
707
+ if time_scales is None:
708
+ raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
709
+
710
+ canonical_vals, canonical_times, canonical_mask, rescaled = (
711
+ self._align_simulation_to_canonical(
712
+ full_simulation,
713
+ full_simulation_times,
714
+ time_scales=time_scales,
715
+ num_obs_sampler=kwargs.get("num_obs_sampler"),
716
+ )
717
+ )
718
+
719
+ obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
720
+ canonical_vals,
721
+ canonical_times,
722
+ canonical_mask,
723
+ generator=kwargs.get("generator"),
724
+ )
725
+
726
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
727
+
728
+ def _generate_random(
729
+ self,
730
+ full_simulation: Tensor,
731
+ full_simulation_times: Tensor,
732
+ *,
733
+ time_scales: Tensor,
734
+ generator: Optional[torch.Generator] = None,
735
+ ) -> Tuple[
736
+ Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
737
+ ]:
738
+ """Randomised variant of canonical observation generation.
739
+
740
+ The pre- and post-peak segments are sampled from uniform distributions
741
+ bounded by the canonical limits. This keeps the semantic meaning of the
742
+ selected points while injecting stochasticity that can improve
743
+ robustness during training.
744
+ """
745
+ device, dtype = full_simulation.device, full_simulation.dtype
746
+ batch = full_simulation.size(0)
747
+ time_steps = int(full_simulation_times.size(1))
748
+ if batch == 0 or time_steps == 0:
749
+ canonical_vals = torch.zeros(batch, 0, dtype=dtype, device=device)
750
+ canonical_times = torch.zeros(batch, 0, dtype=dtype, device=device)
751
+ canonical_mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
752
+ obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
753
+ canonical_vals, canonical_times, canonical_mask, generator=generator
754
+ )
755
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
756
+ t_peak, t_half = time_scales[0].item(), time_scales[1].item()
757
+
758
+ n_pre = len(self._PEAK_PHASE_MULTIPLIERS)
759
+ n_post = len(self._POST_PEAK_HALF_LIFE_MULTIPLIERS)
760
+
761
+ # Uniform samples before peak
762
+ pre_times = torch.rand(n_pre, device=device, dtype=dtype) * t_peak
763
+ # Always include the peak
764
+ peak_time = torch.tensor([t_peak], device=device, dtype=dtype)
765
+ # Uniform samples after peak
766
+ post_times = []
767
+ for mult in self._POST_PEAK_HALF_LIFE_MULTIPLIERS:
768
+ t_end = t_peak + mult * t_half
769
+ t_rand = torch.empty(1, device=device, dtype=dtype).uniform_(t_peak, t_end)
770
+ post_times.append(t_rand)
771
+ post_times = torch.cat(post_times, dim=0)
772
+
773
+ # Truncate to canonical capacity
774
+ grid = torch.cat([pre_times, peak_time, post_times], dim=0)
775
+ canonical_cap = self._canonical_grid_capacity()
776
+ grid = grid[:canonical_cap]
777
+
778
+ # Map grid to nearest simulation points
779
+ ref_times = full_simulation_times[0]
780
+ idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
781
+ idx, _ = idx.sort()
782
+ valid_mask = torch.ones_like(idx, dtype=torch.bool)
783
+ idx, valid_mask = self._deduplicate_sorted_indices(idx, valid_mask)
784
+ gather_idx = idx[None, :].expand(batch, -1)
785
+ batch_idx = torch.arange(batch, device=device)[:, None]
786
+
787
+ canonical_vals = full_simulation[batch_idx, gather_idx]
788
+ canonical_times = full_simulation_times[batch_idx, gather_idx]
789
+ invalid_slots = ~valid_mask
790
+ if invalid_slots.any():
791
+ canonical_vals[:, invalid_slots] = 0
792
+ canonical_times[:, invalid_slots] = 0
793
+
794
+ canonical_mask = valid_mask[None, :].expand(batch, -1).clone()
795
+ canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
796
+
797
+ obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
798
+ canonical_vals, canonical_times, canonical_mask, generator=generator
799
+ )
800
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
801
+
802
+ def generate(
803
+ self,
804
+ full_simulation: Tensor,
805
+ full_simulation_times: Tensor,
806
+ **kwargs,
807
+ ) -> Tuple[
808
+ Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
809
+ ]:
810
+ """Generate PK observations for synthetic simulations.
811
+
812
+ With probability ``randomize_prob`` (default 0.5) the method delegates
813
+ to :meth:`_generate_random`; otherwise the deterministic
814
+ :meth:`_generate_raw` path is taken. Setting ``deterministic_only=True``
815
+ forces the deterministic branch. Both paths require ``time_scales`` and
816
+ honour the ``add_rem`` flag.
817
+ """
818
+ time_scales: Optional[Tensor] = kwargs.get("time_scales")
819
+ if time_scales is None:
820
+ raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
821
+
822
+ deterministic_only = kwargs.pop("deterministic_only", False)
823
+
824
+ use_random = False
825
+ if not deterministic_only:
826
+ use_random = torch.rand(()) < getattr(self, "randomize_prob", 0.5)
827
+
828
+ if use_random:
829
+ obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_random(
830
+ full_simulation,
831
+ full_simulation_times,
832
+ time_scales=time_scales,
833
+ generator=kwargs.get("generator"),
834
+ )
835
+ else:
836
+ obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_raw(
837
+ full_simulation,
838
+ full_simulation_times,
839
+ **kwargs,
840
+ )
841
+
842
+ if not self.observations_config.add_rem:
843
+ rem_sim = rem_time = rem_mask = None
844
+
845
+ return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
846
+
847
+ def generate_empirical(
848
+ self,
849
+ empirical_obs: Tensor,
850
+ empirical_times: Tensor,
851
+ empirical_mask: Tensor,
852
+ *,
853
+ generator: Optional[torch.Generator] = None,
854
+ ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
855
+ """Generate observations from empirical data.
856
+
857
+ Empirical measurements are assumed to already live on their correct
858
+ time grid. This routine:
859
+
860
+ * does **not** perform canonical alignment or time normalisation;
861
+ * only pads / truncates sequences to match the internal capacity;
862
+ * applies past/future splitting via :meth:`_assemble_from_canonical`
863
+ using the configuration in :class:`ObservationsConfig`.
864
+
865
+ Synthetic simulations keep using the canonical alignment path.
866
+ """
867
+ canonical_vals, canonical_times, canonical_mask = self._prepare_empirical_as_canonical(
868
+ empirical_obs,
869
+ empirical_times,
870
+ empirical_mask,
871
+ )
872
+
873
+ obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
874
+ canonical_vals,
875
+ canonical_times,
876
+ canonical_mask,
877
+ generator=generator,
878
+ )
879
+
880
+ if not self.observations_config.add_rem:
881
+ rem_sim = rem_time = rem_mask = None
882
+
883
+ return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask
884
+
885
+
886
+ class PKPeakHalfLifeStrategyOld(ObservationStrategy):
887
+ """Observation strategy tailored to pharmacokinetic (PK) curves.
888
+
889
+ The strategy samples observations around the absorption peak and along the
890
+ elimination phase of a PK simulation. It uses a canonical grid composed of
891
+ four segments:
892
+
893
+ 1. Several points before the peak that are proportional to the configured
894
+ peak time.
895
+ 2. The peak itself.
896
+ 3. Several points after the peak spaced by multiples of the provided
897
+ half-life.
898
+ 4. Optional remainder points that are handed back to the caller when
899
+ ``add_rem`` is enabled.
900
+
901
+ The resulting observation tensor can be optionally split into "past" and
902
+ "future" observations according to :class:`ObservationsConfig`.
903
+
904
+ Parameters
905
+ ----------
906
+ observations_config:
907
+ Simulation-level configuration that defines sampling constraints such
908
+ as ``max_num_obs`` or the minimum/maximum number of "past" points when
909
+ a split is requested.
910
+ meta_config:
911
+ Meta-study configuration. Only the ``time_num_steps`` attribute is
912
+ used and allows clamping the canonical grid to the resolution of the
913
+ simulator.
914
+ """
915
+
916
+ _PEAK_PHASE_MULTIPLIERS = (0.1, 0.2, 0.5, 0.8)
917
+ _POST_PEAK_HALF_LIFE_MULTIPLIERS = (
918
+ 0.25,
919
+ 0.50,
920
+ 1.00,
921
+ 2.00,
922
+ 4.00,
923
+ 6.00,
924
+ 8.00,
925
+ 9.00,
926
+ 14.0,
927
+ 19.0,
928
+ 30.0,
929
+ )
930
+ _RAW_CANONICAL_POINTS = len(_PEAK_PHASE_MULTIPLIERS) + 1 + len(_POST_PEAK_HALF_LIFE_MULTIPLIERS)
931
+
932
+ def __init__(
933
+ self, observations_config: ObservationsConfig, meta_config: MetaStudyConfig
934
+ ) -> None:
935
+ super().__init__(observations_config, meta_config)
936
+ self.max_num_obs = observations_config.max_num_obs
937
+ self.split_past_future = observations_config.split_past_future
938
+ self.min_past = observations_config.min_past
939
+ self.max_past = observations_config.max_past
940
+ self.generative_bias = observations_config.generative_bias
941
+ # ``None`` indicates that the number of past observations should be
942
+ # sampled according to the standard strategy. When populated it forces
943
+ # :meth:`_assemble_from_canonical` to always select the provided number
944
+ # of past observations (within the valid range).
945
+ self._fixed_past_obs_count: Optional[int] = None
946
+
947
+ def fix_past_selection(self, obs_count: int) -> None:
948
+ """Force the past observation count to ``obs_count`` when splitting.
949
+
950
+ The override is only applied when ``split_past_future`` is enabled. The
951
+ provided ``obs_count`` must fall within ``[min_past, max_past]``.
952
+ """
953
+
954
+ if not self.split_past_future:
955
+ return
956
+
957
+ if obs_count < self.min_past or obs_count > self.max_past:
958
+ raise ValueError(
959
+ "Fixed past observation count must lie within the configured min/max bounds."
960
+ )
961
+ self._fixed_past_obs_count = int(obs_count)
962
+
963
+ def release_past_selection(self) -> None:
964
+ """Return to the default random past selection behaviour."""
965
+
966
+ self._fixed_past_obs_count = None
967
+
968
+ @classmethod
969
+ def _build_canonical_grid(
970
+ cls,
971
+ *,
972
+ t_peak: float,
973
+ t_half: float,
974
+ device: torch.device,
975
+ dtype: torch.dtype,
976
+ ) -> Tensor:
977
+ """Construct the canonical grid for a single simulation.
978
+
979
+ The grid covers the pre-peak, peak and post-peak regime of the curve by
980
+ scaling two fundamental quantities supplied at runtime: the time of the
981
+ peak concentration ``t_peak`` and the half-life ``t_half``. Both values
982
+ are expected to be expressed in the same units as the simulation time
983
+ axis.
984
+
985
+ Parameters
986
+ ----------
987
+ t_peak:
988
+ Estimated time of the concentration peak.
989
+ t_half:
990
+ Estimated half-life used to position post-peak points.
991
+ device, dtype:
992
+ Torch device and dtype for the returned tensor so that it matches
993
+ the simulation tensors that will be gathered later on.
994
+
995
+ Returns
996
+ -------
997
+ torch.Tensor
998
+ One-dimensional tensor containing monotonically increasing times
999
+ representing the canonical sampling grid.
1000
+ """
1001
+ before_peak = [mult * t_peak for mult in cls._PEAK_PHASE_MULTIPLIERS]
1002
+ after_peak = [t_peak + mult * t_half for mult in cls._POST_PEAK_HALF_LIFE_MULTIPLIERS]
1003
+ values = before_peak + [t_peak] + after_peak
1004
+ return torch.tensor(values, device=device, dtype=dtype)
1005
+
1006
+ def _canonical_grid_capacity(self) -> int:
1007
+ """Return the number of canonical grid points available.
1008
+
1009
+ The capacity is the minimum between the simulator resolution and the
1010
+ theoretical number of canonical points. This ensures that the
1011
+ observation tensors never attempt to gather indices outside the
1012
+ original simulation.
1013
+
1014
+ Returns
1015
+ -------
1016
+ int
1017
+ Maximum number of grid points that can be sampled for each
1018
+ simulation in the batch.
1019
+ """
1020
+ time_steps = getattr(self.meta_config, "time_num_steps", self.max_num_obs)
1021
+ return max(
1022
+ 0,
1023
+ min(int(self.max_num_obs), int(time_steps), self._RAW_CANONICAL_POINTS),
1024
+ )
1025
+
1026
+ def _get_shapes_raw(self) -> Tuple[int, int]:
1027
+ """Compute the maximum number of observation and remainder slots.
1028
+
1029
+ The method applies the canonical grid capacity alongside the
1030
+ ``split_past_future`` configuration to decide how many points can be
1031
+ surfaced directly as observations and how many should be exposed as
1032
+ "remaining" (future) points.
1033
+
1034
+ Returns
1035
+ -------
1036
+ tuple[int, int]
1037
+ The first entry is the maximum number of observations. The second
1038
+ entry is the maximum number of remaining observations when
1039
+ ``add_rem`` is enabled.
1040
+
1041
+ Raises
1042
+ ------
1043
+ ValueError
1044
+ If a past/future split is requested but the canonical capacity
1045
+ cannot satisfy the configured ``min_past`` requirement.
1046
+ """
1047
+ canonical_cap = self._canonical_grid_capacity()
1048
+ if canonical_cap == 0:
1049
+ return 0, 0
1050
+
1051
+ if self.split_past_future:
1052
+ if canonical_cap < self.min_past:
1053
+ raise ValueError("Canonical grid capacity is smaller than the configured min_past")
1054
+ max_obs = min(self.max_past, canonical_cap)
1055
+ max_rem = max(0, canonical_cap - self.min_past)
1056
+ else:
1057
+ max_obs = canonical_cap
1058
+ max_rem = canonical_cap
1059
+
1060
+ return max_obs, max_rem
1061
+
1062
+ @staticmethod
1063
+ def _deduplicate_sorted_indices(
1064
+ idx: Tensor, valid_mask: Optional[Tensor] = None
1065
+ ) -> Tuple[Tensor, Tensor]:
1066
+ """Collapse repeated gather indices while preserving alignment."""
1067
+
1068
+ if valid_mask is None:
1069
+ valid_mask = torch.ones_like(idx, dtype=torch.bool)
1070
+
1071
+ if idx.numel() <= 1:
1072
+ return idx, valid_mask
1073
+
1074
+ duplicate_mask = torch.zeros_like(idx, dtype=torch.bool)
1075
+ duplicate_mask[1:] = idx[1:] == idx[:-1]
1076
+
1077
+ if not duplicate_mask.any():
1078
+ return idx, valid_mask
1079
+
1080
+ unique_mask = ~duplicate_mask
1081
+ kept_idx = idx[unique_mask]
1082
+ duplicate_idx = idx[duplicate_mask]
1083
+
1084
+ padded_idx = torch.empty_like(idx)
1085
+ padded_idx[: kept_idx.numel()] = kept_idx
1086
+ padded_idx[kept_idx.numel() :] = duplicate_idx
1087
+
1088
+ kept_valid = valid_mask[unique_mask]
1089
+ padded_mask = torch.zeros_like(valid_mask)
1090
+ padded_mask[: kept_valid.numel()] = kept_valid
1091
+
1092
+ return padded_idx, padded_mask
1093
+
1094
+ def _assemble_from_canonical(
1095
+ self,
1096
+ canonical_vals: Tensor,
1097
+ canonical_times: Tensor,
1098
+ canonical_mask: Tensor,
1099
+ *,
1100
+ generator: Optional[torch.Generator] = None,
1101
+ ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1102
+ """Convert canonical tensors into output observations.
1103
+
1104
+ The canonical representation stores **all** admissible samples for a
1105
+ batch element. This helper slices the canonical tensors into the
1106
+ "past" observations that will be returned to the caller and (when
1107
+ requested) the "future" remainder. The selection proceeds row by row:
1108
+
1109
+ 1. ``canonical_mask`` is inspected to identify the indices that contain
1110
+ valid information. These are the only points that may be surfaced.
1111
+ 2. When ``split_past_future`` is ``False`` every valid point is treated
1112
+ as part of the observation history up to the configured capacity.
1113
+ 3. Otherwise we randomly draw ``obs_count`` between ``min_past`` and
1114
+ ``max_past`` (capped by the number of valid canonical entries). The
1115
+ first ``obs_count`` indices become past observations while the
1116
+ remaining valid points are placed in the remainder tensors.
1117
+
1118
+ Parameters
1119
+ ----------
1120
+ canonical_vals, canonical_times:
1121
+ Tensors produced by aligning the simulation or empirical data to
1122
+ the canonical grid.
1123
+ canonical_mask:
1124
+ Boolean tensor marking valid entries for each batch element.
1125
+ generator:
1126
+ Optional random generator used when sampling ``obs_count`` in
1127
+ split-past/future mode.
1128
+
1129
+ Returns
1130
+ -------
1131
+ tuple of tensors
1132
+ Observation and remaining tensors matching the shapes dictated by
1133
+ :meth:`_get_shapes_raw`. All tensors share the same device and
1134
+ dtype as the inputs. ``None`` is returned for remainder tensors
1135
+ when the capacity is zero.
1136
+ """
1137
+ max_obs, max_rem = self._get_shapes_raw()
1138
+ device = canonical_vals.device
1139
+ dtype = canonical_vals.dtype
1140
+ batch, _ = canonical_vals.shape
1141
+
1142
+ obs_out = torch.zeros(batch, max_obs, dtype=dtype, device=device)
1143
+ obs_time = torch.zeros_like(obs_out)
1144
+ obs_mask = torch.zeros(batch, max_obs, dtype=torch.bool, device=device)
1145
+
1146
+ rem_sim = rem_time = rem_mask = None
1147
+ if max_rem > 0:
1148
+ rem_sim = torch.zeros(batch, max_rem, dtype=dtype, device=device)
1149
+ rem_time = torch.zeros_like(rem_sim)
1150
+ rem_mask = torch.zeros(batch, max_rem, dtype=torch.bool, device=device)
1151
+
1152
+ gen = generator if generator is not None else torch.default_generator
1153
+
1154
+ for row in range(batch):
1155
+ valid_idx = canonical_mask[row].nonzero(as_tuple=True)[0]
1156
+ total_valid = valid_idx.numel()
1157
+ if total_valid == 0:
1158
+ continue
1159
+
1160
+ if self.split_past_future:
1161
+ low = min(self.min_past, total_valid)
1162
+ high = min(self.max_past, total_valid)
1163
+ if self._fixed_past_obs_count is not None:
1164
+ obs_count = min(self._fixed_past_obs_count, total_valid)
1165
+ else:
1166
+ obs_count = _sample_past_count_with_bias(
1167
+ low=low,
1168
+ high=high,
1169
+ generative_bias=self.generative_bias,
1170
+ generator=gen,
1171
+ device=device,
1172
+ )
1173
+ obs_count = min(obs_count, max_obs)
1174
+ else:
1175
+ obs_count = min(total_valid, max_obs)
1176
+
1177
+ if obs_count > 0:
1178
+ take = valid_idx[:obs_count]
1179
+ obs_out[row, :obs_count] = canonical_vals[row, take]
1180
+ obs_time[row, :obs_count] = canonical_times[row, take]
1181
+ obs_mask[row, :obs_count] = True
1182
+
1183
+ if rem_sim is not None:
1184
+ rem_candidates = valid_idx[obs_count:]
1185
+ rem_count = min(rem_candidates.numel(), max_rem)
1186
+ if rem_count > 0:
1187
+ rem_idx = rem_candidates[:rem_count]
1188
+ rem_sim[row, :rem_count] = canonical_vals[row, rem_idx]
1189
+ rem_time[row, :rem_count] = canonical_times[row, rem_idx]
1190
+ rem_mask[row, :rem_count] = True
1191
+
1192
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask
1193
+
1194
+ def _align_simulation_to_canonical(
1195
+ self,
1196
+ full_simulation: Tensor,
1197
+ full_simulation_times: Tensor,
1198
+ *,
1199
+ time_scales: Tensor,
1200
+ num_obs_sampler: Optional[Callable[[int], Tensor]] = None,
1201
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1202
+ """Gather the canonical samples from a simulated PK curve.
1203
+
1204
+ The routine creates the canonical grid described in the configuration
1205
+ (using the provided ``time_scales``) and then performs a nearest-neighbour
1206
+ lookup on the simulated trajectory. Each grid location picks the
1207
+ closest time point from the reference simulation (the first batch row);
1208
+ the same indices are applied to every batch element so that values and
1209
+ times remain aligned across the batch. ``num_obs_sampler`` can further
1210
+ prune the resulting grid by specifying how many of those canonical
1211
+ points should remain valid for each row.
1212
+
1213
+ Parameters
1214
+ ----------
1215
+ full_simulation, full_simulation_times:
1216
+ Batched tensors representing the simulated concentration curve and
1217
+ its time axis.
1218
+ time_scales:
1219
+ Two-element tensor with ``t_peak`` and ``t_half`` scaling factors.
1220
+ num_obs_sampler:
1221
+ Optional callable that samples how many canonical points should be
1222
+ retained for each batch element.
1223
+
1224
+ Returns
1225
+ -------
1226
+ tuple of torch.Tensor
1227
+ The canonical values, their corresponding times, a boolean mask of
1228
+ valid entries and the (cloned) ``time_scales`` tensor. When the
1229
+ canonical capacity is zero, zero-sized tensors are returned for the
1230
+ first three entries.
1231
+ """
1232
+ device = full_simulation.device
1233
+ dtype = full_simulation.dtype
1234
+ batch, _ = full_simulation.shape
1235
+ time_steps = int(full_simulation_times.size(1))
1236
+
1237
+ # Empty worker slices (B=0) and zero-step trajectories are valid edge
1238
+ # cases; return empty canonical tensors and keep shape assembly
1239
+ # delegated to _assemble_from_canonical.
1240
+ if batch == 0 or time_steps == 0:
1241
+ zero = torch.zeros(batch, 0, dtype=dtype, device=device)
1242
+ mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
1243
+ return zero, zero, mask, time_scales.clone()
1244
+
1245
+ canonical_cap = self._canonical_grid_capacity()
1246
+ if canonical_cap == 0:
1247
+ zero = torch.zeros(batch, 0, dtype=dtype, device=device)
1248
+ mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
1249
+ return zero, zero, mask, time_scales.clone()
1250
+
1251
+ grid = self._build_canonical_grid(
1252
+ t_peak=time_scales[0].item(),
1253
+ t_half=time_scales[1].item(),
1254
+ device=device,
1255
+ dtype=dtype,
1256
+ )[:canonical_cap]
1257
+
1258
+ ref_times = full_simulation_times[0]
1259
+ min_time = ref_times.min()
1260
+ max_time = ref_times.max()
1261
+ grid_valid_mask = (grid >= min_time) & (grid <= max_time)
1262
+
1263
+ idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
1264
+ idx, order = idx.sort()
1265
+ grid_valid_mask = grid_valid_mask[order]
1266
+ idx, grid_valid_mask = self._deduplicate_sorted_indices(idx, grid_valid_mask)
1267
+
1268
+ gather_idx = idx[None, :].expand(batch, -1)
1269
+ batch_idx = torch.arange(batch, device=device)[:, None]
1270
+
1271
+ canonical_vals = full_simulation[batch_idx, gather_idx]
1272
+ canonical_times = full_simulation_times[batch_idx, gather_idx]
1273
+
1274
+ invalid_slots = ~grid_valid_mask
1275
+ if invalid_slots.any():
1276
+ canonical_vals[:, invalid_slots] = 0
1277
+ canonical_times[:, invalid_slots] = 0
1278
+
1279
+ if num_obs_sampler is None:
1280
+ total_counts = torch.full((batch,), canonical_cap, dtype=torch.long, device=device)
1281
+ else:
1282
+ sampled = num_obs_sampler(batch).to(device=device).long()
1283
+ total_counts = sampled.clamp(min=0, max=canonical_cap)
1284
+
1285
+ max_valid = int(grid_valid_mask.sum().item())
1286
+ if max_valid == 0:
1287
+ total_counts.zero_()
1288
+ else:
1289
+ total_counts.clamp_(max=max_valid)
1290
+
1291
+ valid_order = grid_valid_mask.long().cumsum(dim=0) - 1
1292
+ valid_order = torch.where(
1293
+ grid_valid_mask,
1294
+ valid_order,
1295
+ torch.full_like(valid_order, -1, dtype=valid_order.dtype),
1296
+ )
1297
+ canonical_mask = grid_valid_mask[None, :] & (valid_order[None, :] < total_counts[:, None])
1298
+ canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
1299
+
1300
+ return canonical_vals, canonical_times, canonical_mask, time_scales.clone()
1301
+
1302
+ def _align_empirical_to_canonical(
1303
+ self,
1304
+ empirical_obs: Tensor,
1305
+ empirical_times: Tensor,
1306
+ empirical_mask: Tensor,
1307
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1308
+ """Project empirical observations onto the canonical grid.
1309
+
1310
+ The projection normalises the empirical time axis to estimate the peak
1311
+ and half-life from the data itself. This allows harmonising real
1312
+ measurements with the canonical layout used during simulation-driven
1313
+ training.
1314
+
1315
+ Parameters
1316
+ ----------
1317
+ empirical_obs, empirical_times, empirical_mask:
1318
+ Batched tensors storing empirical observations, the corresponding
1319
+ time stamps and a mask of valid entries.
1320
+
1321
+ Returns
1322
+ -------
1323
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
1324
+ Canonical values, times and boolean masks aligned to the canonical
1325
+ sampling scheme.
1326
+ """
1327
+ device = empirical_obs.device
1328
+ dtype = empirical_obs.dtype
1329
+ batch, _ = empirical_obs.shape
1330
+ canonical_cap = self._canonical_grid_capacity()
1331
+
1332
+ canonical_vals = torch.zeros(batch, canonical_cap, dtype=dtype, device=device)
1333
+ canonical_times = torch.zeros_like(canonical_vals)
1334
+ canonical_mask = torch.zeros(batch, canonical_cap, dtype=torch.bool, device=device)
1335
+
1336
+ if canonical_cap == 0:
1337
+ return canonical_vals, canonical_times, canonical_mask
1338
+
1339
+ for row in range(batch):
1340
+ valid_idx = empirical_mask[row].nonzero(as_tuple=True)[0]
1341
+ if valid_idx.numel() == 0:
1342
+ continue
1343
+
1344
+ obs_row = empirical_obs[row, valid_idx]
1345
+ time_row = empirical_times[row, valid_idx]
1346
+ max_time = torch.maximum(time_row.max(), torch.tensor(1.0, device=device))
1347
+ norm_time = time_row / max_time
1348
+
1349
+ peak_idx = obs_row.argmax().item()
1350
+ t_peak = norm_time[peak_idx].item()
1351
+ post_times = norm_time[peak_idx:]
1352
+ post_obs = obs_row[peak_idx:]
1353
+ half_level = obs_row[peak_idx] / 2
1354
+ below_half = (post_obs <= half_level).nonzero(as_tuple=True)[0]
1355
+ if below_half.numel() == 0:
1356
+ half_time = post_times[-1].item()
1357
+ else:
1358
+ half_time = post_times[below_half[0]].item()
1359
+ t_half = max(half_time - t_peak, 1e-3)
1360
+
1361
+ grid = self._build_canonical_grid(
1362
+ t_peak=t_peak if t_peak > 0 else 1e-3,
1363
+ t_half=t_half,
1364
+ device=device,
1365
+ dtype=dtype,
1366
+ )[:canonical_cap].clamp(max=1.0)
1367
+
1368
+ actual_grid = grid * max_time
1369
+ distances = torch.cdist(actual_grid[:, None], time_row[:, None])
1370
+ nearest = distances.argmin(dim=1)
1371
+
1372
+ usable = min(time_row.numel(), grid.numel())
1373
+ if usable == 0:
1374
+ continue
1375
+
1376
+ canonical_vals[row, :usable] = obs_row[nearest[:usable]]
1377
+ canonical_times[row, :usable] = time_row[nearest[:usable]]
1378
+ canonical_mask[row, :usable] = True
1379
+
1380
+ canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
1381
+
1382
+ return canonical_vals, canonical_times, canonical_mask
1383
+
1384
+ def _generate_raw(
1385
+ self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
1386
+ ) -> Tuple[
1387
+ Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
1388
+ ]:
1389
+ time_scales: Optional[Tensor] = kwargs.get("time_scales")
1390
+ if time_scales is None:
1391
+ raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
1392
+
1393
+ canonical_vals, canonical_times, canonical_mask, rescaled = (
1394
+ self._align_simulation_to_canonical(
1395
+ full_simulation,
1396
+ full_simulation_times,
1397
+ time_scales=time_scales,
1398
+ num_obs_sampler=kwargs.get("num_obs_sampler"),
1399
+ )
1400
+ )
1401
+
1402
+ obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
1403
+ canonical_vals,
1404
+ canonical_times,
1405
+ canonical_mask,
1406
+ generator=kwargs.get("generator"),
1407
+ )
1408
+
1409
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
1410
+
1411
+ def _generate_random(
1412
+ self,
1413
+ full_simulation: Tensor,
1414
+ full_simulation_times: Tensor,
1415
+ *,
1416
+ time_scales: Tensor,
1417
+ generator: Optional[torch.Generator] = None,
1418
+ ) -> Tuple[
1419
+ Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
1420
+ ]:
1421
+ """Randomized variant of canonical observation generation.
1422
+
1423
+ Instead of fixed multipliers, the pre- and post-peak segments are
1424
+ sampled from uniform distributions bounded by the canonical limits.
1425
+ This keeps the semantic meaning of the selected points while injecting
1426
+ stochasticity that improves robustness when training amortised
1427
+ inference models.
1428
+ """
1429
+ device, dtype = full_simulation.device, full_simulation.dtype
1430
+ batch = full_simulation.size(0)
1431
+ time_steps = int(full_simulation_times.size(1))
1432
+ if batch == 0 or time_steps == 0:
1433
+ canonical_vals = torch.zeros(batch, 0, dtype=dtype, device=device)
1434
+ canonical_times = torch.zeros(batch, 0, dtype=dtype, device=device)
1435
+ canonical_mask = torch.zeros(batch, 0, dtype=torch.bool, device=device)
1436
+ obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
1437
+ canonical_vals, canonical_times, canonical_mask, generator=generator
1438
+ )
1439
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
1440
+ t_peak, t_half = time_scales[0].item(), time_scales[1].item()
1441
+
1442
+ n_pre = len(self._PEAK_PHASE_MULTIPLIERS)
1443
+ n_post = len(self._POST_PEAK_HALF_LIFE_MULTIPLIERS)
1444
+
1445
+ # Uniform samples before peak
1446
+ pre_times = torch.rand(n_pre, device=device, dtype=dtype) * t_peak
1447
+ # Always include the peak
1448
+ peak_time = torch.tensor([t_peak], device=device, dtype=dtype)
1449
+ # Uniform samples after peak
1450
+ post_times = []
1451
+ for mult in self._POST_PEAK_HALF_LIFE_MULTIPLIERS:
1452
+ t_end = t_peak + mult * t_half
1453
+ t_rand = torch.empty(1, device=device, dtype=dtype).uniform_(t_peak, t_end)
1454
+ post_times.append(t_rand)
1455
+ post_times = torch.cat(post_times, dim=0)
1456
+
1457
+ # Truncate to canonical capacity
1458
+ grid = torch.cat([pre_times, peak_time, post_times], dim=0)
1459
+ canonical_cap = self._canonical_grid_capacity()
1460
+ grid = grid[:canonical_cap]
1461
+
1462
+ # Map grid to nearest simulation points
1463
+ ref_times = full_simulation_times[0]
1464
+ idx = torch.cdist(grid[:, None], ref_times[:, None]).argmin(dim=1)
1465
+ idx, _ = idx.sort()
1466
+ valid_mask = torch.ones_like(idx, dtype=torch.bool)
1467
+ idx, valid_mask = self._deduplicate_sorted_indices(idx, valid_mask)
1468
+ gather_idx = idx[None, :].expand(batch, -1)
1469
+ batch_idx = torch.arange(batch, device=device)[:, None]
1470
+
1471
+ canonical_vals = full_simulation[batch_idx, gather_idx]
1472
+ canonical_times = full_simulation_times[batch_idx, gather_idx]
1473
+ invalid_slots = ~valid_mask
1474
+ if invalid_slots.any():
1475
+ canonical_vals[:, invalid_slots] = 0
1476
+ canonical_times[:, invalid_slots] = 0
1477
+
1478
+ canonical_mask = valid_mask[None, :].expand(batch, -1).clone()
1479
+ canonical_mask = self._drop_non_positive_times_from_mask(canonical_times, canonical_mask)
1480
+
1481
+ obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
1482
+ canonical_vals, canonical_times, canonical_mask, generator=generator
1483
+ )
1484
+ return obs_out, obs_time, obs_mask, rem_sim, rem_time, rem_mask, time_scales.clone()
1485
+
1486
+ def generate(
1487
+ self,
1488
+ full_simulation: Tensor,
1489
+ full_simulation_times: Tensor,
1490
+ **kwargs,
1491
+ ) -> Tuple[
1492
+ Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor
1493
+ ]:
1494
+ """Generate PK observations using canonical or randomized schedules.
1495
+
1496
+ With probability ``randomize_prob`` (default 0.5) the method delegates
1497
+ to :meth:`_generate_random`; otherwise the deterministic
1498
+ :meth:`_generate_raw` path is taken. Setting the keyword argument
1499
+ ``deterministic_only=True`` forces the deterministic branch regardless
1500
+ of the random draw. Both paths require the caller to provide
1501
+ ``time_scales`` specifying the peak and half-life. The method honours
1502
+ the ``add_rem`` flag by optionally returning remainder tensors.
1503
+ """
1504
+ time_scales: Optional[Tensor] = kwargs.get("time_scales")
1505
+ if time_scales is None:
1506
+ raise ValueError("time_scales must be provided for PKPeakHalfLifeStrategy")
1507
+
1508
+ deterministic_only = kwargs.pop("deterministic_only", False)
1509
+
1510
+ use_random = False
1511
+ if not deterministic_only:
1512
+ use_random = torch.rand(()) < getattr(self, "randomize_prob", 0.5)
1513
+
1514
+ if use_random:
1515
+ obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_random(
1516
+ full_simulation,
1517
+ full_simulation_times,
1518
+ time_scales=time_scales,
1519
+ generator=kwargs.get("generator"),
1520
+ )
1521
+ else:
1522
+ obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled = self._generate_raw(
1523
+ full_simulation,
1524
+ full_simulation_times,
1525
+ **kwargs,
1526
+ )
1527
+
1528
+ if not self.observations_config.add_rem:
1529
+ rem_sim = rem_time = rem_mask = None
1530
+
1531
+ return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask, rescaled
1532
+
1533
+ def generate_empirical(
1534
+ self,
1535
+ empirical_obs: Tensor,
1536
+ empirical_times: Tensor,
1537
+ empirical_mask: Tensor,
1538
+ *,
1539
+ generator: Optional[torch.Generator] = None,
1540
+ ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1541
+ canonical_vals, canonical_times, canonical_mask = self._align_empirical_to_canonical(
1542
+ empirical_obs,
1543
+ empirical_times,
1544
+ empirical_mask,
1545
+ )
1546
+
1547
+ obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask = self._assemble_from_canonical(
1548
+ canonical_vals,
1549
+ canonical_times,
1550
+ canonical_mask,
1551
+ generator=generator,
1552
+ )
1553
+
1554
+ if not self.observations_config.add_rem:
1555
+ rem_sim = rem_time = rem_mask = None
1556
+
1557
+ return obs, obs_time, obs_mask, rem_sim, rem_time, rem_mask
1558
+
1559
+
1560
+ class FixPastTimeRandomSelectionStrategy(ObservationStrategy):
1561
+ """Randomly sample observations and split with fixed-capacity past/future slots.
1562
+
1563
+ For ``split_past_future=True`` this strategy enforces the contract:
1564
+ ``obs_capacity=max_past`` and ``rem_capacity=max_num_obs-max_past``
1565
+ (subject to ``fixed_M_max=min(max_num_obs, time_num_steps)``).
1566
+ """
1567
+
1568
+ def __init__(self, config: ObservationsConfig, meta_config: MetaStudyConfig):
1569
+ super().__init__(config, meta_config)
1570
+ time_steps = getattr(meta_config, "time_num_steps", config.max_num_obs)
1571
+ self.fixed_M_max = min(config.max_num_obs, time_steps)
1572
+ self.split_past_future = config.split_past_future
1573
+ self.max_past = config.max_past
1574
+ self.min_past = config.min_past
1575
+ self.generative_bias = config.generative_bias
1576
+ self.boundary_ratio = getattr(config, "past_time_ratio", 0.1)
1577
+
1578
+ def _generate_raw(self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs):
1579
+ return fix_past_time_random_selection(
1580
+ full_simulation=full_simulation,
1581
+ full_simulation_times=full_simulation_times,
1582
+ boundary_ratio=self.boundary_ratio,
1583
+ fixed_M_max=self.fixed_M_max,
1584
+ num_obs_sampler=kwargs.get("num_obs_sampler", None),
1585
+ generator=kwargs.get("generator", None),
1586
+ )
1587
+
1588
+ def _get_shapes_raw(self) -> Tuple[int, int]:
1589
+ """Return fixed-capacity shapes for random split outputs.
1590
+
1591
+ With ``split_past_future=True``:
1592
+ - ``max_obs`` is bounded by ``max_past``
1593
+ - ``max_rem`` is bounded by ``max_num_obs - max_past``
1594
+ """
1595
+ if self.split_past_future:
1596
+ if self.min_past is None or self.max_past is None:
1597
+ raise ValueError(
1598
+ "min_past and max_past must be specified when split_past_future=True"
1599
+ )
1600
+ if self.fixed_M_max < self.min_past:
1601
+ raise ValueError("fixed_M_max is smaller than the configured min_past")
1602
+ max_obs = min(self.max_past, self.fixed_M_max)
1603
+ max_rem = max(0, self.fixed_M_max - self.max_past)
1604
+ else:
1605
+ max_obs = self.fixed_M_max
1606
+ max_rem = self.fixed_M_max
1607
+
1608
+ return max_obs, max_rem
1609
+
1610
+ def _split_by_boundary(
1611
+ self,
1612
+ obs: TensorType["B", "M"],
1613
+ obs_time: TensorType["B", "M"],
1614
+ obs_mask: TensorType["B", "M"],
1615
+ *,
1616
+ generator: Optional[torch.Generator] = None,
1617
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
1618
+ """Split sampled observations into strict past and future blocks.
1619
+
1620
+ The split is boundary-based and strict:
1621
+ - Past block samples ``k`` points from ``time <= boundary`` candidates,
1622
+ where ``k`` follows ``min_past``/``max_past`` (and ``generative_bias``),
1623
+ capped by available candidates and ``K_max``.
1624
+ - When ``k > 0``, remainder receives up to ``R_cap`` points sampled
1625
+ from ``time > boundary`` only (strict future).
1626
+ - When ``k == 0``, boundary splitting is ignored for remainder and
1627
+ points are sampled from all valid candidates.
1628
+
1629
+ Extra past/future candidates are ignored, and missing entries are
1630
+ padded by zeros with mask=False.
1631
+ """
1632
+ B, M = obs.shape
1633
+ # K_max: capacity of the past block [B, K_max]
1634
+ K_max = min(int(self.max_past), int(M))
1635
+ K_min = min(int(self.min_past), K_max)
1636
+ # R_cap: fixed capacity of the remainder block [B, R_cap]
1637
+ R_cap = max(0, int(M) - K_max)
1638
+
1639
+ boundary = self.meta_config.time_stop * self.boundary_ratio
1640
+ gen = generator if generator is not None else torch.default_generator
1641
+
1642
+ past_obs = torch.zeros(B, K_max, dtype=obs.dtype, device=obs.device)
1643
+ past_time = torch.zeros_like(past_obs)
1644
+ past_mask = torch.zeros(B, K_max, dtype=torch.bool, device=obs.device)
1645
+
1646
+ rem_obs = torch.zeros(B, R_cap, dtype=obs.dtype, device=obs.device)
1647
+ rem_time = torch.zeros_like(rem_obs)
1648
+ rem_mask = torch.zeros(B, R_cap, dtype=torch.bool, device=obs.device)
1649
+
1650
+ for b in range(B):
1651
+ valid_idx = obs_mask[b].nonzero(as_tuple=True)[0]
1652
+ past_candidates = valid_idx[obs_time[b, valid_idx] <= boundary]
1653
+ future_candidates = valid_idx[obs_time[b, valid_idx] > boundary]
1654
+
1655
+ if past_candidates.numel() > 1:
1656
+ order = torch.argsort(obs_time[b, past_candidates])
1657
+ past_candidates = past_candidates[order]
1658
+ if future_candidates.numel() > 1:
1659
+ order = torch.argsort(obs_time[b, future_candidates])
1660
+ future_candidates = future_candidates[order]
1661
+
1662
+ # Past is sampled uniformly without replacement from pre-boundary points.
1663
+ k_high = min(K_max, int(past_candidates.numel()))
1664
+ k_low = min(K_min, k_high)
1665
+ k = _sample_past_count_with_bias(
1666
+ low=int(k_low),
1667
+ high=int(k_high),
1668
+ generative_bias=self.generative_bias,
1669
+ generator=gen,
1670
+ device=obs.device,
1671
+ )
1672
+ if k > 0 and past_candidates.numel() > 0:
1673
+ chosen_offsets = torch.randperm(
1674
+ past_candidates.numel(),
1675
+ generator=gen,
1676
+ device=obs.device,
1677
+ )[:k]
1678
+ chosen_past = past_candidates[chosen_offsets]
1679
+ chosen_order = torch.argsort(obs_time[b, chosen_past])
1680
+ chosen_past = chosen_past[chosen_order]
1681
+ else:
1682
+ chosen_past = past_candidates[:0]
1683
+
1684
+ num_past = chosen_past.numel()
1685
+ if num_past > 0:
1686
+ past_obs[b, :num_past] = obs[b, chosen_past]
1687
+ past_time[b, :num_past] = obs_time[b, chosen_past]
1688
+ past_mask[b, :num_past] = True
1689
+
1690
+ # If no past point is selected, allow remainder sampling across the
1691
+ # whole valid domain. Otherwise keep strict future-only remainder.
1692
+ rem_pool = valid_idx if num_past == 0 else future_candidates
1693
+ if rem_pool.numel() > 1:
1694
+ order = torch.argsort(obs_time[b, rem_pool])
1695
+ rem_pool = rem_pool[order]
1696
+
1697
+ if R_cap <= 0 or rem_pool.numel() == 0:
1698
+ chosen_rem = rem_pool[:0]
1699
+ elif rem_pool.numel() <= R_cap:
1700
+ chosen_rem = rem_pool
1701
+ else:
1702
+ chosen_offsets = torch.randperm(
1703
+ rem_pool.numel(),
1704
+ generator=gen,
1705
+ device=obs.device,
1706
+ )[:R_cap]
1707
+ chosen_rem = rem_pool[chosen_offsets]
1708
+ chosen_order = torch.argsort(obs_time[b, chosen_rem])
1709
+ chosen_rem = chosen_rem[chosen_order]
1710
+
1711
+ r = chosen_rem.numel()
1712
+ if r > 0:
1713
+ rem_obs[b, :r] = obs[b, chosen_rem]
1714
+ rem_time[b, :r] = obs_time[b, chosen_rem]
1715
+ rem_mask[b, :r] = True
1716
+
1717
+ return past_obs, past_time, past_mask, rem_obs, rem_time, rem_mask
1718
+
1719
+ def generate(
1720
+ self, full_simulation: Tensor, full_simulation_times: Tensor, **kwargs
1721
+ ) -> Tuple[Tensor, ...]:
1722
+ obs, obs_time, obs_mask, _, _, _ = self._generate_raw(
1723
+ full_simulation, full_simulation_times, **kwargs
1724
+ )
1725
+ obs_mask = self._drop_non_positive_times_from_mask(obs_time, obs_mask)
1726
+
1727
+ if self.split_past_future:
1728
+ out = self._split_by_boundary(
1729
+ obs,
1730
+ obs_time,
1731
+ obs_mask,
1732
+ generator=kwargs.get("generator", None),
1733
+ )
1734
+ else:
1735
+ past_obs, past_time, past_mask = obs, obs_time, obs_mask
1736
+ rem_obs = rem_time = rem_mask = None
1737
+ out = (past_obs, past_time, past_mask, rem_obs, rem_time, rem_mask)
1738
+
1739
+ if not self.observations_config.add_rem:
1740
+ out = out[:3] + (None, None, None)
1741
+
1742
+ return (*out, None)
1743
+
1744
+
1745
+ class ObservationStrategyFactory:
1746
+ @staticmethod
1747
+ def from_config(
1748
+ obs_config: ObservationsConfig, meta_config: MetaStudyConfig
1749
+ ) -> ObservationStrategy:
1750
+ # Legacy compatibility:
1751
+ # - omitted ``type`` defaults via dataclass to ``pk_peak_half_life``
1752
+ # - explicit YAML ``type: null`` is loaded as ``None`` and also falls
1753
+ # back to ``pk_peak_half_life``
1754
+ strategy_type = getattr(obs_config, "type", None)
1755
+ if strategy_type is None:
1756
+ normalized_type = "pk_peak_half_life"
1757
+ elif isinstance(strategy_type, str):
1758
+ stripped = strategy_type.strip()
1759
+ if stripped == "" or stripped.lower() in {"null", "none"}:
1760
+ normalized_type = "pk_peak_half_life"
1761
+ else:
1762
+ normalized_type = stripped.lower()
1763
+ else:
1764
+ normalized_type = str(strategy_type).strip().lower()
1765
+
1766
+ if normalized_type in {
1767
+ "observations_pk_peak_halflife",
1768
+ "pk_peak_half_life",
1769
+ }:
1770
+ return PKPeakHalfLifeStrategy(obs_config, meta_config)
1771
+ if normalized_type in {
1772
+ "fix_past_time_random_selection",
1773
+ "random",
1774
+ }:
1775
+ return FixPastTimeRandomSelectionStrategy(obs_config, meta_config)
1776
+ raise ValueError(f"Unknown observation type: {strategy_type}")
sim_priors_pk/data/data_generation/observations_functions.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the observation functions that create the separation
3
+ between observations and remainders, the reminder can be either future
4
+ or selected from random in betweens, or None
5
+
6
+ """
7
+ import torch
8
+ from typing import Callable, Optional, Tuple
9
+ from torchtyping import TensorType
10
+
11
+ def fix_past_time_random_selection(
12
+ full_simulation: TensorType["N", "S"],
13
+ full_simulation_times: TensorType["N", "S"],
14
+ *,
15
+ boundary_ratio: float = 0.1,
16
+ fixed_M_max: int,
17
+ num_obs_sampler: Optional[Callable[[int], torch.Tensor]] = None,
18
+ generator: Optional[torch.Generator] = None,
19
+ **kwargs,
20
+ ) -> Tuple[
21
+ TensorType["N", "M"],
22
+ TensorType["N", "M"],
23
+ TensorType["N", "M"],
24
+ None,
25
+ None,
26
+ None,
27
+ ]:
28
+ """Select observation time-points uniformly without replacement.
29
+
30
+ Each row samples indices from the simulation grid independently and
31
+ uniformly (no replacement), then sorts the selected points by sampled
32
+ timestamps to keep chronological ordering in the output tensors.
33
+ """
34
+ if full_simulation is None:
35
+ return (None,) * 6
36
+
37
+ device = full_simulation.device
38
+ N, S = full_simulation.shape
39
+ M = int(max(0, fixed_M_max))
40
+
41
+ gen = generator if generator is not None else torch.default_generator
42
+ observations = torch.zeros(N, M, device=device, dtype=full_simulation.dtype)
43
+ observation_times = torch.zeros(N, M, device=device, dtype=full_simulation_times.dtype)
44
+ obs_mask = torch.zeros(N, M, dtype=torch.bool, device=device)
45
+
46
+ sample_cap = min(M, S)
47
+ if sample_cap == 0:
48
+ return observations, observation_times, obs_mask, None, None, None
49
+
50
+ if num_obs_sampler is None:
51
+ num_obs = torch.full((N,), sample_cap, dtype=torch.long, device=device)
52
+ else:
53
+ num_obs = num_obs_sampler(N).to(device=device, dtype=torch.long).clamp(1, sample_cap)
54
+
55
+ # Per-row sampling keeps selection uniform without replacement.
56
+ for row in range(N):
57
+ row_count = int(num_obs[row].item())
58
+ if row_count <= 0:
59
+ continue
60
+ selected = torch.randperm(S, generator=gen, device=device)[:row_count]
61
+ if row_count > 1:
62
+ # Order chosen simulation indices by sampled time for stable packing.
63
+ order = torch.argsort(full_simulation_times[row, selected])
64
+ selected = selected[order]
65
+ observations[row, :row_count] = full_simulation[row, selected]
66
+ observation_times[row, :row_count] = full_simulation_times[row, selected]
67
+ obs_mask[row, :row_count] = True
68
+
69
+ return observations, observation_times, obs_mask, None, None, None
sim_priors_pk/data/data_generation/study_population_stats.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is used for calculating summary statistics over ensembles of StudyJSONs to check that
2
+ the distribution of simulated data matches empirical data."""
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Dict, List
6
+
7
+ import numpy as np
8
+
9
+ from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON
10
+
11
+
12
+ class StudyPopulationStats(ABC):
13
+ """Abstract interface for computing and aggregating statistics over ensembles of StudyJSONs."""
14
+
15
+ @abstractmethod
16
+ def compute_per_individual(self, ind: IndividualJSON) -> Dict[str, float]:
17
+ """Compute statistics for a single individual (e.g., min/max observation value, count)."""
18
+
19
+ @abstractmethod
20
+ def compute_per_study(self, study: StudyJSON) -> Dict[str, float]:
21
+ """Compute statistics for a single study (e.g., min/max observation value, count)."""
22
+
23
+ @abstractmethod
24
+ def aggregate(
25
+ self,
26
+ per_study: List[Dict[str, float]],
27
+ ) -> Dict[str, object]:
28
+ """Aggregate statistics across studies (e.g., global extrema, averages, or histograms)."""
29
+
30
+ def compute_study_population_statistics(
31
+ self,
32
+ studies: List[StudyJSON],
33
+ ) -> Dict[str, object]:
34
+ """Compute and aggregate statistics for a StudyJSON ensemble."""
35
+ per_study = [self.compute_per_study(study) for study in studies]
36
+ return self.aggregate(per_study)
37
+
38
+
39
+ class BasicObservationStats(StudyPopulationStats):
40
+ """Compute descriptive statistics for observation values across individuals.
41
+ For each individual, computes:
42
+ - nAUC: Area Under the Curve (AUC), normalized by dose, using trapezoidal rule.
43
+ - nCmax: Maximum observed concentration, normalized by dose.
44
+ - Tmax: Time at which Cmax occurs.
45
+ - Nobs: Number of observations.
46
+ - Duration: Duration of the observation period (max observation time).
47
+ For each study, computes:
48
+ - Mean and standard deviation of nAUC, nCmax, Tmax across individuals.
49
+ - Mean and total number of observations (Nobs) across all individuals.
50
+ - Total study duration (max Duration across individuals).
51
+ Aggregates across studies to provide percentiles of each study-level statistic.
52
+ """
53
+
54
+ def __init__(self, alpha=0.1):
55
+ self.alpha = alpha
56
+
57
+ def compute_per_individual(self, ind: IndividualJSON) -> Dict[str, float]:
58
+ obs_vals = ind.get("observations", [])
59
+ obs_times = ind.get("observation_times", [])
60
+ dose = ind.get("dosing", [])
61
+ dosing_time = ind.get("dosing_times", [])
62
+ route = ind.get("dosing_type", [])
63
+
64
+ if not obs_vals:
65
+ return {"nAUC": np.nan, "nCmax": np.nan, "Tmax": np.nan, "Nobs": 0, "Duration": np.nan}
66
+
67
+ # Check that input times are sorted and match the number of observations
68
+ if len(obs_times) != len(obs_vals) or any(
69
+ obs_times[i] >= obs_times[i + 1] for i in range(len(obs_times) - 1)
70
+ ):
71
+ raise ValueError(
72
+ "Observation times must be sorted and match the number of observations."
73
+ )
74
+
75
+ # Check that there is only a single positive dose
76
+ if len(dose) != 1 or len(dosing_time) != 1 or len(route) != 1:
77
+ raise ValueError("Only single dosing is supported in this statistic.")
78
+ if dose[0] <= 0 or np.isnan(dose) or np.isnan(dosing_time[0]):
79
+ raise ValueError("Dose must be positive.")
80
+
81
+ # Check that dose precedes observations
82
+ if any(t < dosing_time[0] for t in obs_times):
83
+ raise ValueError("Dosing time must precede observation times.")
84
+
85
+ # calculate AUC using the trapezoidal rule:
86
+ # - for oral dosing, add a value of 0 at dosing time
87
+ # - for iv bolus, add the first observation at dosing time
88
+
89
+ obs_times_trapz = dosing_time + obs_times
90
+ if route[0] == "oral":
91
+ obs_vals_trapz = [0.0] + obs_vals
92
+ elif route[0] == "iv":
93
+ obs_vals_trapz = [obs_vals[0]] + obs_vals
94
+ else:
95
+ raise ValueError("Only 'oral' and 'iv' dosing types are supported.")
96
+
97
+ auc = np.trapezoid(obs_vals_trapz, obs_times_trapz) if len(obs_vals) > 0 else np.nan
98
+ auc /= dose[0]
99
+
100
+ # Calculate Cmax and Tmax
101
+ Cmax_idx = np.argmax(obs_vals)
102
+ Cmax = obs_vals[Cmax_idx]
103
+ Tmax = obs_times[Cmax_idx]
104
+ Cmax /= dose[0]
105
+
106
+ return {
107
+ "nAUC": float(auc),
108
+ "nCmax": float(Cmax),
109
+ "Tmax": float(Tmax),
110
+ "Nobs": len(obs_vals),
111
+ "Duration": np.max(obs_times),
112
+ }
113
+
114
+ def compute_per_study(self, study: StudyJSON) -> Dict[str, float]:
115
+ ind_stats = [
116
+ self.compute_per_individual(ind)
117
+ for block in ("context", "target")
118
+ for ind in study.get(block, [])
119
+ ]
120
+ if not ind_stats:
121
+ return {"max_obs": np.nan, "min_obs": np.nan, "mean_obs": np.nan, "num_obs": 0}
122
+
123
+ # Calculate statistics (maybe a bit too much, can be simplified later)
124
+ metrics = {
125
+ "nAUC_mean": ("nAUC", np.mean),
126
+ "nAUC_sd": ("nAUC", np.std),
127
+ "nAUC_cv": ("nAUC", lambda x: np.std(x) / np.mean(x) * 100 if np.mean(x) != 0 else np.nan),
128
+ "nCmax_mean": ("nCmax", np.mean),
129
+ "nCmax_sd": ("nCmax", np.std),
130
+ "nCmax_cv": ("nCmax", lambda x: np.std(x) / np.mean(x) * 100 if np.mean(x) != 0 else np.nan),
131
+ "Tmax_mean": ("Tmax", np.mean),
132
+ "Tmax_sd": ("Tmax", np.std),
133
+ "Tmax_cv": ("Tmax", lambda x: np.std(x) / np.mean(x) * 100 if np.mean(x) != 0 else np.nan),
134
+ "Nobs_mean": ("Nobs", np.mean),
135
+ "Nobs_total": ("Nobs", np.sum),
136
+ "Duration_max": ("Duration", np.max),
137
+ "nID": ("Nobs", lambda x: len(x)),
138
+ }
139
+
140
+ results = {name: func([d[key] for d in ind_stats]) for name, (key, func) in metrics.items()}
141
+
142
+ # Ensure all values are floats for JSON-friendliness or downstream compatibility
143
+ return {k: float(v) for k, v in results.items()}
144
+
145
+ def aggregate(
146
+ self,
147
+ per_study: List[Dict[str, float]],
148
+ ) -> Dict[str, object]:
149
+ """Aggregate statistics across studies."""
150
+ # Calculate percentiles of study-level statistics
151
+ percentiles = [5, 50, 95]
152
+ summary: Dict[str, object] = {}
153
+ for key in per_study[0].keys():
154
+ values = [s[key] for s in per_study if not np.isnan(s[key])]
155
+ if values:
156
+ summary[f"{key}_percentiles"] = {
157
+ f"P{p}": float(np.percentile(values, p)) for p in percentiles
158
+ }
159
+ else:
160
+ summary[f"{key}_percentiles"] = {f"P{p}": np.nan for p in percentiles}
161
+ summary["Nstudy"] = len(per_study)
162
+
163
+ return summary
164
+
165
+
166
+ class ListedObservationStats(BasicObservationStats):
167
+ """Variant of BasicObservationStats that returns lists of study-level statistics instead of percentiles.
168
+ This is useful for more detailed analyses or visualizations of the distribution of study-level statistics.
169
+ """
170
+ def __init__(self, alpha=0.1):
171
+ self.alpha = alpha
172
+
173
+ def aggregate(
174
+ self,
175
+ per_study: List[Dict[str, float]],
176
+ ) -> Dict[str, object]:
177
+ """Aggregate statistics across studies."""
178
+ # Collect lists of study-level statistics
179
+ summary: Dict[str, object] = {}
180
+ for key in per_study[0].keys():
181
+ values = [s[key] for s in per_study]
182
+ summary[f"{key}_list"] = [float(v) for v in values]
183
+ summary["Nstudy"] = len(per_study)
184
+
185
+ return summary
sim_priors_pk/data/data_preprocessing/__init__.py ADDED
File without changes
sim_priors_pk/data/data_preprocessing/data_preprocessing_utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from torchtyping import TensorType
5
+
6
+ import torch
7
+ from torchtyping import TensorType
8
+ from typing import List,Tuple,Optional
9
+ import numpy as np
10
+
11
+ from sim_priors_pk.data.data_preprocessing.raw_to_tensors_bundles import substance_cvs_to_tensors_bundle,substances_csv_to_tensors
12
+
13
+ from typing import NamedTuple
14
+ import torch
15
+ from torchtyping import TensorType
16
+
17
+ class SubstanceTensorGroup(NamedTuple):
18
+ observations: TensorType[1, "I", "T"]
19
+ times: TensorType[1, "I", "T"]
20
+ mask: TensorType[1, "I", "T"]
21
+ subject_mask: TensorType[1, "I"]
22
+
23
+ def apply_timescale_filter(
24
+ observations: TensorType["S", "I", "T"],
25
+ times: TensorType["S", "I", "T"],
26
+ masks: TensorType["S", "I", "T"],
27
+ subject_mask: TensorType["S", "I"],
28
+ *,
29
+ strategy: str = "log_zscore", # "log_zscore" | "median_fraction" | "none"
30
+ max_abs_z: float = 2.0, # for "log_zscore"
31
+ tau: float = 0.4, # for "median_fraction" (≈ ln 1.5)
32
+ ) -> Tuple[
33
+ TensorType["S", "I", "T"], # filtered observations
34
+ TensorType["S", "I", "T"], # filtered times
35
+ TensorType["S", "I", "T"], # filtered masks
36
+ TensorType["S", "I"], # filtered subject_mask
37
+ ]:
38
+ """
39
+ Zeroes‑out and un‑masks subjects whose time‑span is an outlier
40
+ w.r.t. other subjects in the *same* substance.
41
+
42
+ • strategy="log_zscore": keep subjects with |z| ≤ max_abs_z in log‑span
43
+ • strategy="median_fraction": keep subjects within ±tau of median(log‑span)
44
+ • strategy="none": return inputs unchanged
45
+ """
46
+ if strategy == "none":
47
+ return observations, times, masks, subject_mask
48
+
49
+ # combine padding + subject mask to know valid time points
50
+ valid = masks.bool() & subject_mask.unsqueeze(-1)
51
+
52
+ # --- compute log‑spans ----------------------------------------------------
53
+ t_max = times.masked_fill(~valid, float("-inf")).max(dim=2).values # [S, I]
54
+ t_min = times.masked_fill(~valid, float("inf")).min(dim=2).values # [S, I]
55
+ span = (t_max - t_min).clamp(min=1e-12)
56
+ log_span = span.log() # [S, I]
57
+
58
+ # --- decide which subjects to keep ---------------------------------------
59
+ if strategy == "log_zscore":
60
+ z = (log_span - log_span.mean(dim=1, keepdim=True)) / \
61
+ (log_span.std(dim=1, keepdim=True).clamp(min=1e-6))
62
+ keep = torch.abs(z) <= max_abs_z # [S, I]
63
+
64
+ elif strategy == "median_fraction":
65
+ med = log_span.median(dim=1, keepdim=True).values # [S,1]
66
+ keep = (log_span >= med - tau) & (log_span <= med + tau) # [S,I]
67
+
68
+ else:
69
+ # No filtering applied — return inputs unchanged
70
+ return observations, times, masks, subject_mask
71
+
72
+ # --- apply filter: zero & un‑mask ----------------------------------------
73
+ # clone so we don't mutate original tensors accidentally
74
+ obs_f = observations.clone()
75
+ times_f = times.clone()
76
+ masks_f = masks.clone()
77
+ subj_f = subject_mask.clone()
78
+
79
+ # indices where we drop subjects
80
+ drop = ~keep & subj_f.bool()
81
+ subj_f[drop] = False
82
+ masks_f[drop] = False
83
+ obs_f[drop] = 0.0
84
+ times_f[drop] = 0.0
85
+
86
+ return obs_f, times_f, masks_f, subj_f
87
+
88
+ def plot_subjects_for_substance(
89
+ drug_data_frame,
90
+ substance_label: str,
91
+ *,
92
+ z_score_normalization: bool = False,
93
+ normalize_by_max:bool = False,
94
+ time_strategy:str="log_zscore", # "log_zscore" | "median_fraction" | "none"
95
+ max_abs_z:float=2.,
96
+ x_scale: str = "linear", # "linear" ▸ default · "log"
97
+ y_scale: str = "linear", # "linear" ▸ default · "log"
98
+ alpha: float = 1.0, # 0 ≤ alpha ≤ 1
99
+ legend_outside: bool = True, # park legend to the right
100
+ figsize: Tuple[float, float] = (10, 5), # default width × height
101
+ save_dir: Optional[str] = None, # if set, saves the figure here
102
+
103
+ ) -> None:
104
+ """
105
+ Draw every subject‑trajectory (points + line) for *one* substance.
106
+
107
+ Parameters
108
+ ----------
109
+ drug_data_frame : pandas.DataFrame
110
+ substance_label : str
111
+ z_score_normalization : bool, optional
112
+ x_scale, y_scale : {"linear", "log"}, optional
113
+ Axis scaling. If you pick "log", make sure data are strictly > 0
114
+ on that axis or Matplotlib will complain.
115
+ alpha : float in [0, 1], optional
116
+ Transparency applied to both the line and the markers.
117
+ legend_outside : bool, optional
118
+ True ⇢ legend in a separate column to the right;
119
+ False ⇢ legend inside plot.
120
+ """
121
+ # ── 1.  Pull tensors ────────────────────────────────────────────
122
+ data_bundle = substance_cvs_to_tensors_bundle(drug_data_frame,normalize_by_max=True)
123
+
124
+ all_obs = data_bundle.observations # [S, I, T]
125
+ all_times = data_bundle.times # [S, I, T]
126
+ all_masks = data_bundle.masks # [S, I, T]
127
+ all_subj_mask = data_bundle.individuals_mask
128
+ substance_labels = data_bundle.substance_names # [S]
129
+ mapping = data_bundle.mapping
130
+ study_names = data_bundle.study_names # [S]
131
+ subject_names = data_bundle.individuals_names # [S][I]
132
+ empirical_loaded = True
133
+
134
+ # ── 2.  Find substance row ──────────────────────────────────────
135
+ try:
136
+ s_idx: int = int(np.where(substance_labels == substance_label)[0][0])
137
+ except IndexError:
138
+ raise ValueError(f"Substance '{substance_label}' not found.")
139
+
140
+ # ("I", "T")
141
+ obs: TensorType["I", "T"] = all_obs[s_idx]
142
+ times: TensorType["I", "T"] = all_times[s_idx]
143
+ step_mask: TensorType["I", "T"] = all_masks[s_idx].bool()
144
+ subj_mask: TensorType["I"] = all_subj_mask[s_idx].bool()
145
+
146
+ # ── 3.  Filter Time Series ──────────────────────────────────────
147
+ # Add batch dimension to match expected input [S, I, T], [S, I]
148
+ obs_b = obs.unsqueeze(0) # [1, I, T]
149
+ times_b = times.unsqueeze(0) # [1, I, T]
150
+ step_mask_b = step_mask.unsqueeze(0) # [1, I]
151
+ subj_mask_b = subj_mask.unsqueeze(0) # [1, I]
152
+
153
+ # Apply timescale filter (choose one strategy)
154
+ obs_b, times_b, step_mask_b, subj_mask_b = apply_timescale_filter(
155
+ observations=obs_b,
156
+ times=times_b,
157
+ masks=step_mask_b,
158
+ subject_mask=subj_mask_b,
159
+ strategy=time_strategy, # or "median_fraction"
160
+ max_abs_z=max_abs_z,
161
+ tau=0.4,
162
+ )
163
+
164
+ # Remove batch dim again
165
+ obs = obs_b[0]
166
+ times = times_b[0]
167
+ step_mask = step_mask_b[0]
168
+ subj_mask = subj_mask_b[0]
169
+
170
+
171
+ # ── 4.  Plot one line per *real* subject ────────────────────────
172
+ fig, ax = plt.subplots(figsize=figsize)
173
+ for i in range(obs.shape[0]): # iterate subjects (I)
174
+ if not subj_mask[i]:
175
+ continue # skip padded rows
176
+
177
+ valid: TensorType["T"] = step_mask[i] # True ⇢ real sample
178
+ t: TensorType["T"] = times[i][valid].cpu()
179
+ y: TensorType["T"] = obs[i][valid].cpu()
180
+
181
+ ax.plot(t, y, marker="o", alpha=alpha, label=f"subject {i}")
182
+
183
+ # ── 5.  Styling ────────────────────────────────────────────────
184
+ ax.set_title(f"All subjects – {substance_label}")
185
+ ax.set_xlabel("Time (normalised per substance)")
186
+ ax.set_ylabel("Observation")
187
+
188
+ # Axis scales
189
+ ax.set_xscale(x_scale)
190
+ ax.set_yscale(y_scale)
191
+
192
+ # Legend placement
193
+ if legend_outside:
194
+ # ncol=1 ▸ vertical list; bbox_to_anchor shifts legend fully outside
195
+ ax.legend(
196
+ loc="center left",
197
+ bbox_to_anchor=(1.02, 0.5),
198
+ borderaxespad=0.0,
199
+ frameon=False,
200
+ )
201
+ plt.tight_layout(rect=[0, 0, 0.82, 1]) # leave room on the right
202
+ else:
203
+ ax.legend(frameon=False)
204
+ plt.tight_layout()
205
+
206
+ # Save figure if path is given
207
+ if save_dir is not None:
208
+ from pathlib import Path
209
+ study_name = mapping[substance_label]["study_name"]
210
+ index = mapping[substance_label]["index"]
211
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
212
+ filename = f"{study_name}_{substance_label}_{index}.png"
213
+ filepath = Path(save_dir) / filename
214
+ fig.savefig(filepath, bbox_inches="tight", dpi=300)
215
+
216
+ plt.show()
217
+
218
+ def substances_with_min_timesteps(
219
+ drug_data_frame,
220
+ min_timesteps: int = 140,
221
+ *,
222
+ z_score_normalization: bool = False,
223
+ normalize_by_max:bool = False,
224
+ ) -> List[str]:
225
+ """
226
+ Return the list of substance labels whose **best** subject has
227
+ ≥ `min_timesteps` valid observations.
228
+
229
+ Parameters
230
+ ----------
231
+ drug_data_frame : pandas.DataFrame
232
+ Same dataframe you already pass to `substance_cvs_to_tensors_from_list`.
233
+ min_timesteps : int, default = 140
234
+ Threshold on the number of valid (unpadded) time‑points.
235
+ z_score_normalization : bool, default = False
236
+ Passed straight through to `substance_cvs_to_tensors_from_list`.
237
+
238
+ Returns
239
+ -------
240
+ List[str]
241
+ Substance strings that satisfy the criterion.
242
+ """
243
+ (
244
+ all_observations, # TensorType["S", "I", "T"] – concentration values
245
+ all_times, # TensorType["S", "I", "T"] – time grid (0‥1)
246
+ all_masks, # TensorType["S", "I", "T"] – bool, 1 = real step
247
+ all_subjects_mask, # TensorType["S", "I"] – bool, 1 = real subject
248
+ substance_labels, # np.ndarray, shape ["S"]
249
+ mapping
250
+ ) = substance_cvs_to_tensors_bundle(
251
+ drug_data_frame,
252
+ z_score_normalization=z_score_normalization,
253
+ normalize_by_max=normalize_by_max
254
+ )
255
+
256
+ # --- Shapes -------------------------------------------------------
257
+ # S = number of substances, I = max subjects per substance,
258
+ # T = max time‑steps per subject.
259
+ # all_masks : (S, I, T) – True at valid positions
260
+ # all_subjects_mask: (S, I) – True for *existing* subjects only
261
+ # -----------------------------------------------------------------
262
+
263
+ # Convert to bool & mask out padded subjects
264
+ valid_masks: TensorType["S", "I", "T"] = all_masks.bool()
265
+ subj_mask: TensorType["S", "I", 1] = all_subjects_mask.bool().unsqueeze(-1)
266
+ valid_masks = valid_masks & subj_mask # shape keeps (S,I,T)
267
+
268
+ # Count valid steps per subject ───────────────────────────────────
269
+ # counts[s, i] = #valid time‑points of subject i in substance s
270
+ counts: TensorType["S", "I"] = valid_masks.sum(dim=2) # (S, I)
271
+
272
+ # Max over subjects (per substance) -------------------------------
273
+ max_counts: TensorType["S"] = counts.max(dim=1).values # (S,)
274
+
275
+ # Pick substances that meet / beat the threshold ------------------
276
+ qualifying: TensorType["S"] = max_counts >= min_timesteps # (S,)
277
+
278
+ # Build the output list -------------------------------------------
279
+ return [label for label, keep in zip(substance_labels.tolist(), qualifying.tolist()) if keep]
280
+
281
+ def get_substance_tensors_by_label(
282
+ drug_data_frame,
283
+ substance_label: str,
284
+ *,
285
+ z_score_normalization: bool = False,
286
+ normalize_by_max: bool = False,
287
+ ) -> SubstanceTensorGroup:
288
+ """
289
+ Returns tensors for a selected substance, preserving S=1 batch shape.
290
+
291
+ Shapes:
292
+ observations : [1, I, T]
293
+ times : [1, I, T]
294
+ mask : [1, I, T]
295
+ subject_mask : [1, I]
296
+ """
297
+ data_bundle = substance_cvs_to_tensors_bundle(drug_data_frame,
298
+ z_score_normalization=z_score_normalization,
299
+ normalize_by_max=normalize_by_max)
300
+
301
+ all_observations = data_bundle.observations # [S, I, T]
302
+ all_empirical_times = data_bundle.times # [S, I, T]
303
+ all_empirical_mask = data_bundle.masks # [S, I, T]
304
+ all_subjects_mask = data_bundle.individuals_mask
305
+ substance_labels = data_bundle.substance_names # [S]
306
+ mapping = data_bundle.mapping
307
+
308
+ # Lookup index
309
+ label_to_index = {label: idx for idx, label in enumerate(substance_labels)}
310
+ if substance_label not in label_to_index:
311
+ raise ValueError(f"Substance label '{substance_label}' not found.")
312
+ s_idx = label_to_index[substance_label]
313
+
314
+ # Add batch dim: [1, I, T] or [1, I]
315
+ return SubstanceTensorGroup(
316
+ observations=all_observations[s_idx].unsqueeze(0), # [1, I, T]
317
+ times=all_empirical_times[s_idx].unsqueeze(0), # [1, I, T]
318
+ mask=all_empirical_mask[s_idx].unsqueeze(0).bool(), # [1, I, T]
319
+ subject_mask=all_subjects_mask[s_idx].unsqueeze(0).bool() # [1, I]
320
+ )
321
+
sim_priors_pk/data/data_preprocessing/raw_to_tensors_bundles.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Here we define the functions requiered to process the data
3
+
4
+ https://pk-db.com/
5
+
6
+ """
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from typing import Tuple
11
+ from torchtyping import TensorType
12
+ from typing import NamedTuple, List, Dict
13
+ from torchtyping import TensorType
14
+ from typing import Dict, Tuple, List
15
+ import numpy as np
16
+ import torch
17
+ from torchtyping import TensorType
18
+
19
+ from typing import Dict, Tuple, List, Optional
20
+ import numpy as np
21
+ import torch
22
+ from torchtyping import TensorType
23
+
24
+ lenuzza_doses_mg_per_g = {
25
+ "memantine": 0.005,
26
+ "omeprazole": 0.010,
27
+ "repaglinide": 0.00025,
28
+ "rosuvastatin": 0.005,
29
+ "tolbutamide": 0.010,
30
+ "dextromethorphan": 0.018,
31
+ "digoxin": 0.00025,
32
+ "paracetamol": 0.060,
33
+ "caffeine": 0.073,
34
+ "midazolam": 0.004,
35
+ "paraxanthine":0.073,
36
+ "dextrorphan":0.018,
37
+ }
38
+
39
+ class EmpiricalSubstanceTensorBundle(NamedTuple):
40
+ observations: TensorType["S", "I", "T"] # padded concentration values
41
+ times: TensorType["S", "I", "T"] # padded normalized times [0,1]
42
+ masks: TensorType["S", "I", "T"] # 1 = observed, 0 = missing or padded
43
+ individuals_mask: TensorType["S", "I"] # 1 = real subject, 0 = padded row
44
+ study_names: List[str] # [S] → one study name per substance
45
+ individuals_names: List[List[str]] # [S][I] → subject name per padded subject
46
+ substance_names: List[str] # [S] substance_label entries
47
+ mapping: Dict[str, Dict[str, object]]
48
+ dosing_amounts: TensorType["S", "I"] # dose mg/g per subject
49
+ dosing_route_types: TensorType["S", "I"] # route type index per subject
50
+
51
+ def map_substance_to_index_and_study(
52
+ drug_data_frame
53
+ ) -> dict[str, dict[str, object]]:
54
+ """
55
+ Returns a dictionary mapping each substance_label to its index (in np.unique order)
56
+ and its associated study_name (taken from the first row where that label appears).
57
+
58
+ Returns
59
+ -------
60
+ dict: {
61
+ "substance_label": {
62
+ "index": int,
63
+ "study_name": str
64
+ },
65
+ ...
66
+ }
67
+ """
68
+ substance_labels = np.unique(drug_data_frame["substance_label"].values)
69
+
70
+ mapping = {}
71
+ for idx, label in enumerate(substance_labels):
72
+ study_name = drug_data_frame.loc[
73
+ drug_data_frame["substance_label"] == label, "study_name"
74
+ ].iloc[0]
75
+ mapping[label] = {
76
+ "index": idx,
77
+ "study_name": study_name
78
+ }
79
+
80
+ return mapping
81
+
82
+ def substances_csv_to_tensors(drug_data_frame, substance_label='omeprazole'):
83
+ """
84
+ The function groups by substance_label and obtains the time series
85
+ for each subject, pads when necessary, and returns observations, times, and masks.
86
+
87
+ Params:
88
+ drug_data_frame (pd.DataFrame): Input DataFrame with specified columns.
89
+ substance_label (str): The substance label to filter by. Defaults to 'omeprazole'.
90
+
91
+ Returns:
92
+ observations (torch.Tensor): Padded observation values tensor of shape [num_subjects, max_time].
93
+ observations_times (torch.Tensor): Padded time points tensor of shape [num_subjects, max_time].
94
+ observations_mask (torch.Tensor): Mask tensor indicating valid data points, shape [num_subjects, max_time].
95
+ dosing_amounts (torch.Tensor): Dose amount per subject [num_subjects].
96
+ dosing_route_types (torch.Tensor): Route type index per subject [num_subjects].
97
+ """
98
+ # Filter the DataFrame by the given substance_label
99
+ substance_data = drug_data_frame[drug_data_frame['substance_label'] == substance_label]
100
+
101
+ # Group by subject_name
102
+ subject_groups = substance_data.groupby('subject_name')
103
+
104
+ # Collect sorted time and value arrays for each subject
105
+ times_list = []
106
+ values_list = []
107
+ dosing_amounts_list = []
108
+ route_list = []
109
+ for subject_name, group in subject_groups:
110
+ # Sort the group by 'time' to ensure chronological order
111
+ sorted_group = group.sort_values('time')
112
+ times = sorted_group['time'].values.astype(np.float32)
113
+ values = sorted_group['value'].values.astype(np.float32)
114
+ times_list.append(times)
115
+ values_list.append(values)
116
+
117
+ # Determine dosing amount based on substance name
118
+ if 'substance_name' in group.columns:
119
+ s_name = str(group['substance_name'].iloc[0]).lower()
120
+ else:
121
+ s_name = str(substance_label).lower()
122
+
123
+ dose_value = 0.5
124
+ for key, val in lenuzza_doses_mg_per_g.items():
125
+ if key in s_name:
126
+ dose_value = val
127
+ break
128
+ dosing_amounts_list.append(dose_value)
129
+ route_list.append(0) # oral
130
+
131
+ # Determine the maximum time sequence length
132
+ max_len = max(len(times) for times in times_list) if times_list else 0
133
+
134
+ # Pad each subject's time and value arrays, and create the mask
135
+ padded_times = []
136
+ padded_values = []
137
+ masks = []
138
+ for times, values in zip(times_list, values_list):
139
+ current_len = len(times)
140
+ pad_len = max_len - current_len
141
+
142
+ # Pad with zeros
143
+ padded_time = np.pad(times, (0, pad_len), mode='constant', constant_values=0)
144
+ padded_value = np.pad(values, (0, pad_len), mode='constant', constant_values=0)
145
+
146
+ # Create mask (1 for real data, 0 for padding)
147
+ mask = np.ones(max_len, dtype=np.float32)
148
+ mask[current_len:] = 0
149
+
150
+ padded_times.append(padded_time)
151
+ padded_values.append(padded_value)
152
+ masks.append(mask)
153
+
154
+ # Convert to PyTorch tensors
155
+ observations = torch.tensor(padded_values, dtype=torch.float32) # [P, T]
156
+ observations_times = torch.tensor(padded_times, dtype=torch.float32) # [P, T]
157
+ observations_mask = torch.tensor(masks, dtype=torch.float32) # [P, T]
158
+
159
+ dosing_amounts = torch.tensor(dosing_amounts_list, dtype=torch.float32) # [P]
160
+ dosing_route_types = torch.tensor(route_list, dtype=torch.long) # [P]
161
+
162
+ return observations, observations_times, observations_mask, dosing_amounts, dosing_route_types
163
+
164
+ def substance_dict_to_tensors(
165
+ selected_series: Optional[Dict[str, Dict[str, List[float]]]],
166
+ hidden_series: Optional[Dict[str, Dict[str, List[float]]]],
167
+ ) -> Tuple[
168
+ Optional[TensorType["N_sel", "T"]], Optional[TensorType["N_sel", "T"]], Optional[TensorType["N_sel", "T"]],
169
+ Optional[TensorType["N_hid", "T"]], Optional[TensorType["N_hid", "T"]], Optional[TensorType["N_hid", "T"]],
170
+ ]:
171
+ """
172
+ Converts two dictionaries of time series into padded tensors, sharing a common maximum sequence length.
173
+ Typically comming from the frontend payload
174
+
175
+ Args:
176
+ selected_series: Mapping subject_name -> {'timepoints': [...], 'values': [...]}.
177
+ hidden_series: Mapping subject_name -> {'timepoints': [...], 'values': [...]}.
178
+
179
+ Returns:
180
+ sel_obs, sel_times, sel_mask: [N_sel, T] or None.
181
+ hid_obs, hid_times, hid_mask: [N_hid, T] or None.
182
+ """
183
+ def _extract_sorted(series: Dict[str, Dict[str, List[float]]]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
184
+ times_list, values_list = [], []
185
+ for subj, data in series.items():
186
+ t = np.array(data['timepoints'], dtype=np.float32)
187
+ v = np.array(data['values'], dtype=np.float32)
188
+ idx = np.argsort(t)
189
+ times_list.append(t[idx])
190
+ values_list.append(v[idx])
191
+ return times_list, values_list
192
+
193
+ def _pad(times_list: List[np.ndarray], vals_list: List[np.ndarray], T: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
+ padded_times, padded_vals, masks = [], [], []
195
+ for t, v in zip(times_list, vals_list):
196
+ pad = T - len(t)
197
+ t_pad = np.pad(t, (0, pad), mode='constant', constant_values=0)
198
+ v_pad = np.pad(v, (0, pad), mode='constant', constant_values=0)
199
+ mask = np.ones(T, dtype=np.float32)
200
+ mask[len(t):] = 0
201
+ padded_times.append(t_pad)
202
+ padded_vals.append(v_pad)
203
+ masks.append(mask)
204
+ return (
205
+ torch.tensor(padded_vals, dtype=torch.float32), # [N, T]
206
+ torch.tensor(padded_times, dtype=torch.float32), # [N, T]
207
+ torch.tensor(masks, dtype=torch.float32), # [N, T]
208
+ )
209
+
210
+ # Handle selected_series
211
+ if selected_series:
212
+ sel_times_list, sel_vals_list = _extract_sorted(selected_series)
213
+ max_len_sel = max((len(t) for t in sel_times_list), default=0)
214
+ else:
215
+ sel_times_list = sel_vals_list = []
216
+ max_len_sel = 0
217
+
218
+ # Handle hidden_series
219
+ if hidden_series:
220
+ hid_times_list, hid_vals_list = _extract_sorted(hidden_series)
221
+ max_len_hid = max((len(t) for t in hid_times_list), default=0)
222
+ else:
223
+ hid_times_list = hid_vals_list = []
224
+ max_len_hid = 0
225
+
226
+ # Determine shared max length
227
+ T = max(max_len_sel, max_len_hid)
228
+
229
+ # Pad or return None depending on presence of data
230
+ if sel_times_list:
231
+ sel_obs, sel_times, sel_mask = _pad(sel_times_list, sel_vals_list, T)
232
+ else:
233
+ sel_obs = sel_times = sel_mask = None
234
+
235
+ if hid_times_list:
236
+ hid_obs, hid_times, hid_mask = _pad(hid_times_list, hid_vals_list, T)
237
+ else:
238
+ hid_obs = hid_times = hid_mask = None
239
+
240
+ return sel_obs, sel_times, sel_mask, hid_obs, hid_times, hid_mask
241
+
242
+ def substance_cvs_to_tensors_bundle(
243
+ drug_data_frame: pd.DataFrame,
244
+ **kwargs
245
+ ) -> EmpiricalSubstanceTensorBundle:
246
+ """
247
+ Groups by substance_label and returns padded tensors for:
248
+ - observations,
249
+ - times (normalized per-substance to [0, 1]),
250
+ - observation masks.
251
+
252
+ Handles invalid (NaN) values in observations, applies optional normalization,
253
+ and constructs per-substance tensors.
254
+
255
+ Also returns metadata:
256
+ - study_names: one per substance,
257
+ - subject_names: one per subject (padded to max P).
258
+
259
+ Returns:
260
+ observations: TensorType["S", "I", "T"]
261
+ times: TensorType["S", "I", "T"]
262
+ masks: TensorType["S", "I", "T"]
263
+ subjects_mask: TensorType["S", "I"]
264
+ substance_labels: np.ndarray of length S
265
+ mapping: metadata dictionary
266
+ study_names: list of S strings
267
+ subject_names: list of S lists of I strings
268
+ """
269
+ import numpy as np
270
+ import torch
271
+ import torch.nn.functional as F
272
+
273
+ substance_labels = np.unique(drug_data_frame["substance_label"].values)
274
+ mapping = map_substance_to_index_and_study(drug_data_frame)
275
+
276
+ substance_observations = []
277
+ substance_times = []
278
+ substance_masks = []
279
+ subject_masks = []
280
+ substance_doses = []
281
+ substance_routes = []
282
+
283
+ study_names_per_substance = []
284
+ subject_names_per_substance = []
285
+
286
+ max_time_steps = 0
287
+ max_subjects = 0
288
+
289
+ for substance_label in substance_labels:
290
+ df_sub = drug_data_frame[drug_data_frame["substance_label"] == substance_label]
291
+ obs, times, masks, doses, routes = substances_csv_to_tensors(
292
+ drug_data_frame, substance_label=substance_label
293
+ )
294
+ # obs, times, masks: [P, T]
295
+
296
+ valid_obs_mask = ~torch.isnan(obs)
297
+ masks = masks.bool() & valid_obs_mask
298
+ obs = obs.nan_to_num(nan=0.0)
299
+
300
+ max_time_steps = max(max_time_steps, obs.shape[1])
301
+ max_subjects = max(max_subjects, obs.shape[0])
302
+
303
+ # --- Metadata collection ---
304
+ grouped = df_sub.groupby("subject_name").first()
305
+ subject_names = list(grouped.index)
306
+ study_name = grouped["study_name"].iloc[0] if len(grouped) > 0 else ""
307
+
308
+ study_names_per_substance.append(study_name)
309
+ subject_names_per_substance.append(subject_names)
310
+
311
+ substance_observations.append(obs)
312
+ substance_times.append(times)
313
+ substance_masks.append(masks)
314
+ subject_masks.append(torch.ones(obs.shape[0], dtype=torch.float32)) # [P]
315
+ substance_doses.append(doses)
316
+ substance_routes.append(routes)
317
+
318
+ # Padding pass
319
+ all_observations, all_times, all_masks, all_subjects_mask = [], [], [], []
320
+ all_doses, all_routes = [], []
321
+
322
+ for obs, time, mask, subj_mask, subj_names, doses, routes in zip(
323
+ substance_observations,
324
+ substance_times,
325
+ substance_masks,
326
+ subject_masks,
327
+ subject_names_per_substance,
328
+ substance_doses,
329
+ substance_routes,
330
+ ):
331
+ pad_subjects = max_subjects - obs.shape[0]
332
+ pad_timesteps = max_time_steps - obs.shape[1]
333
+
334
+ obs_padded = F.pad(obs, (0, pad_timesteps, 0, pad_subjects)) # [I, T]
335
+ time_padded = F.pad(time, (0, pad_timesteps, 0, pad_subjects)) # [I, T]
336
+ mask_padded = F.pad(mask, (0, pad_timesteps, 0, pad_subjects)) # [I, T]
337
+ subj_mask_padded = F.pad(subj_mask, (0, pad_subjects)) # [I]
338
+ dose_padded = F.pad(doses, (0, pad_subjects)) # [I]
339
+ route_padded = F.pad(routes, (0, pad_subjects)) # [I]
340
+ subj_names += [""] * pad_subjects # [I] → pad with ""
341
+
342
+ all_observations.append(obs_padded)
343
+ all_times.append(time_padded)
344
+ all_masks.append(mask_padded)
345
+ all_subjects_mask.append(subj_mask_padded)
346
+ all_doses.append(dose_padded)
347
+ all_routes.append(route_padded)
348
+
349
+ return EmpiricalSubstanceTensorBundle(
350
+ observations=torch.stack(all_observations), # [S, I, T]
351
+ times=torch.stack(all_times), # [S, I, T]
352
+ masks=torch.stack(all_masks), # [S, I, T]
353
+ individuals_mask=torch.stack(all_subjects_mask), # [S, I]
354
+ substance_names=list(substance_labels), # [S]
355
+ mapping=mapping,
356
+ study_names=study_names_per_substance, # [S]
357
+ individuals_names=subject_names_per_substance, # [S][I]
358
+ dosing_amounts=torch.stack(all_doses), # [S, I]
359
+ dosing_route_types=torch.stack(all_routes) # [S, I]
360
+ )
sim_priors_pk/data/data_preprocessing/tensors_to_databatch.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility for initializing :class:`AICMECompartmentsDataBatch` objects.
2
+
3
+ This small helper is primarily used in older preprocessing scripts. It takes
4
+ precomputed observation tensors and wraps them into a minimal
5
+ ``AICMECompartmentsDataBatch`` where only the context fields are populated.
6
+ All other entries are set to empty tensors or placeholders so that the
7
+ resulting object conforms to the new metadata interface.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+
14
+ from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
15
+
16
+
17
+ def initialize_aicme_batch(
18
+ observations: torch.Tensor,
19
+ observations_times: torch.Tensor,
20
+ observations_mask: torch.Tensor,
21
+ ) -> AICMECompartmentsDataBatch:
22
+ """Wrap raw tensors into an :class:`AICMECompartmentsDataBatch`.
23
+
24
+ Parameters
25
+ ----------
26
+ observations:
27
+ Tensor of shape ``[I, T]`` containing concentration values.
28
+ observations_times:
29
+ Tensor of shape ``[I, T]`` with the corresponding time points.
30
+ observations_mask:
31
+ Boolean tensor of shape ``[I, T]`` indicating valid entries.
32
+
33
+ Returns
34
+ -------
35
+ AICMECompartmentsDataBatch
36
+ Batch with ``B=1`` where all context fields are populated and the
37
+ remaining fields are placeholders (zeros or empty strings).
38
+ """
39
+
40
+ # Add batch dimension (B=1) and feature dimension for observations and times
41
+ context_obs = observations.unsqueeze(0).unsqueeze(-1) # [1, I, T, 1]
42
+ context_obs_time = observations_times.unsqueeze(0).unsqueeze(-1) # [1, I, T, 1]
43
+ # Add batch dimension for mask
44
+ context_obs_mask = observations_mask.unsqueeze(0) # [1, I, T]
45
+
46
+ num_individuals = observations.shape[0]
47
+
48
+ return AICMECompartmentsDataBatch(
49
+ target_obs=None,
50
+ target_obs_time=None,
51
+ target_obs_mask=None,
52
+ target_rem_sim=None,
53
+ target_rem_sim_time=None,
54
+ target_rem_sim_mask=None,
55
+ target_dosing_amounts=torch.zeros(1, 0),
56
+ target_dosing_route_types=torch.zeros(1, 0, dtype=torch.long),
57
+ context_obs=context_obs,
58
+ context_obs_time=context_obs_time,
59
+ context_obs_mask=context_obs_mask,
60
+ context_rem_sim=None,
61
+ context_rem_sim_time=None,
62
+ context_rem_sim_mask=None,
63
+ context_dosing_amounts=torch.zeros(1, num_individuals),
64
+ context_dosing_route_types=torch.zeros(1, num_individuals, dtype=torch.long),
65
+ study_name=[""],
66
+ context_subject_name=[["" for _ in range(num_individuals)]],
67
+ target_subject_name=[["" for _ in range(0)]],
68
+ substance_name=[""],
69
+ time_scales=None,
70
+ is_empirical=False,
71
+ )
72
+
sim_priors_pk/data/datasets/aicme_batch.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Batch structures shared between synthetic and empirical pipelines."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import namedtuple
6
+ from typing import List, NamedTuple
7
+
8
+ import torch
9
+ from torchtyping import TensorType
10
+
11
+ ShapeConfig = namedtuple(
12
+ "ShapeConfig",
13
+ [
14
+ "batch_size",
15
+ "c_individuals",
16
+ "num_obs_c",
17
+ "remaining_obs_c",
18
+ "t_individuals",
19
+ "num_obs_t",
20
+ "remaining_obs_t",
21
+ ],
22
+ )
23
+
24
+
25
+ class AICMECompartmentsDataBatch(NamedTuple):
26
+ """Container aggregating context and target trajectories.
27
+
28
+ The tuple carries tensors describing observed measurements, simulated
29
+ remainders, dosing metadata and masking utilities used across both the
30
+ synthetic simulation pipeline and the empirical JSON tooling.
31
+ """
32
+
33
+ # max_num_individuals-max_n_new_individuals = n_c_individuals
34
+ target_obs: TensorType["B", "t_ind", "num_obs_t", 1]
35
+ target_obs_time: TensorType["B", "t_ind", "num_obs_t", 1]
36
+ target_obs_mask: TensorType["B", "t_ind", "num_obs_t"]
37
+
38
+ target_rem_sim: TensorType["B", "t_ind", "rem_obs_t", 1]
39
+ target_rem_sim_time: TensorType["B", "t_ind", "rem_obs_t", 1]
40
+ target_rem_sim_mask: TensorType["B", "t_ind", "rem_obs_t"]
41
+
42
+ context_obs: TensorType["B", "c_ind", "num_obs_c", 1]
43
+ context_obs_time: TensorType["B", "c_ind", "num_obs_c", 1]
44
+ context_obs_mask: TensorType["B", "c_ind", "num_obs_c"]
45
+
46
+ context_rem_sim: TensorType["B", "c_ind", "rem_obs_c", 1]
47
+ context_rem_sim_time: TensorType["B", "c_ind", "rem_obs_c", 1]
48
+ context_rem_sim_mask: TensorType["B", "c_ind", "rem_obs_c"]
49
+
50
+ # Dosing information
51
+ target_dosing_amounts: TensorType["B", "t_ind"]
52
+ target_dosing_route_types: TensorType["B", "t_ind"]
53
+ context_dosing_amounts: TensorType["B", "c_ind"]
54
+ context_dosing_route_types: TensorType["B", "c_ind"]
55
+
56
+ # Masks over padded individuals
57
+ mask_context_individuals: TensorType["B", "c_ind"]
58
+ mask_target_individuals: TensorType["B", "t_ind"]
59
+
60
+ # 🆕 NEW: tracking metadata
61
+ study_name: List[str]
62
+ """Study identifier for each element in the batch (length ``B``)."""
63
+ context_subject_name: List[List[str]]
64
+ """Names of context individuals: shape ``[B][c_ind]``."""
65
+ target_subject_name: List[List[str]]
66
+ """Names of target individuals: shape ``[B][t_ind]``."""
67
+ substance_name: List[str]
68
+ """Drug or compound names corresponding to each study (length ``B``)."""
69
+
70
+ # Meta information
71
+ time_scales: TensorType["B", 2] # shape : [B,2]
72
+ is_empirical: bool = False # NEW: True ⇢ empirical CSV, False ⇢ simulation
73
+
74
+ @property
75
+ def mask_individuals(self) -> TensorType["B", "c_ind"]:
76
+ """Alias for backward compatibility; returns ``mask_context_individuals``."""
77
+
78
+ return self.mask_context_individuals
79
+
80
+ def detach_all(self) -> "AICMECompartmentsDataBatch":
81
+ """Detaches all tensor fields from the computation graph."""
82
+
83
+ return AICMECompartmentsDataBatch(
84
+ *(t.detach() if isinstance(t, torch.Tensor) else t for t in self)
85
+ )
86
+
87
+ def log_transform(self) -> "AICMECompartmentsDataBatch":
88
+ """Applies log transformation to observation and remainder tensors.
89
+
90
+ Deprecated for training: log scaling is now expected to be handled by
91
+ ``PKScaler`` (for example via ``value_method="log"`` or
92
+ ``value_method="log_and_max"``).
93
+ Kept for backward compatibility with older utilities.
94
+ """
95
+
96
+ transformed_tensors = []
97
+ for name, tensor in zip(self._fields, self):
98
+ if name in [
99
+ "target_obs",
100
+ "target_rem_sim",
101
+ "context_obs",
102
+ "context_rem_sim",
103
+ ]:
104
+ transformed_tensors.append(torch.log(tensor + 1e-6))
105
+ else:
106
+ transformed_tensors.append(tensor)
107
+ return AICMECompartmentsDataBatch(*transformed_tensors)
108
+
109
+ def to_device(self, device: torch.device) -> "AICMECompartmentsDataBatch":
110
+ """Moves all tensor fields to the specified device (leaves strings untouched)."""
111
+
112
+ return AICMECompartmentsDataBatch(
113
+ *(t.to(device) if isinstance(t, torch.Tensor) else t for t in self)
114
+ )
115
+
116
+ def to(self, device: torch.device | str) -> "AICMECompartmentsDataBatch":
117
+ """PyTorch-style alias delegating to :meth:`to_device`.
118
+
119
+ Several generic utilities expect batch-like objects to implement
120
+ ``.to(device)``. Exposing this alias keeps the explicit
121
+ ``to_device(...)`` API while allowing those utilities to move the full
122
+ databatch onto the target device safely.
123
+ """
124
+
125
+ return self.to_device(torch.device(device))
126
+
127
+ def to_reconstruct_type(self) -> "AICMECompartmentsDataBatch":
128
+ """
129
+ Return a new databatch where the target trajectories are reconstructed
130
+ by concatenating observed and remainder segments, then right-padding
131
+ so that the target has the same time dimension as the context.
132
+ The context is left untouched.
133
+ """
134
+
135
+ B, Ic, Tc, _ = self.context_obs.shape # context time dimension is reference
136
+ _, It, _, _ = self.target_obs.shape
137
+
138
+ T_max = Tc # max length for padding
139
+
140
+ # allocate reconstructed tensors
141
+ Xt_full = torch.zeros(
142
+ B, It, T_max, 1, dtype=self.target_obs.dtype, device=self.target_obs.device
143
+ )
144
+ Tt_full = torch.zeros(
145
+ B, It, T_max, 1, dtype=self.target_obs_time.dtype, device=self.target_obs_time.device
146
+ )
147
+ Mt_full = torch.zeros(B, It, T_max, dtype=torch.bool, device=self.target_obs_mask.device)
148
+
149
+ # fill with observed + remainder segments
150
+ for b in range(B):
151
+ for i in range(It):
152
+ o_len = int(self.target_obs_mask[b, i].sum().item())
153
+ r_len = int(self.target_rem_sim_mask[b, i].sum().item())
154
+ total = o_len + r_len
155
+ if total == 0:
156
+ continue
157
+ Xt_full[b, i, :o_len] = self.target_obs[b, i, :o_len]
158
+ Xt_full[b, i, o_len:total] = self.target_rem_sim[b, i, :r_len]
159
+ Tt_full[b, i, :o_len] = self.target_obs_time[b, i, :o_len]
160
+ Tt_full[b, i, o_len:total] = self.target_rem_sim_time[b, i, :r_len]
161
+ Mt_full[b, i, :total] = True
162
+
163
+ return self._replace(
164
+ target_obs=Xt_full,
165
+ target_obs_time=Tt_full,
166
+ target_obs_mask=Mt_full,
167
+ )
sim_priors_pk/data/datasets/aicme_datasets.py ADDED
@@ -0,0 +1,1874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import tempfile
4
+ import warnings
5
+ from dataclasses import replace
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Sequence, Tuple
8
+
9
+ import lightning.pytorch as pl
10
+ import torch
11
+ from torch import Tensor
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from torch.utils.data.dataloader import default_collate
14
+
15
+ from sim_priors_pk import data_dir
16
+ from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig
17
+ from sim_priors_pk.data.data_generation.compartment_models_management import (
18
+ prepare_full_simulation,
19
+ prepare_full_simulation_list_with_repeated_targets as prepare_full_simulation_list_with_repeated_targets_backend,
20
+ prepare_full_simulation_with_repeated_targets,
21
+ )
22
+ from sim_priors_pk.data.data_generation.observations_classes import (
23
+ ObservationStrategyFactory,
24
+ )
25
+ from sim_priors_pk.data.datasets.aicme_batch import (
26
+ AICMECompartmentsDataBatch,
27
+ )
28
+ from sim_priors_pk.utils.tensors_operations import ensure_mask_or_empty, ensure_tensor_or_empty
29
+
30
+
31
+ def ensure_min_valid(mask, min_length):
32
+ """
33
+ Ensures that each row of the last dimension in the mask has at least `min_length` valid (1s) entries.
34
+ """
35
+ valid_counts = mask.sum(dim=-1, keepdim=True) # Count valid entries along time dimension
36
+ needs_fixing = valid_counts < min_length # Identify sequences needing more valid entries
37
+
38
+ if needs_fixing.any():
39
+ # Find the top `min_length` indices in each row (sorted for deterministic filling)
40
+ _, topk_indices = torch.topk(
41
+ mask + torch.rand_like(mask) * 0.01, k=min_length, dim=-1, sorted=True
42
+ )
43
+
44
+ # Create an empty mask and scatter `1`s at selected indices
45
+ fixed_mask = torch.zeros_like(mask)
46
+ fixed_mask.scatter_(-1, topk_indices, 1.0)
47
+
48
+ # Combine the original and fixed masks
49
+ mask = torch.where(needs_fixing, fixed_mask, mask)
50
+
51
+ return mask
52
+
53
+
54
+ def is_valid_simulation(sim: torch.Tensor) -> bool:
55
+ """Returns True if the simulation is numerically valid and all values are < 10."""
56
+ return torch.isfinite(sim).all() and (sim >= 0).all() and (sim < 10).all()
57
+
58
+
59
+ def _stack_one_perm(
60
+ batches: Sequence["AICMECompartmentsDataBatch"],
61
+ ) -> "AICMECompartmentsDataBatch":
62
+ result = []
63
+ for f in AICMECompartmentsDataBatch._fields:
64
+ items = [getattr(b, f) for b in batches]
65
+
66
+ if f in {"study_name", "substance_name"}:
67
+ merged = []
68
+ for it in items:
69
+ if isinstance(it, (list, tuple)):
70
+ merged.extend(map(str, it))
71
+ elif isinstance(it, str):
72
+ merged.append(it)
73
+ else:
74
+ raise TypeError(f"Unexpected type for {f}: {type(it)}")
75
+ result.append(merged)
76
+ continue
77
+
78
+ if f in {"context_subject_name", "target_subject_name"}:
79
+ merged_lls = []
80
+ for it in items:
81
+ if isinstance(it, (list, tuple)):
82
+ merged_lls.extend([list(inner) for inner in it])
83
+ else:
84
+ raise TypeError(f"Unexpected type for {f}: {type(it)}")
85
+ result.append(merged_lls)
86
+ continue
87
+
88
+ result.append(default_collate(items))
89
+
90
+ return AICMECompartmentsDataBatch(*result)
91
+
92
+
93
+ def _collate_aicme_batches(batch_list):
94
+ """
95
+ Handles:
96
+ - [B] of AICMECompartmentsDataBatch → returns one collated batch
97
+ - [B][P] of AICMECompartmentsDataBatch → returns list of P collated batches
98
+ """
99
+ if not batch_list:
100
+ return batch_list
101
+
102
+ first = batch_list[0]
103
+
104
+ # Case 1: flat list of AICME batches
105
+ if hasattr(first, "_fields"): # NamedTuple-like
106
+ return _stack_one_perm(batch_list)
107
+
108
+ # Case 2: nested [B][P]
109
+ if isinstance(first, (list, tuple)) and hasattr(first[0], "_fields"):
110
+ # transpose [B][P] -> [P][B]
111
+ transposed = list(zip(*batch_list))
112
+ return [_stack_one_perm(list(group)) for group in transposed]
113
+
114
+ # If we reach here and elements are Tensors, do NOT recurse further.
115
+ if torch.is_tensor(first):
116
+ raise TypeError(
117
+ "Got a list of tensors instead of AICMECompartmentsDataBatch. "
118
+ "Check that your Dataset returns AICMECompartmentsDataBatch, not raw tensors."
119
+ )
120
+
121
+ raise TypeError(
122
+ f"Unexpected element type in batch_list: {type(first)}. "
123
+ "Expected AICMECompartmentsDataBatch or list thereof."
124
+ )
125
+
126
+
127
+ def split_individuals_tensor_batch(
128
+ full_tensor_a: torch.Tensor,
129
+ full_tensor_b: torch.Tensor,
130
+ full_tensor_c: Optional[torch.Tensor],
131
+ n_of_target_individuals: int,
132
+ seed: Optional[int] = None,
133
+ ) -> Tuple[
134
+ torch.Tensor,
135
+ torch.Tensor,
136
+ Optional[torch.Tensor],
137
+ torch.Tensor,
138
+ torch.Tensor,
139
+ Optional[torch.Tensor],
140
+ ]:
141
+ num_individuals = full_tensor_a.shape[0]
142
+ if seed is not None:
143
+ random.seed(seed)
144
+
145
+ if n_of_target_individuals == 0:
146
+ return full_tensor_a, full_tensor_b, full_tensor_c, None, None, None
147
+
148
+ all_indices = list(range(num_individuals))
149
+ target_indices = random.sample(all_indices, n_of_target_individuals)
150
+ context_indices = [i for i in all_indices if i not in target_indices]
151
+
152
+ context_a = full_tensor_a[context_indices]
153
+ context_b = full_tensor_b[context_indices]
154
+ context_c = full_tensor_c[context_indices] if full_tensor_c is not None else None
155
+
156
+ target_a = full_tensor_a[target_indices]
157
+ target_b = full_tensor_b[target_indices]
158
+ target_c = full_tensor_c[target_indices] if full_tensor_c is not None else None
159
+
160
+ return context_a, context_b, context_c, target_a, target_b, target_c
161
+
162
+
163
+ def list_of_databath_to_device(
164
+ batch_list: List[AICMECompartmentsDataBatch],
165
+ device: torch.device | str,
166
+ ) -> List[AICMECompartmentsDataBatch]:
167
+ """Move a list of batches to ``device``.
168
+
169
+ Parameters
170
+ ----------
171
+ batch_list:
172
+ List of :class:`AICMECompartmentsDataBatch` objects.
173
+ device:
174
+ Target device.
175
+ """
176
+ return [b.to_device(device) for b in batch_list]
177
+
178
+
179
+ def build_reconstruction_db(
180
+ db: AICMECompartmentsDataBatch,
181
+ ) -> AICMECompartmentsDataBatch:
182
+ """
183
+ Reconstruct the target trajectories by concatenating observed and remainder
184
+ segments, then right-padding so that the target has the same time dimension
185
+ as the context. The context is left untouched.
186
+
187
+ Returns a new AICMECompartmentsDataBatch.
188
+ """
189
+ B, Ic, Tc, _ = db.context_obs.shape # context shape is the reference
190
+ _, It, _, _ = db.target_obs.shape
191
+
192
+ # reference length for padding (use context time dim)
193
+ T_max = Tc
194
+
195
+ # allocate new target tensors
196
+ Xt_full = torch.zeros(B, It, T_max, 1, dtype=db.target_obs.dtype, device=db.target_obs.device)
197
+ Tt_full = torch.zeros(
198
+ B, It, T_max, 1, dtype=db.target_obs_time.dtype, device=db.target_obs_time.device
199
+ )
200
+ Mt_full = torch.zeros(B, It, T_max, dtype=torch.bool, device=db.target_obs_mask.device)
201
+
202
+ # fill reconstructed target
203
+ for b in range(B):
204
+ for i in range(It):
205
+ o_len = int(db.target_obs_mask[b, i].sum().item())
206
+ r_len = int(db.target_rem_sim_mask[b, i].sum().item())
207
+ total = o_len + r_len
208
+ if total == 0:
209
+ continue
210
+ Xt_full[b, i, :o_len] = db.target_obs[b, i, :o_len]
211
+ Xt_full[b, i, o_len:total] = db.target_rem_sim[b, i, :r_len]
212
+ Tt_full[b, i, :o_len] = db.target_obs_time[b, i, :o_len]
213
+ Tt_full[b, i, o_len:total] = db.target_rem_sim_time[b, i, :r_len]
214
+ Mt_full[b, i, :total] = True
215
+
216
+ # replace only the target fields
217
+ return db._replace(
218
+ target_obs=Xt_full,
219
+ target_obs_time=Tt_full,
220
+ target_obs_mask=Mt_full,
221
+ )
222
+
223
+
224
+ class AICMECompartmentsDataset(Dataset):
225
+ """Dataset generating synthetic PK batches for AICME models.
226
+
227
+ Target observation strategies should already divide past and future
228
+ observations (``split_past_future=True``).
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ model_config: NodePKExperimentConfig,
234
+ ctx_fn,
235
+ tgt_fn,
236
+ number_of_process=1000,
237
+ *,
238
+ store_in_tempfile: bool = False,
239
+ keep_tempfile: bool = False,
240
+ recreate_tempfile: bool = False,
241
+ tempfile_path: str | None = None,
242
+ show_progress: bool = True,
243
+ split: str = "",
244
+ use_shared_target_dosing: bool = False,
245
+ shared_target_n_targets: int = 100,
246
+ ):
247
+ self.mix_data_config = model_config.mix_data
248
+ self.meta_study_config = model_config.meta_study
249
+ self.meta_dosing_config = model_config.dosing
250
+ self.number_of_process = number_of_process
251
+ # ``n_of_permutations`` specifies how many shuffled versions of the
252
+ # context/target split are generated for a single simulation.
253
+ # ``n_of_databatches`` is a deprecated alias kept for backward
254
+ # compatibility and mirrors ``n_of_permutations``.
255
+ self.n_of_permutations = model_config.mix_data.n_of_permutations
256
+ self.n_of_databatches = self.n_of_permutations # deprecated alias
257
+ self.n_of_target_individuals = int(model_config.mix_data.n_of_target_individuals)
258
+ if self.n_of_target_individuals < 0:
259
+ raise ValueError("n_of_target_individuals must be >= 0")
260
+
261
+ # `num_individuals_range` controls context individuals only.
262
+ self.min_context_individuals = int(self.meta_study_config.num_individuals_range[0])
263
+ self.max_context_individuals = int(self.meta_study_config.num_individuals_range[-1])
264
+ if self.min_context_individuals < 0:
265
+ raise ValueError("meta_study.num_individuals_range minimum must be >= 0")
266
+ if self.max_context_individuals < self.min_context_individuals:
267
+ raise ValueError("meta_study.num_individuals_range must satisfy max >= min")
268
+
269
+ # Fixed total capacity used by downstream consumers.
270
+ self.max_individuals = self.max_context_individuals + self.n_of_target_individuals
271
+
272
+ self.context_fn = ctx_fn
273
+ self.target_fn = tgt_fn
274
+ self.store_in_tempfile = store_in_tempfile
275
+ self.keep_tempfile = keep_tempfile
276
+ self.recreate_tempfile = recreate_tempfile
277
+ self.show_progress = True
278
+ self._tmpfile_path: List[str] | None = None
279
+ self._loaded_data = None
280
+ self.run_id = getattr(model_config, "run_index", 0)
281
+ self.model_name = model_config.name_str
282
+
283
+ if self.store_in_tempfile:
284
+ self._prepare_tempfile_data(tempfile_path=tempfile_path, split=split)
285
+
286
+ self.use_shared_target_dosing = use_shared_target_dosing
287
+ self.shared_target_n_targets = shared_target_n_targets
288
+
289
+ def __del__(self):
290
+ if (
291
+ self.store_in_tempfile
292
+ and not self.keep_tempfile
293
+ and self._tmpfile_path
294
+ and os.path.exists(self._tmpfile_path)
295
+ ):
296
+ os.remove(self._tmpfile_path)
297
+
298
+ def __len__(self):
299
+ return self.number_of_process # Arbitrary large number to simulate infinite data
300
+
301
+ def _prepare_tempfile_data(self, *, tempfile_path: str | None, split: str) -> None:
302
+ """Handle creation and (re)generation of the temporary data file."""
303
+ if tempfile_path is None:
304
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
305
+ self._tmpfile_path = tmp.name
306
+ tmp.close()
307
+ else:
308
+ # Allow both Tuple paths from YAML and plain strings
309
+ if isinstance(tempfile_path, (tuple, list)):
310
+ base_path = os.path.join(data_dir, *tempfile_path)
311
+ else:
312
+ base_path = tempfile_path
313
+
314
+ dirname = os.path.dirname(base_path)
315
+ basename = os.path.basename(base_path)
316
+
317
+ suffix = f"_{self.model_name}_{split}"
318
+ if self.run_id is not None:
319
+ suffix += f"_run{self.run_id}"
320
+ new_basename = basename + suffix + ".tr"
321
+
322
+ self._tmpfile_path = os.path.join(dirname, new_basename)
323
+
324
+ if self.recreate_tempfile or not os.path.exists(self._tmpfile_path):
325
+ print("RECREATING DATASET!")
326
+ iterator = range(self.number_of_process)
327
+ if self.show_progress:
328
+ from tqdm.auto import tqdm
329
+
330
+ iterator = tqdm(iterator, desc="Generating AICME data")
331
+ data = [self._generate_item(i) for i in iterator]
332
+ torch.save(data, self._tmpfile_path)
333
+
334
+ def split_simulations(
335
+ self, full_simulation, full_simulation_times
336
+ ) -> Tuple[
337
+ torch.Tensor,
338
+ torch.Tensor,
339
+ Optional[torch.Tensor],
340
+ Optional[torch.Tensor],
341
+ list[int],
342
+ list[int],
343
+ ]:
344
+ """
345
+ From the full simulation, randomly select `n_of_target_individuals` as targets and keep the rest as context.
346
+ If `n_of_target_individuals == 0`, returns None for the target fields.
347
+ """
348
+ n_of_target_individuals = self.n_of_target_individuals
349
+ num_individuals = full_simulation.shape[0]
350
+
351
+ if n_of_target_individuals == 0:
352
+ context_simulation = full_simulation
353
+ context_simulation_times = full_simulation_times
354
+ return (
355
+ context_simulation,
356
+ context_simulation_times,
357
+ None,
358
+ None,
359
+ list(range(num_individuals)),
360
+ [],
361
+ )
362
+
363
+ if num_individuals < n_of_target_individuals:
364
+ raise ValueError(
365
+ "Simulation contains fewer individuals than requested targets: "
366
+ f"num_individuals={num_individuals}, "
367
+ f"n_of_target_individuals={n_of_target_individuals}."
368
+ )
369
+
370
+ # Randomly select indices for target individuals
371
+ target_indices = random.sample(range(num_individuals), n_of_target_individuals)
372
+ context_indices = [i for i in range(num_individuals) if i not in target_indices]
373
+
374
+ # Split the simulations, times, and masks
375
+ target_simulation = full_simulation[target_indices]
376
+ target_simulation_times = full_simulation_times[target_indices]
377
+ context_simulation = full_simulation[context_indices]
378
+ context_simulation_times = full_simulation_times[context_indices]
379
+
380
+ return (
381
+ context_simulation,
382
+ context_simulation_times,
383
+ target_simulation,
384
+ target_simulation_times,
385
+ context_indices,
386
+ target_indices,
387
+ )
388
+
389
+ def _build_generation_meta_study_config(self):
390
+ """Return a meta-study config where totals include fixed target individuals.
391
+
392
+ The user-facing ``meta_study.num_individuals_range`` represents context
393
+ individuals only. For raw simulation generation, we therefore sample
394
+ ``context + n_of_target_individuals`` total individuals.
395
+ """
396
+ total_min = self.min_context_individuals + self.n_of_target_individuals
397
+ total_max = self.max_context_individuals + self.n_of_target_individuals
398
+
399
+ if getattr(self.meta_study_config, "simple_mode", False):
400
+ total_individuals = random.randint(total_min, total_max)
401
+ return replace(
402
+ self.meta_study_config,
403
+ num_individuals=total_individuals,
404
+ num_individuals_range=(total_individuals, total_individuals),
405
+ )
406
+
407
+ return replace(
408
+ self.meta_study_config,
409
+ num_individuals_range=(total_min, total_max),
410
+ )
411
+
412
+ def __getitem__(self, idx):
413
+ if self.store_in_tempfile:
414
+ if self._loaded_data is None:
415
+ self._loaded_data = torch.load(self._tmpfile_path, weights_only=False)
416
+ # If in distributed mode, adjust the index based on process rank/world size
417
+ if torch.distributed.is_initialized():
418
+ rank = torch.distributed.get_rank()
419
+ world_size = torch.distributed.get_world_size()
420
+ total_len = len(self._loaded_data)
421
+ # Compute adjusted indices for this rank
422
+ adjusted_idx = idx * world_size + rank
423
+ if adjusted_idx >= total_len:
424
+ # If we would go out of bounds, wrap around to get a valid index
425
+ adjusted_idx = adjusted_idx % total_len
426
+ return self._loaded_data[adjusted_idx]
427
+ return self._loaded_data[idx]
428
+
429
+ if self.use_shared_target_dosing:
430
+ return self._generate_item_sample_target_dosing(
431
+ idx, n_targets=self.shared_target_n_targets
432
+ )
433
+ return self._generate_item(idx)
434
+
435
+ def _generate_item(self, idx) -> List[AICMECompartmentsDataBatch]:
436
+ """Generate a list of ``AICMECompartmentsDataBatch`` objects.
437
+
438
+ Each element corresponds to one permutation of the context/target split.
439
+ Target observations are generated using ``target_fn``, which is expected
440
+ to divide past and future observations.
441
+ """
442
+ (
443
+ full_simulation,
444
+ full_simulation_times,
445
+ dosing_amounts,
446
+ dosing_routes,
447
+ time_points,
448
+ time_scales,
449
+ ) = prepare_full_simulation(
450
+ self._build_generation_meta_study_config(),
451
+ self.meta_dosing_config,
452
+ )
453
+
454
+ list_of_databatches: List[AICMECompartmentsDataBatch] = []
455
+ for _ in range(self.n_of_permutations):
456
+ # Split into context and target
457
+ (
458
+ context_simulation,
459
+ context_simulation_times,
460
+ target_simulation,
461
+ target_simulation_times,
462
+ context_indices,
463
+ target_indices,
464
+ ) = self.split_simulations(full_simulation, full_simulation_times)
465
+
466
+ context_observations = self._safe_generate(
467
+ self.context_fn,
468
+ context_simulation,
469
+ context_simulation_times,
470
+ time_scales=time_scales,
471
+ )
472
+
473
+ target_observations = self._safe_generate(
474
+ self.target_fn,
475
+ target_simulation,
476
+ target_simulation_times,
477
+ time_scales=time_scales,
478
+ )
479
+
480
+ (
481
+ context_obs, # [c_ind, num_obs_c, 1]
482
+ context_obs_time, # [c_ind, num_obs_c, 1]
483
+ context_obs_mask, # [c_ind, num_obs_c]
484
+ context_rem_sim, # [c_ind, rem_obs_c, 1]
485
+ context_rem_sim_time, # [c_ind, rem_obs_c, 1]
486
+ context_rem_sim_mask, # [c_ind, rem_obs_c]
487
+ context_time_scales,
488
+ ) = context_observations
489
+
490
+ (
491
+ target_obs, # [t_ind, num_obs_t, 1]
492
+ target_obs_time, # [t_ind, num_obs_t, 1]
493
+ target_obs_mask, # [t_ind, num_obs_t]
494
+ target_rem_sim, # [t_ind, rem_obs_t, 1]
495
+ target_rem_sim_time, # [t_ind, rem_obs_t, 1]
496
+ target_rem_sim_mask, # [t_ind, rem_obs_t]
497
+ target_time_scales,
498
+ ) = target_observations
499
+
500
+ # Use provided time scales or fall back to simulation defaults
501
+ ts = (
502
+ context_time_scales
503
+ if context_time_scales is not None
504
+ else target_time_scales
505
+ if target_time_scales is not None
506
+ else time_scales
507
+ )
508
+
509
+ batch = self._build_padded_batch(
510
+ context_obs,
511
+ context_obs_time,
512
+ context_obs_mask,
513
+ context_rem_sim,
514
+ context_rem_sim_time,
515
+ context_rem_sim_mask,
516
+ dosing_amounts[context_indices],
517
+ dosing_routes[context_indices],
518
+ target_obs,
519
+ target_obs_time,
520
+ target_obs_mask,
521
+ target_rem_sim,
522
+ target_rem_sim_time,
523
+ target_rem_sim_mask,
524
+ dosing_amounts[target_indices] if len(target_indices) > 0 else None,
525
+ dosing_routes[target_indices] if len(target_indices) > 0 else None,
526
+ ts,
527
+ )
528
+ list_of_databatches.append(batch)
529
+
530
+ return list_of_databatches
531
+
532
+ def _generate_item_sample_target_dosing(
533
+ self,
534
+ idx: int,
535
+ n_targets: int = 100,
536
+ different_dosing: bool = False,
537
+ ):
538
+ (
539
+ context_sim,
540
+ context_times,
541
+ target_sim,
542
+ target_times,
543
+ dosing_amounts_ctx,
544
+ dosing_routes_ctx,
545
+ dosing_amounts_tgt,
546
+ dosing_routes_tgt,
547
+ time_points,
548
+ time_scales,
549
+ ) = prepare_full_simulation_with_repeated_targets(
550
+ self.meta_study_config,
551
+ self.meta_dosing_config,
552
+ n_targets,
553
+ different_dosing=different_dosing,
554
+ idx=idx,
555
+ )
556
+
557
+ # Observations
558
+ context_obs_pack = self._safe_generate(
559
+ self.context_fn, context_sim, context_times, time_scales=time_scales
560
+ )
561
+ target_obs_pack = self._safe_generate(
562
+ self.target_fn, target_sim, target_times, time_scales=time_scales
563
+ )
564
+
565
+ (
566
+ context_obs,
567
+ context_obs_time,
568
+ context_obs_mask,
569
+ context_rem_sim,
570
+ context_rem_sim_time,
571
+ context_rem_sim_mask,
572
+ context_time_scales,
573
+ ) = context_obs_pack
574
+
575
+ (
576
+ target_obs,
577
+ target_obs_time,
578
+ target_obs_mask,
579
+ target_rem_sim,
580
+ target_rem_sim_time,
581
+ target_rem_sim_mask,
582
+ target_time_scales,
583
+ ) = target_obs_pack
584
+
585
+ ts = (
586
+ context_time_scales
587
+ if context_time_scales is not None
588
+ else (target_time_scales or time_scales)
589
+ )
590
+
591
+ # Build batch
592
+ batch = self._build_padded_batch(
593
+ # context
594
+ context_obs,
595
+ context_obs_time,
596
+ context_obs_mask,
597
+ context_rem_sim,
598
+ context_rem_sim_time,
599
+ context_rem_sim_mask,
600
+ dosing_amounts_ctx,
601
+ dosing_routes_ctx,
602
+ # target
603
+ target_obs,
604
+ target_obs_time,
605
+ target_obs_mask,
606
+ target_rem_sim,
607
+ target_rem_sim_time,
608
+ target_rem_sim_mask,
609
+ dosing_amounts_tgt,
610
+ dosing_routes_tgt,
611
+ # time scales
612
+ ts=ts,
613
+ target_capacity=n_targets,
614
+ )
615
+
616
+ return [batch]
617
+
618
+ # ------------------------------------------------------------------ #
619
+ # utilities
620
+ # ------------------------------------------------------------------ #
621
+
622
+ def _build_padded_batch(
623
+ self,
624
+ ctx_obs: Tensor, # [c_ind, num_obs_c]
625
+ ctx_time: Tensor, # [c_ind, num_obs_c]
626
+ ctx_mask: Tensor, # [c_ind, num_obs_c]
627
+ ctx_rem: Optional[Tensor], # [c_ind, rem_obs_c] | None
628
+ ctx_rem_time: Optional[Tensor], # [c_ind, rem_obs_c] | None
629
+ ctx_rem_mask: Optional[Tensor], # [c_ind, rem_obs_c] | None
630
+ ctx_dose: Tensor, # [c_ind]
631
+ ctx_route: Tensor, # [c_ind]
632
+ tgt_obs: Optional[Tensor], # [t_ind, num_obs_t] | None
633
+ tgt_time: Optional[Tensor], # [t_ind, num_obs_t] | None
634
+ tgt_mask: Optional[Tensor], # [t_ind, num_obs_t] | None
635
+ tgt_rem: Optional[Tensor], # [t_ind, rem_obs_t] | None
636
+ tgt_rem_time: Optional[Tensor], # [t_ind, rem_obs_t] | None
637
+ tgt_rem_mask: Optional[Tensor], # [t_ind, rem_obs_t] | None
638
+ tgt_dose: Optional[Tensor], # [t_ind] | None
639
+ tgt_route: Optional[Tensor], # [t_ind] | None
640
+ ts: Tensor, # [B(=1), 2]
641
+ *,
642
+ target_capacity: Optional[
643
+ int
644
+ ] = None, # ← NEW (optional). If None, use self.n_of_target_individuals
645
+ ) -> AICMECompartmentsDataBatch:
646
+ """Pad context and target tensors then pack them into a batch."""
647
+
648
+ max_c = self.max_context_individuals # (unchanged)
649
+ max_t = (
650
+ target_capacity if target_capacity is not None else self.n_of_target_individuals
651
+ ) # ← ONLY CHANGE
652
+
653
+ # ── target padding (unchanged) ─────────────────────────────────────────
654
+ t_obs_p = self._pad_first_dim(
655
+ ensure_tensor_or_empty(
656
+ tgt_obs.unsqueeze(-1) if tgt_obs is not None else None, (1, 1, 1)
657
+ ), # to [t_ind, Tt, 1]
658
+ max_t,
659
+ )
660
+ t_time_p = self._pad_first_dim(
661
+ ensure_tensor_or_empty(
662
+ tgt_time.unsqueeze(-1) if tgt_time is not None else None, (1, 1, 1)
663
+ ), # to [t_ind, Tt, 1]
664
+ max_t,
665
+ )
666
+ t_mask_p = self._pad_first_dim(
667
+ ensure_mask_or_empty(
668
+ tgt_mask if tgt_mask is not None else None, (1, 1)
669
+ ), # to [t_ind, Tt]
670
+ max_t,
671
+ )
672
+ t_rem_p = self._pad_first_dim(
673
+ ensure_tensor_or_empty(
674
+ tgt_rem.unsqueeze(-1) if tgt_rem is not None else None, (t_obs_p.size(0), 1, 1)
675
+ ), # [t_ind, Rt,1]
676
+ max_t,
677
+ )
678
+ t_rem_time_p = self._pad_first_dim(
679
+ ensure_tensor_or_empty(
680
+ tgt_rem_time.unsqueeze(-1) if tgt_rem_time is not None else None,
681
+ (t_obs_p.size(0), 1, 1),
682
+ ),
683
+ max_t,
684
+ )
685
+ t_rem_mask_p = self._pad_first_dim(
686
+ ensure_mask_or_empty(
687
+ tgt_rem_mask if tgt_rem_mask is not None else None, (t_obs_p.size(0), 1)
688
+ ),
689
+ max_t,
690
+ )
691
+ t_dose_p = self._pad_first_dim(
692
+ ensure_tensor_or_empty(tgt_dose if tgt_dose is not None else None, (1,)), # [t_ind]
693
+ max_t,
694
+ )
695
+ t_route_p = self._pad_first_dim(
696
+ ensure_tensor_or_empty(tgt_route if tgt_route is not None else None, (1,)), # [t_ind]
697
+ max_t,
698
+ ).long()
699
+
700
+ # ── context padding (unchanged) ────────────────────────────────────────
701
+ c_obs_p = self._pad_first_dim(ctx_obs, max_c).unsqueeze(-1) # [c_ind, Tc, 1]
702
+ c_time_p = self._pad_first_dim(ctx_time, max_c).unsqueeze(-1) # [c_ind, Tc, 1]
703
+ c_mask_p = self._pad_first_dim(ctx_mask, max_c) # [c_ind, Tc]
704
+ c_rem_p = self._pad_first_dim(
705
+ ensure_tensor_or_empty(
706
+ ctx_rem.unsqueeze(-1) if ctx_rem is not None else None, (ctx_obs.size(0), 1, 1)
707
+ ),
708
+ max_c,
709
+ )
710
+ c_rem_time_p = self._pad_first_dim(
711
+ ensure_tensor_or_empty(
712
+ ctx_rem_time.unsqueeze(-1) if ctx_rem_time is not None else None,
713
+ (ctx_obs.size(0), 1, 1),
714
+ ),
715
+ max_c,
716
+ )
717
+ c_rem_mask_p = self._pad_first_dim(
718
+ ensure_mask_or_empty(
719
+ ctx_rem_mask if ctx_rem_mask is not None else None, (ctx_obs.size(0), 1)
720
+ ),
721
+ max_c,
722
+ )
723
+ c_dose_p = self._pad_first_dim(ctx_dose, max_c) # [c_ind]
724
+ c_route_p = self._pad_first_dim(ctx_route, max_c).long() # [c_ind]
725
+
726
+ total_c = ctx_obs.size(0)
727
+ mask_c_inds = torch.zeros(self.max_context_individuals, dtype=torch.bool)
728
+ mask_c_inds[:total_c] = True
729
+
730
+ total_t = tgt_obs.size(0) if tgt_obs is not None else 0
731
+ mask_t_inds = torch.zeros(
732
+ max_t, dtype=torch.bool
733
+ ) # ← use max_t here (unchanged logic, just variable)
734
+ mask_t_inds[:total_t] = True
735
+
736
+ return AICMECompartmentsDataBatch(
737
+ target_obs=t_obs_p,
738
+ target_obs_time=t_time_p,
739
+ target_obs_mask=t_mask_p,
740
+ target_rem_sim=t_rem_p,
741
+ target_rem_sim_time=t_rem_time_p,
742
+ target_rem_sim_mask=t_rem_mask_p,
743
+ target_dosing_amounts=t_dose_p,
744
+ target_dosing_route_types=t_route_p,
745
+ context_obs=c_obs_p,
746
+ context_obs_time=c_time_p,
747
+ context_obs_mask=c_mask_p,
748
+ context_rem_sim=c_rem_p,
749
+ context_rem_sim_time=c_rem_time_p,
750
+ context_rem_sim_mask=c_rem_mask_p,
751
+ context_dosing_amounts=c_dose_p,
752
+ context_dosing_route_types=c_route_p,
753
+ mask_context_individuals=mask_c_inds,
754
+ mask_target_individuals=mask_t_inds,
755
+ study_name=[""],
756
+ context_subject_name=[[""] * max_c],
757
+ target_subject_name=[[""] * max_t], # ← still uses max_t
758
+ substance_name=[""],
759
+ time_scales=ts,
760
+ is_empirical=False,
761
+ )
762
+
763
+ @staticmethod
764
+ def _safe_generate(strategy, sim, times, **kw):
765
+ """
766
+ Call ObservationStrategy.generate() only when `sim` is not None.
767
+ Returns a 7-tuple of Nones otherwise.
768
+ """
769
+ if sim is None:
770
+ return (None, None, None, None, None, None, None)
771
+
772
+ for _ in range(10): # retries, like old manager
773
+ out = strategy.generate(sim, times, **kw)
774
+ if out[0] is not None: # got a non-empty slice
775
+ return out
776
+ raise RuntimeError(
777
+ "Unable to generate non-empty observations "
778
+ "after 10 attempts – check strategy parameters."
779
+ )
780
+
781
+ @staticmethod
782
+ def _pad_first_dim(t: torch.Tensor, size: int) -> torch.Tensor:
783
+ """Pad tensor along the first dimension up to ``size``.
784
+
785
+ Parameters
786
+ ----------
787
+ t : TensorType["I", *Ts]
788
+ Input tensor where ``I`` may be smaller than ``size``.
789
+ size : int
790
+ Desired first-dimension size after padding.
791
+
792
+ Returns
793
+ -------
794
+ TensorType["size", *Ts]
795
+ Tensor padded with zeros (or ``False`` for bool tensors) so that the
796
+ first dimension equals ``size``. If ``t`` already has ``size`` or
797
+ more elements along the first dimension, it is truncated.
798
+ """
799
+
800
+ current = t.size(0)
801
+ if current >= size:
802
+ return t[:size]
803
+
804
+ pad_shape = (size - current, *t.shape[1:])
805
+ pad_value = False if t.dtype == torch.bool else 0.0
806
+ padding = torch.full(pad_shape, pad_value, dtype=t.dtype, device=t.device)
807
+ return torch.cat([t, padding], dim=0)
808
+
809
+
810
+ class AICMECompartmentsDataModule(pl.LightningDataModule):
811
+ """LightningDataModule for synthetic PK simulation data."""
812
+
813
+ # Empirical target batches always use the legacy PK observation strategy
814
+ # with a fixed capacity profile, independent from synthetic target config.
815
+ _EMPIRICAL_TARGET_MAX_NUM_OBS = 15
816
+ _EMPIRICAL_TARGET_MIN_PAST = 0
817
+ _EMPIRICAL_TARGET_MAX_PAST = 5
818
+
819
+ def __init__(
820
+ self,
821
+ model_config: NodePKExperimentConfig,
822
+ ):
823
+ super().__init__()
824
+ self.model_config = model_config
825
+ self.context_config = model_config.context_observations
826
+ self.target_config = model_config.target_observations
827
+ self.meta_config = model_config.meta_study
828
+ self.data_config = model_config.mix_data
829
+ self.study_config = model_config.meta_study
830
+ self.num_workers = model_config.train.num_workers
831
+ self.persistent_workers = model_config.train.persistent_workers
832
+ self.shuffle_val = getattr(model_config.train, "shuffle_val", True)
833
+ self.train_size = self.data_config.train_size
834
+ self.val_size = self.data_config.val_size
835
+ self.test_size = self.data_config.test_size
836
+ self.batch_size = model_config.train.batch_size
837
+ self._prepared = False
838
+ # Cached shape parameters for empirical batch builders
839
+ self.max_individuals: int | None = None
840
+ self.max_observations: int | None = None
841
+ self.max_remaining: int | None = None
842
+ self.empirical_target_config = None
843
+ self.empirical_target_strategy = None
844
+ self.empirical_test_batches: Dict[str, List["AICMECompartmentsDataBatch"]] = {}
845
+ self.empirical_test_batches_no_heldout: Dict[str, List["AICMECompartmentsDataBatch"]] = {}
846
+
847
+ def prepare_data(self):
848
+ # Use this method to download or prepare data if needed.
849
+ # This is called only once and on a single GPU.
850
+ # Here the Observation Manager Also Handles Empirical Data
851
+ tempfile_path = getattr(self.data_config, "tempfile_path", None)
852
+ if tempfile_path:
853
+ temp_dir = Path(data_dir).joinpath(*tempfile_path)
854
+ else:
855
+ temp_dir = Path(data_dir) / "preprocessed"
856
+ temp_dir.mkdir(parents=True, exist_ok=True)
857
+
858
+ self.context_strategy = ObservationStrategyFactory.from_config(
859
+ self.context_config,
860
+ self.meta_config,
861
+ )
862
+ self.target_strategy = ObservationStrategyFactory.from_config(
863
+ self.target_config,
864
+ self.meta_config,
865
+ )
866
+ # Empirical target path: enforce legacy PK strategy and fixed capacities.
867
+ # This is intentionally decoupled from synthetic target strategy settings.
868
+ self.empirical_target_config = replace(
869
+ self.target_config,
870
+ type=None,
871
+ split_past_future=True,
872
+ max_num_obs=self._EMPIRICAL_TARGET_MAX_NUM_OBS,
873
+ min_past=self._EMPIRICAL_TARGET_MIN_PAST,
874
+ max_past=self._EMPIRICAL_TARGET_MAX_PAST,
875
+ )
876
+ self.empirical_target_strategy = ObservationStrategyFactory.from_config(
877
+ self.empirical_target_config,
878
+ self.meta_config,
879
+ )
880
+ self.train_dataset = AICMECompartmentsDataset(
881
+ self.model_config,
882
+ ctx_fn=self.context_strategy,
883
+ tgt_fn=self.target_strategy,
884
+ number_of_process=self.train_size,
885
+ store_in_tempfile=self.data_config.store_in_tempfile,
886
+ keep_tempfile=self.data_config.keep_tempfile,
887
+ recreate_tempfile=self.data_config.recreate_tempfile,
888
+ tempfile_path=self.data_config.tempfile_path,
889
+ show_progress=self.data_config.tqdm_progress,
890
+ split="train",
891
+ )
892
+ self.val_dataset = AICMECompartmentsDataset(
893
+ self.model_config,
894
+ ctx_fn=self.context_strategy,
895
+ tgt_fn=self.target_strategy,
896
+ number_of_process=self.val_size,
897
+ store_in_tempfile=self.data_config.store_in_tempfile,
898
+ keep_tempfile=self.data_config.keep_tempfile,
899
+ recreate_tempfile=self.data_config.recreate_tempfile,
900
+ tempfile_path=self.data_config.tempfile_path,
901
+ show_progress=self.data_config.tqdm_progress,
902
+ split="val",
903
+ )
904
+ self.test_dataset = AICMECompartmentsDataset(
905
+ self.model_config,
906
+ ctx_fn=self.context_strategy,
907
+ tgt_fn=self.target_strategy,
908
+ number_of_process=self.test_size,
909
+ store_in_tempfile=self.data_config.store_in_tempfile,
910
+ keep_tempfile=self.data_config.keep_tempfile,
911
+ recreate_tempfile=self.data_config.recreate_tempfile,
912
+ tempfile_path=self.data_config.tempfile_path,
913
+ show_progress=self.data_config.tqdm_progress,
914
+ split="test",
915
+ )
916
+ # Record shapes for empirical builders
917
+ ctx_obs, ctx_rem = self.context_strategy.get_shapes()
918
+ tgt_obs, tgt_rem = self.target_strategy.get_shapes()
919
+ self.max_observations = max(ctx_obs, tgt_obs)
920
+ self.max_remaining = max(ctx_rem, tgt_rem)
921
+ self.max_individuals = max(
922
+ self.train_dataset.max_context_individuals,
923
+ self.train_dataset.n_of_target_individuals,
924
+ )
925
+ self._prepared = True
926
+ self._empirical_loaded = False
927
+
928
+ # Preload empirical datasets during prepare_data so they are available
929
+ # before training callbacks query them.
930
+ # In DDP, keep network/download activity on rank 0 only.
931
+ if self._is_global_zero_process():
932
+ self._load_empirical_test_batches()
933
+ self._empirical_loaded = True
934
+
935
+ def setup(self, stage=None):
936
+ # Use this method to split data into train, validation, and test sets.
937
+ # This is called on every GPU.
938
+ if not self._prepared:
939
+ self.prepare_data()
940
+
941
+ def train_dataloader(self):
942
+ # Returns the training dataloader.
943
+ num_workers, persistent_workers = self._resolve_dataloader_workers()
944
+ return DataLoader(
945
+ self.train_dataset,
946
+ batch_size=self.batch_size,
947
+ shuffle=True,
948
+ num_workers=num_workers,
949
+ persistent_workers=persistent_workers,
950
+ collate_fn=_collate_aicme_batches,
951
+ )
952
+
953
+ def val_dataloader(self):
954
+ # Returns the validation dataloader.
955
+ num_workers, persistent_workers = self._resolve_dataloader_workers()
956
+ return DataLoader(
957
+ self.val_dataset,
958
+ batch_size=self.batch_size,
959
+ shuffle=self.shuffle_val,
960
+ num_workers=num_workers,
961
+ persistent_workers=persistent_workers,
962
+ collate_fn=_collate_aicme_batches,
963
+ )
964
+
965
+ def test_dataloader(self):
966
+ # Optional: Returns the test dataloader.
967
+ # If you don't have a test set, you can omit this method.
968
+ num_workers, persistent_workers = self._resolve_dataloader_workers()
969
+ return DataLoader(
970
+ self.test_dataset,
971
+ batch_size=self.batch_size,
972
+ shuffle=False,
973
+ num_workers=num_workers,
974
+ persistent_workers=persistent_workers,
975
+ collate_fn=_collate_aicme_batches,
976
+ )
977
+
978
+ def obtain_shapes(self) -> Tuple[int, int, int]:
979
+ """Expose dataset shape parameters for empirical batching.
980
+
981
+ Returns
982
+ -------
983
+ Tuple[int, int, int]
984
+ ``(max_individuals, max_observations, max_remaining)`` as used by
985
+ :class:`AICMECompartmentsDataset`.
986
+ """
987
+
988
+ if not self._prepared:
989
+ self.prepare_data()
990
+
991
+ assert self.max_individuals is not None
992
+ assert self.max_observations is not None
993
+ assert self.max_remaining is not None
994
+ return (
995
+ self.max_individuals,
996
+ self.max_observations,
997
+ self.max_remaining,
998
+ )
999
+
1000
+ def _resolve_dataloader_workers(self) -> Tuple[int, bool]:
1001
+ """Return DataLoader worker settings that are safe for single-process runs."""
1002
+ num_workers = max(0, int(self.num_workers))
1003
+ persistent_workers = self.persistent_workers and num_workers > 0
1004
+ return num_workers, persistent_workers
1005
+
1006
+ @staticmethod
1007
+ def _is_global_zero_process() -> bool:
1008
+ """Return True for rank 0 (or single-process execution)."""
1009
+
1010
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
1011
+ return torch.distributed.get_rank() == 0
1012
+ return True
1013
+
1014
+ def _load_empirical_test_batches(self) -> None:
1015
+ """Download and cache empirical Hugging Face datasets for evaluation."""
1016
+
1017
+ from sim_priors_pk.data.data_empirical import load_empirical_hf_batches_as_dm
1018
+
1019
+ datasets = getattr(self.data_config, "test_empirical_datasets", [])
1020
+ self.empirical_test_batches = {}
1021
+ self.empirical_test_batches_no_heldout = {}
1022
+ if not datasets:
1023
+ return
1024
+
1025
+ for repo_id in datasets:
1026
+ try:
1027
+ batches = load_empirical_hf_batches_as_dm(
1028
+ repo_id,
1029
+ meta_dosing=self.model_config.dosing,
1030
+ datamodule=self,
1031
+ held_out=True,
1032
+ )
1033
+ except Exception as exc: # noqa: BLE001 - surface download issues
1034
+ warnings.warn(
1035
+ f"Failed to load empirical dataset '{repo_id}': {exc}",
1036
+ stacklevel=2,
1037
+ )
1038
+ continue
1039
+
1040
+ if not batches:
1041
+ warnings.warn(
1042
+ f"No empirical batches returned for dataset '{repo_id}'",
1043
+ stacklevel=2,
1044
+ )
1045
+ continue
1046
+
1047
+ self.empirical_test_batches[repo_id] = batches
1048
+ try:
1049
+ no_heldout_batches = load_empirical_hf_batches_as_dm(
1050
+ repo_id,
1051
+ meta_dosing=self.model_config.dosing,
1052
+ datamodule=self,
1053
+ held_out=False,
1054
+ )
1055
+ except Exception as exc: # noqa: BLE001 - surface download issues
1056
+ warnings.warn(
1057
+ f"Failed to load no-heldout empirical dataset '{repo_id}': {exc}",
1058
+ stacklevel=2,
1059
+ )
1060
+ continue
1061
+
1062
+ if not no_heldout_batches:
1063
+ warnings.warn(
1064
+ f"No no-heldout empirical batches returned for dataset '{repo_id}'",
1065
+ stacklevel=2,
1066
+ )
1067
+ continue
1068
+
1069
+ self.empirical_test_batches_no_heldout[repo_id] = no_heldout_batches
1070
+
1071
+ def get_empirical_test_batches(
1072
+ self,
1073
+ *,
1074
+ no_heldout: bool = False,
1075
+ device: Optional[torch.device | str] = None,
1076
+ ) -> Dict[str, List["AICMECompartmentsDataBatch"]]:
1077
+ """Return cached empirical batches keyed by Hugging Face dataset id.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ no_heldout:
1082
+ If ``True``, return batches where all empirical individuals remain
1083
+ in context (no held-out target). If ``False`` (default), return the
1084
+ leave-one-out batches.
1085
+ device:
1086
+ Optional device where returned batches should live. When provided,
1087
+ returned batches are moved to ``device`` without mutating the
1088
+ internal cache.
1089
+ """
1090
+
1091
+ # Safety fallback for direct/manual datamodule usage.
1092
+ if not getattr(self, "_empirical_loaded", False):
1093
+ if not self._prepared:
1094
+ self.prepare_data()
1095
+ elif self._is_global_zero_process():
1096
+ self._load_empirical_test_batches()
1097
+ self._empirical_loaded = True
1098
+
1099
+ batch_map = (
1100
+ self.empirical_test_batches_no_heldout if no_heldout else self.empirical_test_batches
1101
+ )
1102
+
1103
+ if device is None:
1104
+ return batch_map
1105
+
1106
+ return {
1107
+ repo_id: list_of_databath_to_device(batch_list, device)
1108
+ for repo_id, batch_list in batch_map.items()
1109
+ }
1110
+
1111
+ def get_empirical_batches(
1112
+ self,
1113
+ *,
1114
+ split: str,
1115
+ empirical_name: Optional[str],
1116
+ device: Optional[torch.device | str] = None,
1117
+ ) -> List["AICMECompartmentsDataBatch"]:
1118
+ """Return one empirical batch list using scheduler-oriented split aliases.
1119
+
1120
+ Supported split aliases:
1121
+ - ``empirical_heldout``: leave-one-out empirical targets
1122
+ - ``empirical_no_heldout``: all empirical individuals remain in context
1123
+ """
1124
+
1125
+ normalized_split = str(split).strip().lower()
1126
+ if normalized_split == "empirical_heldout":
1127
+ batch_map = self.get_empirical_test_batches(no_heldout=False, device=device)
1128
+ elif normalized_split == "empirical_no_heldout":
1129
+ batch_map = self.get_empirical_test_batches(no_heldout=True, device=device)
1130
+ else:
1131
+ raise ValueError(
1132
+ f"Unsupported empirical split alias '{split}'. "
1133
+ "Expected 'empirical_heldout' or 'empirical_no_heldout'."
1134
+ )
1135
+
1136
+ if empirical_name is None:
1137
+ raise ValueError("`empirical_name` must be provided for empirical scheduler tasks.")
1138
+ try:
1139
+ return batch_map[str(empirical_name)]
1140
+ except KeyError as exc:
1141
+ raise ValueError(
1142
+ f"No empirical batches found for split='{split}' and empirical_name='{empirical_name}'."
1143
+ ) from exc
1144
+
1145
+ @staticmethod
1146
+ def _normalize_substance_name(name: object) -> str:
1147
+ """Normalize substance names for robust matching."""
1148
+
1149
+ return "".join(ch.lower() for ch in str(name) if ch.isalnum())
1150
+
1151
+ def select_empirical_batch_list(
1152
+ self,
1153
+ dataset_key: Optional[str] = None,
1154
+ *,
1155
+ no_heldout: bool = False,
1156
+ ) -> Tuple[Optional[str], List["AICMECompartmentsDataBatch"]]:
1157
+ """Select one empirical dataset batch list for plotting/evaluation.
1158
+
1159
+ Parameters
1160
+ ----------
1161
+ dataset_key:
1162
+ Explicit dataset key to use. If missing or unknown, the first
1163
+ non-empty dataset in cache is selected.
1164
+ no_heldout:
1165
+ Whether to read from the no-heldout cache.
1166
+
1167
+ Returns
1168
+ -------
1169
+ Tuple[Optional[str], List[AICMECompartmentsDataBatch]]
1170
+ Selected dataset key (or ``None`` if unavailable) and batch list.
1171
+ """
1172
+
1173
+ empirical_batches = self.get_empirical_test_batches(no_heldout=no_heldout)
1174
+ if dataset_key is not None and dataset_key in empirical_batches:
1175
+ selected_key = dataset_key
1176
+ batch_list = empirical_batches[dataset_key]
1177
+ else:
1178
+ selected_key = None
1179
+ batch_list = None
1180
+ for repo_id, batches in empirical_batches.items():
1181
+ if batches:
1182
+ selected_key = repo_id
1183
+ batch_list = batches
1184
+ break
1185
+
1186
+ if not batch_list:
1187
+ label = "no-heldout" if no_heldout else "heldout"
1188
+ raise RuntimeError(f"No empirical {label} batches available for predictive plotting.")
1189
+
1190
+ return selected_key, batch_list
1191
+
1192
+ def describe_empirical_test_batches(
1193
+ self,
1194
+ empirical_batches: Optional[Dict[str, List["AICMECompartmentsDataBatch"]]] = None,
1195
+ *,
1196
+ no_heldout: bool = False,
1197
+ batch_index: int = 0,
1198
+ print_available: bool = True,
1199
+ ) -> Tuple[List[str], List[str]]:
1200
+ """Describe empirical test batches and return available studies/drugs.
1201
+
1202
+ This helper is designed to be called after
1203
+ :meth:`get_empirical_test_batches` in notebook/script workflows.
1204
+
1205
+ Parameters
1206
+ ----------
1207
+ empirical_batches:
1208
+ Optional pre-fetched empirical batches (typically from
1209
+ :meth:`get_empirical_test_batches`). If ``None``, batches are
1210
+ fetched internally.
1211
+ no_heldout:
1212
+ Whether to describe no-heldout batches.
1213
+ batch_index:
1214
+ Batch index to inspect within each dataset. Default is ``0``.
1215
+ print_available:
1216
+ If ``True``, print available datasets/studies/drugs.
1217
+
1218
+ Returns
1219
+ -------
1220
+ Tuple[List[str], List[str]]
1221
+ Unique available study names and drug names from the selected
1222
+ ``batch_index`` across datasets.
1223
+ """
1224
+
1225
+ batch_map = empirical_batches
1226
+ if batch_map is None:
1227
+ batch_map = self.get_empirical_test_batches(no_heldout=no_heldout)
1228
+
1229
+ if batch_index < 0:
1230
+ raise ValueError("batch_index must be non-negative")
1231
+
1232
+ available_studies: List[str] = []
1233
+ available_drugs: List[str] = []
1234
+ seen_studies: set[str] = set()
1235
+ seen_drugs: set[str] = set()
1236
+
1237
+ if print_available:
1238
+ label = "no_heldout=True" if no_heldout else "heldout"
1239
+ print(f"Available empirical datasets ({label}):", list(batch_map.keys()))
1240
+
1241
+ for repo_id, batch_list in batch_map.items():
1242
+ if print_available:
1243
+ print(f"Dataset '{repo_id}' contains {len(batch_list)} empirical batch(es).")
1244
+ if batch_index >= len(batch_list):
1245
+ if print_available:
1246
+ print(
1247
+ f" Skipping dataset '{repo_id}': batch_index={batch_index} "
1248
+ f"is out of range."
1249
+ )
1250
+ continue
1251
+
1252
+ batch = batch_list[batch_index]
1253
+ studies, drugs = self.describe_empirical_batch(batch, print_available=False)
1254
+ for study in studies:
1255
+ if study not in seen_studies:
1256
+ seen_studies.add(study)
1257
+ available_studies.append(study)
1258
+ for drug in drugs:
1259
+ if drug not in seen_drugs:
1260
+ seen_drugs.add(drug)
1261
+ available_drugs.append(drug)
1262
+
1263
+ if print_available:
1264
+ print(f" Batch {batch_index} studies:", studies)
1265
+ print(f" Batch {batch_index} drugs:", drugs)
1266
+
1267
+ if print_available:
1268
+ print("Available studies:", available_studies)
1269
+ print("Available drugs:", available_drugs)
1270
+
1271
+ return available_studies, available_drugs
1272
+
1273
+ @staticmethod
1274
+ def describe_empirical_batch(
1275
+ batch: "AICMECompartmentsDataBatch",
1276
+ *,
1277
+ print_available: bool = True,
1278
+ ) -> Tuple[List[str], List[str]]:
1279
+ """Return display-ready study and substance names for a batch.
1280
+
1281
+ Parameters
1282
+ ----------
1283
+ batch:
1284
+ Empirical batch to inspect.
1285
+ print_available:
1286
+ If ``True``, print available studies and drugs to stdout.
1287
+ """
1288
+
1289
+ studies = [str(name) if name else f"study_{i}" for i, name in enumerate(batch.study_name)]
1290
+ drugs = [
1291
+ str(name) if name else f"substance_{i}" for i, name in enumerate(batch.substance_name)
1292
+ ]
1293
+
1294
+ if print_available:
1295
+ print("Available studies in selected batch:", studies)
1296
+ print("Available drugs in selected batch:", drugs)
1297
+
1298
+ return studies, drugs
1299
+
1300
+ @staticmethod
1301
+ def slice_single_substance_batch(
1302
+ batch: "AICMECompartmentsDataBatch",
1303
+ b_idx: int,
1304
+ ) -> "AICMECompartmentsDataBatch":
1305
+ """Extract one substance entry from a multi-substance batch.
1306
+
1307
+ Parameters
1308
+ ----------
1309
+ batch:
1310
+ Batch with leading batch dimension ``B``.
1311
+ b_idx:
1312
+ Substance index along ``B``.
1313
+
1314
+ Returns
1315
+ -------
1316
+ AICMECompartmentsDataBatch
1317
+ Single-substance batch with tensors sliced to ``B=1``.
1318
+ """
1319
+
1320
+ if b_idx < 0 or b_idx >= len(batch.substance_name):
1321
+ raise IndexError(
1322
+ f"Substance index {b_idx} is out of range for batch size "
1323
+ f"{len(batch.substance_name)}."
1324
+ )
1325
+
1326
+ values = []
1327
+ for field_name in batch._fields:
1328
+ value = getattr(batch, field_name)
1329
+ if isinstance(value, torch.Tensor):
1330
+ # Keep tensor rank stable by preserving a singleton leading B axis.
1331
+ values.append(value[b_idx : b_idx + 1])
1332
+ elif field_name in {
1333
+ "study_name",
1334
+ "substance_name",
1335
+ "context_subject_name",
1336
+ "target_subject_name",
1337
+ }:
1338
+ values.append([value[b_idx]])
1339
+ else:
1340
+ values.append(value)
1341
+ return batch.__class__(*values)
1342
+
1343
+ @classmethod
1344
+ def slice_single_substance_batch_by_name(
1345
+ cls,
1346
+ batch: "AICMECompartmentsDataBatch",
1347
+ substance_name: str,
1348
+ ) -> "AICMECompartmentsDataBatch":
1349
+ """Extract one substance entry by matching drug name."""
1350
+
1351
+ _, available_drugs = cls.describe_empirical_batch(batch, print_available=False)
1352
+ norm_target = cls._normalize_substance_name(substance_name)
1353
+ matches = [
1354
+ i
1355
+ for i, name in enumerate(available_drugs)
1356
+ if cls._normalize_substance_name(name) == norm_target
1357
+ ]
1358
+ if not matches:
1359
+ raise ValueError(
1360
+ f"Selected drug '{substance_name}' not found in heldout batch. "
1361
+ f"Choose from: {available_drugs}"
1362
+ )
1363
+ return cls.slice_single_substance_batch(batch, matches[0])
1364
+
1365
+ def select_empirical_drug_batch(
1366
+ self,
1367
+ empirical_batches: Dict[str, List["AICMECompartmentsDataBatch"]],
1368
+ selected_drug: str,
1369
+ *,
1370
+ permutation_indexes: Optional[int | Sequence[int]] = None,
1371
+ print_selection: bool = True,
1372
+ ) -> Tuple[
1373
+ "AICMECompartmentsDataBatch | List[AICMECompartmentsDataBatch]",
1374
+ str,
1375
+ str,
1376
+ ]:
1377
+ """Select one drug from empirical batches, optionally across permutations.
1378
+
1379
+ Parameters
1380
+ ----------
1381
+ empirical_batches:
1382
+ Mapping returned by :meth:`get_empirical_test_batches`.
1383
+ selected_drug:
1384
+ Drug name to match across all empirical batches.
1385
+ permutation_indexes:
1386
+ Optional permutation index or list of permutation indices within the
1387
+ selected empirical dataset's batch list. When ``None`` (default),
1388
+ the method preserves legacy behaviour and returns the first matching
1389
+ single-substance batch. When a list/tuple is provided, returns a
1390
+ list of single-substance batches in the requested permutation order.
1391
+ print_selection:
1392
+ If ``True``, print where the match was found.
1393
+ """
1394
+
1395
+ norm_target = self._normalize_substance_name(selected_drug)
1396
+ requested_permutations: Optional[List[int]]
1397
+ return_many = isinstance(permutation_indexes, (list, tuple))
1398
+ if permutation_indexes is None:
1399
+ requested_permutations = None
1400
+ elif return_many:
1401
+ if len(permutation_indexes) == 0:
1402
+ raise ValueError("'permutation_indexes' must not be empty.")
1403
+ requested_permutations = [int(idx) for idx in permutation_indexes]
1404
+ if len(set(requested_permutations)) != len(requested_permutations):
1405
+ raise ValueError("'permutation_indexes' must contain unique indices.")
1406
+ else:
1407
+ requested_permutations = [int(permutation_indexes)]
1408
+
1409
+ all_available_drugs: List[str] = []
1410
+ seen_drugs: set[str] = set()
1411
+
1412
+ for repo_id, batch_list in empirical_batches.items():
1413
+ for batch_index, batch in enumerate(batch_list):
1414
+ _, available_drugs = self.describe_empirical_batch(batch, print_available=False)
1415
+ for drug in available_drugs:
1416
+ if drug not in seen_drugs:
1417
+ seen_drugs.add(drug)
1418
+ all_available_drugs.append(drug)
1419
+
1420
+ matches = [
1421
+ i
1422
+ for i, name in enumerate(available_drugs)
1423
+ if self._normalize_substance_name(name) == norm_target
1424
+ ]
1425
+ if matches:
1426
+ if requested_permutations is None:
1427
+ selected_batches: List[AICMECompartmentsDataBatch] = [
1428
+ self.slice_single_substance_batch(batch, matches[0])
1429
+ ]
1430
+ chosen_permutations = [batch_index]
1431
+ else:
1432
+ selected_batches = []
1433
+ chosen_permutations = requested_permutations
1434
+ for permutation_index in requested_permutations:
1435
+ if permutation_index < 0 or permutation_index >= len(batch_list):
1436
+ raise IndexError(
1437
+ f"Permutation index {permutation_index} is out of range for "
1438
+ f"dataset '{repo_id}' with {len(batch_list)} permutations."
1439
+ )
1440
+
1441
+ perm_batch = batch_list[permutation_index]
1442
+ _, perm_drugs = self.describe_empirical_batch(
1443
+ perm_batch, print_available=False
1444
+ )
1445
+ perm_matches = [
1446
+ i
1447
+ for i, name in enumerate(perm_drugs)
1448
+ if self._normalize_substance_name(name) == norm_target
1449
+ ]
1450
+ if not perm_matches:
1451
+ raise ValueError(
1452
+ f"Selected drug '{selected_drug}' was not found in dataset "
1453
+ f"'{repo_id}' at permutation index {permutation_index}."
1454
+ )
1455
+ selected_batches.append(
1456
+ self.slice_single_substance_batch(perm_batch, perm_matches[0])
1457
+ )
1458
+
1459
+ studies, drugs = self.describe_empirical_batch(
1460
+ selected_batches[0], print_available=False
1461
+ )
1462
+ selected_study = studies[0]
1463
+ selected_name = drugs[0]
1464
+ if print_selection:
1465
+ print("Selected empirical dataset key:", repo_id)
1466
+ if len(chosen_permutations) == 1:
1467
+ print("Selected empirical batch index:", chosen_permutations[0])
1468
+ else:
1469
+ print("Selected empirical batch indexes:", chosen_permutations)
1470
+ print("Selected study:", selected_study)
1471
+ print("Selected drug:", selected_name)
1472
+ if return_many:
1473
+ return selected_batches, selected_study, selected_name
1474
+ return selected_batches[0], selected_study, selected_name
1475
+
1476
+ raise ValueError(
1477
+ f"Selected drug '{selected_drug}' was not found in empirical batches. "
1478
+ f"Choose from: {all_available_drugs}"
1479
+ )
1480
+
1481
+ def _select_strategy(self, who: str):
1482
+ """Return the observation strategy requested via ``who``.
1483
+
1484
+ Parameters
1485
+ ----------
1486
+ who:
1487
+ Either ``"target"`` or ``"context"``.
1488
+
1489
+ Returns
1490
+ -------
1491
+ ObservationStrategy
1492
+ The strategy matching the requested role.
1493
+ """
1494
+
1495
+ if who == "target":
1496
+ return self.target_strategy
1497
+ if who == "context":
1498
+ return self.context_strategy
1499
+ raise ValueError("'who' must be either 'target' or 'context'.")
1500
+
1501
+ def _select_strategies(self, who: str) -> List[object]:
1502
+ """Return strategy list for the requested role.
1503
+
1504
+ For ``who='target'`` this includes both synthetic and empirical target
1505
+ strategies so past-selection overrides remain consistent when empirical
1506
+ batches are generated from the datamodule.
1507
+ """
1508
+
1509
+ if who == "context":
1510
+ return [self.context_strategy]
1511
+ if who == "target":
1512
+ strategies: List[object] = [self.target_strategy]
1513
+ empirical_target_strategy = getattr(self, "empirical_target_strategy", None)
1514
+ if empirical_target_strategy is not None:
1515
+ strategies.append(empirical_target_strategy)
1516
+ # Keep order stable while avoiding duplicate objects.
1517
+ deduped: List[object] = []
1518
+ seen_ids: set[int] = set()
1519
+ for strategy in strategies:
1520
+ strategy_id = id(strategy)
1521
+ if strategy_id in seen_ids:
1522
+ continue
1523
+ seen_ids.add(strategy_id)
1524
+ deduped.append(strategy)
1525
+ return deduped
1526
+ raise ValueError("'who' must be either 'target' or 'context'.")
1527
+
1528
+ def fix_past_selection(self, fix_past_value: int, *, who: str = "target") -> None:
1529
+ """Force a fixed number of past observations for the selected strategy.
1530
+
1531
+ The override is only applied for strategies with ``split_past_future``
1532
+ enabled; for others the call is ignored.
1533
+ """
1534
+
1535
+ if not self._prepared:
1536
+ self.prepare_data()
1537
+
1538
+ for strategy in self._select_strategies(who):
1539
+ if hasattr(strategy, "fix_past_selection"):
1540
+ strategy.fix_past_selection(fix_past_value)
1541
+ # Reset lazy-load flag so empirical data is reloaded with new strategy settings
1542
+ self._empirical_loaded = False
1543
+
1544
+ def release_past_selection(self, *, who: str = "target") -> None:
1545
+ """Restore the default past sampling behaviour for the given strategy."""
1546
+
1547
+ if not self._prepared:
1548
+ self.prepare_data()
1549
+
1550
+ for strategy in self._select_strategies(who):
1551
+ if hasattr(strategy, "release_past_selection"):
1552
+ strategy.release_past_selection()
1553
+ # Reset lazy-load flag so empirical data is reloaded with restored strategy settings
1554
+ self._empirical_loaded = False
1555
+
1556
+ def set_shared_target_dosing(self, enable: bool = True, n_targets: int = 100) -> None:
1557
+ """Enable/disable shared target dosing across all datasets.
1558
+
1559
+ Parameters
1560
+ ----------
1561
+ enable : bool
1562
+ Whether to enable shared-target dosing.
1563
+ n_targets : int
1564
+ Number of target individuals to sample when enabled.
1565
+ """
1566
+ self.use_shared_target_dosing = enable
1567
+ self.shared_target_n_targets = n_targets
1568
+
1569
+ for ds in (
1570
+ getattr(self, "train_dataset", None),
1571
+ getattr(self, "val_dataset", None),
1572
+ getattr(self, "test_dataset", None),
1573
+ ):
1574
+ if ds is not None:
1575
+ ds.use_shared_target_dosing = enable
1576
+ ds.shared_target_n_targets = n_targets
1577
+
1578
+ def unset_shared_target_dosing(self) -> None:
1579
+ """Disable shared target dosing and restore default behaviour."""
1580
+ self.set_shared_target_dosing(False)
1581
+
1582
+ @staticmethod
1583
+ def _add_batch_dim_to_synthetic_batch(
1584
+ batch: AICMECompartmentsDataBatch,
1585
+ ) -> AICMECompartmentsDataBatch:
1586
+ """Add a leading ``B=1`` axis to tensor fields missing batch dimension."""
1587
+
1588
+ values: list = []
1589
+ for name, value in zip(batch._fields, batch):
1590
+ if not isinstance(value, torch.Tensor):
1591
+ values.append(value)
1592
+ continue
1593
+
1594
+ if name == "time_scales":
1595
+ # ``time_scales`` is often already [B, 2] while other fields are
1596
+ # emitted as [I, ...] by dataset-level generation.
1597
+ values.append(value.unsqueeze(0) if value.dim() == 1 else value)
1598
+ continue
1599
+
1600
+ values.append(value.unsqueeze(0))
1601
+
1602
+ return AICMECompartmentsDataBatch(*values)
1603
+
1604
+ def _generate_synthetic_list_with_repeated_target(
1605
+ self,
1606
+ *,
1607
+ shared_context_pack: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor],
1608
+ target_sim: Tensor,
1609
+ target_times: Tensor,
1610
+ target_dosing_amounts: Tensor,
1611
+ target_dosing_routes: Tensor,
1612
+ base_time_scales: Tensor,
1613
+ num_targets: int,
1614
+ ) -> AICMECompartmentsDataBatch:
1615
+ """Package one list element from shared context and dosing-specific targets.
1616
+
1617
+ Parameters
1618
+ ----------
1619
+ shared_context_pack:
1620
+ Context tensors generated once and reused across all list elements.
1621
+ target_sim:
1622
+ Target simulation for one dosing condition with shape ``[n_targets, T]``.
1623
+ target_times:
1624
+ Target simulation times with shape ``[n_targets, T]``.
1625
+ target_dosing_amounts:
1626
+ Target dosing amounts with shape ``[n_targets]``.
1627
+ target_dosing_routes:
1628
+ Target dosing route types with shape ``[n_targets]``.
1629
+ base_time_scales:
1630
+ Simulation-level time scales from the sampler.
1631
+ num_targets:
1632
+ Number of target individuals capacity for this synthetic sample.
1633
+ """
1634
+ (
1635
+ context_obs,
1636
+ context_obs_time,
1637
+ context_obs_mask,
1638
+ context_rem_sim,
1639
+ context_rem_sim_time,
1640
+ context_rem_sim_mask,
1641
+ dosing_amounts_ctx,
1642
+ dosing_routes_ctx,
1643
+ ) = shared_context_pack
1644
+
1645
+ target_obs_pack = self.train_dataset._safe_generate(
1646
+ self.train_dataset.target_fn,
1647
+ target_sim,
1648
+ target_times,
1649
+ time_scales=base_time_scales,
1650
+ )
1651
+ (
1652
+ target_obs,
1653
+ target_obs_time,
1654
+ target_obs_mask,
1655
+ target_rem_sim,
1656
+ target_rem_sim_time,
1657
+ target_rem_sim_mask,
1658
+ _target_time_scales,
1659
+ ) = target_obs_pack
1660
+
1661
+ # Keep the time-scale metadata aligned with the shared context payload.
1662
+ ts = base_time_scales
1663
+
1664
+ return self.train_dataset._build_padded_batch(
1665
+ # context (shared across all list elements)
1666
+ context_obs,
1667
+ context_obs_time,
1668
+ context_obs_mask,
1669
+ context_rem_sim,
1670
+ context_rem_sim_time,
1671
+ context_rem_sim_mask,
1672
+ dosing_amounts_ctx,
1673
+ dosing_routes_ctx,
1674
+ # target (specific to one repeated-dosing condition)
1675
+ target_obs,
1676
+ target_obs_time,
1677
+ target_obs_mask,
1678
+ target_rem_sim,
1679
+ target_rem_sim_time,
1680
+ target_rem_sim_mask,
1681
+ target_dosing_amounts,
1682
+ target_dosing_routes,
1683
+ # time scales
1684
+ ts=ts,
1685
+ target_capacity=num_targets,
1686
+ )
1687
+
1688
+ def prepare_full_simulation_list_with_repeated_targets(
1689
+ self,
1690
+ num_targets: int,
1691
+ batch_index: int = 0,
1692
+ num_of_different_dosages: int = 1,
1693
+ device: Optional[torch.device | str] = None,
1694
+ ) -> List["AICMECompartmentsDataBatch"]:
1695
+ """Build one shared context and ``L`` repeated-target dosing batches.
1696
+
1697
+ This helper is responsible for context creation exactly once, then
1698
+ looping over ``num_of_different_dosages`` target dosing conditions.
1699
+ Packaging is delegated to
1700
+ :meth:`_generate_synthetic_list_with_repeated_target`.
1701
+
1702
+ Parameters
1703
+ ----------
1704
+ num_targets:
1705
+ Number of target individuals per dosing condition.
1706
+ batch_index:
1707
+ Synthetic sample index used by the simulation backend.
1708
+ num_of_different_dosages:
1709
+ Number of target dosing conditions ``L``.
1710
+ device:
1711
+ Optional device where returned batches should live.
1712
+ """
1713
+
1714
+ if num_targets < 0:
1715
+ raise ValueError("num_targets must be non-negative")
1716
+ if batch_index < 0:
1717
+ raise ValueError("batch_index must be non-negative")
1718
+ if num_of_different_dosages < 0:
1719
+ raise ValueError("num_of_different_dosages must be non-negative")
1720
+
1721
+ if not self._prepared:
1722
+ self.prepare_data()
1723
+
1724
+ (
1725
+ context_sim,
1726
+ context_times,
1727
+ dosing_amounts_ctx,
1728
+ dosing_routes_ctx,
1729
+ target_simulations,
1730
+ target_times_list,
1731
+ target_dosing_amounts_list,
1732
+ target_dosing_routes_list,
1733
+ _time_points,
1734
+ time_scales,
1735
+ ) = prepare_full_simulation_list_with_repeated_targets_backend(
1736
+ self.meta_config,
1737
+ self.model_config.dosing,
1738
+ n_targets=num_targets,
1739
+ num_of_different_dosages=num_of_different_dosages,
1740
+ idx=batch_index,
1741
+ )
1742
+
1743
+ # Build context once and reuse verbatim for all list elements.
1744
+ context_obs_pack = self.train_dataset._safe_generate(
1745
+ self.train_dataset.context_fn,
1746
+ context_sim,
1747
+ context_times,
1748
+ time_scales=time_scales,
1749
+ )
1750
+ (
1751
+ context_obs,
1752
+ context_obs_time,
1753
+ context_obs_mask,
1754
+ context_rem_sim,
1755
+ context_rem_sim_time,
1756
+ context_rem_sim_mask,
1757
+ context_time_scales,
1758
+ ) = context_obs_pack
1759
+
1760
+ shared_context_pack = (
1761
+ context_obs,
1762
+ context_obs_time,
1763
+ context_obs_mask,
1764
+ context_rem_sim,
1765
+ context_rem_sim_time,
1766
+ context_rem_sim_mask,
1767
+ dosing_amounts_ctx,
1768
+ dosing_routes_ctx,
1769
+ )
1770
+
1771
+ base_time_scales = context_time_scales if context_time_scales is not None else time_scales
1772
+
1773
+ synthetic_batches: List[AICMECompartmentsDataBatch] = []
1774
+ for target_sim, target_times, target_dosing_amounts, target_dosing_routes in zip(
1775
+ target_simulations,
1776
+ target_times_list,
1777
+ target_dosing_amounts_list,
1778
+ target_dosing_routes_list,
1779
+ ):
1780
+ synthetic_batches.append(
1781
+ self._generate_synthetic_list_with_repeated_target(
1782
+ shared_context_pack=shared_context_pack,
1783
+ target_sim=target_sim,
1784
+ target_times=target_times,
1785
+ target_dosing_amounts=target_dosing_amounts,
1786
+ target_dosing_routes=target_dosing_routes,
1787
+ base_time_scales=base_time_scales,
1788
+ num_targets=num_targets,
1789
+ )
1790
+ )
1791
+
1792
+ if device is None:
1793
+ return synthetic_batches
1794
+
1795
+ return list_of_databath_to_device(synthetic_batches, device)
1796
+
1797
+ def generate_synthetic_with_repeated_target(
1798
+ self,
1799
+ num_targets: int,
1800
+ batch_index: int = 0,
1801
+ different_dosing: bool = False,
1802
+ device: Optional[torch.device | str] = None,
1803
+ ) -> List["AICMECompartmentsDataBatch"]:
1804
+ """Generate one synthetic batch list with configurable target dosing.
1805
+
1806
+ The generated sample follows the current datamodule ``data_config`` and
1807
+ observation strategies while overriding only the number of target
1808
+ individuals. The returned tensors include an explicit leading batch
1809
+ dimension ``B=1`` to match dataloader outputs.
1810
+
1811
+ Parameters
1812
+ ----------
1813
+ num_targets:
1814
+ Number of target individuals to include in the generated synthetic
1815
+ sample.
1816
+ batch_index:
1817
+ Dataset index used by the internal synthetic generator.
1818
+ different_dosing:
1819
+ If ``False`` (default), target individuals share one repeated dosing
1820
+ configuration.
1821
+ If ``True``, each target individual receives an independent dosing
1822
+ sample drawn from the same dosing distribution as context.
1823
+ device:
1824
+ Optional device where returned batches should live. When provided,
1825
+ returned batches are moved to ``device``.
1826
+
1827
+ Returns
1828
+ -------
1829
+ List[AICMECompartmentsDataBatch]
1830
+ A list containing a single synthetic databatch.
1831
+ """
1832
+
1833
+ if num_targets < 0:
1834
+ raise ValueError("num_targets must be non-negative")
1835
+ if batch_index < 0:
1836
+ raise ValueError("batch_index must be non-negative")
1837
+
1838
+ if not self._prepared:
1839
+ self.prepare_data()
1840
+
1841
+ batches = self.train_dataset._generate_item_sample_target_dosing(
1842
+ batch_index,
1843
+ n_targets=num_targets,
1844
+ different_dosing=different_dosing,
1845
+ )
1846
+ batch_list = [self._add_batch_dim_to_synthetic_batch(batch) for batch in batches]
1847
+
1848
+ if device is None:
1849
+ return batch_list
1850
+
1851
+ return list_of_databath_to_device(batch_list, device)
1852
+
1853
+ def generate_synthetic_list_of_repeated_target(
1854
+ self,
1855
+ num_targets: int,
1856
+ batch_index: int = 0,
1857
+ num_of_different_dosages: int = 1,
1858
+ device: Optional[torch.device | str] = None,
1859
+ ) -> List["AICMECompartmentsDataBatch"]:
1860
+ """Generate a list of synthetic batches sharing one context.
1861
+
1862
+ The returned list has length ``num_of_different_dosages``. Context
1863
+ fields are identical across all elements, while targets are regenerated
1864
+ per element using repeated dosing within each element.
1865
+ """
1866
+
1867
+ synthetic_batches = self.prepare_full_simulation_list_with_repeated_targets(
1868
+ num_targets=num_targets,
1869
+ batch_index=batch_index,
1870
+ num_of_different_dosages=num_of_different_dosages,
1871
+ device=device,
1872
+ )
1873
+ batch_list = [self._add_batch_dim_to_synthetic_batch(batch) for batch in synthetic_batches]
1874
+ return batch_list
sim_priors_pk/data/extra/compartment_models_vectorized.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def sample_individual_configs_vectorized(study_config):
5
+ """
6
+ Vectorizes the sampling of parameters for a population of individuals.
7
+
8
+ Parameters
9
+ ----------
10
+ study_config : StudyConfig
11
+ Contains the study settings and distribution parameters.
12
+
13
+ Returns
14
+ -------
15
+ config_dict : dict
16
+ Dictionary containing the vectorized parameters and time-magnitudes.
17
+ Keys:
18
+ 'k_a', 'k_e', 'V': Tensors of shape (N,)
19
+ 'k_1p', 'k_p1': Tensors of shape (N, P)
20
+ 'k_a_tmag', 'k_e_tmag', 'V_tmag': Scalars
21
+ 'k_1p_tmag', 'k_p1_tmag': Tensors of shape (P,)
22
+ 'num_peripherals': int
23
+ """
24
+ N = study_config.num_individuals
25
+ P = study_config.num_peripherals
26
+
27
+ # Sample the central parameters as tensors of shape (N,)
28
+ k_a = torch.from_numpy(np.random.lognormal(study_config.log_k_a_mean, study_config.log_k_a_std, size=N)).float()
29
+ k_e = torch.from_numpy(np.random.lognormal(study_config.log_k_e_mean, study_config.log_k_e_std, size=N)).float()
30
+ V = torch.from_numpy(np.random.lognormal(study_config.log_V_mean, study_config.log_V_std, size=N)).float()
31
+
32
+ # Sample the peripheral parameters as tensors of shape (N, P)
33
+ k_1p = []
34
+ k_p1 = []
35
+ for i in range(P):
36
+ k_1p_i = torch.from_numpy(np.random.lognormal(study_config.log_k_1p_mean[i],
37
+ study_config.log_k_1p_std[i], size=N)).float()
38
+ k_p1_i = torch.from_numpy(np.random.lognormal(study_config.log_k_p1_mean[i],
39
+ study_config.log_k_p1_std[i], size=N)).float()
40
+ k_1p.append(k_1p_i)
41
+ k_p1.append(k_p1_i)
42
+ # Stack along the peripheral dimension: shape becomes (N, P)
43
+ k_1p = torch.stack(k_1p, dim=1)
44
+ k_p1 = torch.stack(k_p1, dim=1)
45
+
46
+ # Pack time-magnitudes (assumed scalars for central parameters and lists for peripherals)
47
+ k_a_tmag = study_config.k_a_tmag # scalar
48
+ k_e_tmag = study_config.k_e_tmag # scalar
49
+ V_tmag = study_config.V_tmag # scalar
50
+ # For peripherals, we assume the study_config gives lists/arrays of length P.
51
+ k_1p_tmag = torch.tensor(study_config.k_1p_tmag).float() # shape (P,)
52
+ k_p1_tmag = torch.tensor(study_config.k_p1_tmag).float() # shape (P,)
53
+
54
+ config_dict = {
55
+ 'k_a': k_a,
56
+ 'k_e': k_e,
57
+ 'V': V,
58
+ 'k_1p': k_1p,
59
+ 'k_p1': k_p1,
60
+ 'k_a_tmag': k_a_tmag,
61
+ 'k_e_tmag': k_e_tmag,
62
+ 'V_tmag': V_tmag,
63
+ 'k_1p_tmag': k_1p_tmag,
64
+ 'k_p1_tmag': k_p1_tmag,
65
+ 'num_peripherals': P,
66
+ }
67
+ return config_dict
68
+
69
+ import torch
70
+
71
+ def compute_rates(config, t):
72
+ """
73
+ Computes the dynamic rates for all individuals at a given time t.
74
+
75
+ Parameters
76
+ ----------
77
+ config : dict
78
+ Dictionary returned by sample_individual_configs_vectorized.
79
+ t : float or torch.Tensor
80
+ Current time point.
81
+
82
+ Returns
83
+ -------
84
+ k_a, k_e, V : torch.Tensor
85
+ Tensors of shape (N,).
86
+ k_1p, k_p1 : torch.Tensor
87
+ Tensors of shape (N, P).
88
+ """
89
+ # Ensure t is a tensor
90
+ if not isinstance(t, torch.Tensor):
91
+ t = torch.tensor(t, dtype=config['k_a_tmag'].dtype, device=config['k_a_tmag'].device)
92
+
93
+ k_a = config['k_a'] * torch.exp(-config['k_a_tmag'] * t)
94
+ k_e = config['k_e'] * torch.exp(-config['k_e_tmag'] * t)
95
+ V = config['V'] * torch.exp(-config['V_tmag'] * t)
96
+
97
+ # Use broadcasting for peripheral compartments
98
+ k_1p = config['k_1p'] * torch.exp(-config['k_1p_tmag'] * t)
99
+ k_p1 = config['k_p1'] * torch.exp(-config['k_p1_tmag'] * t)
100
+
101
+ return k_a, k_e, V, k_1p, k_p1
102
+
103
+ def ode_func(t_val, y, config):
104
+ """
105
+ ODE function using vectorized rate computations.
106
+
107
+ Parameters
108
+ ----------
109
+ t_val : torch.Tensor
110
+ Current time point.
111
+ y : torch.Tensor
112
+ Current state, shape (N, M) where M = 2 + num_peripherals.
113
+ config : dict
114
+ Vectorized individual configuration dictionary.
115
+
116
+ Returns
117
+ -------
118
+ dy_dt : torch.Tensor
119
+ Time derivative of y, shape (N, M).
120
+ """
121
+ # Get the dynamic rates for all individuals at time t_val.
122
+ k_a, k_e, _, k_1p, k_p1 = compute_rates(config, t_val)
123
+ N = y.size(0)
124
+ P = config['num_peripherals']
125
+ M = 2 + P
126
+
127
+ # Build the ODE rate matrix A(t) in a vectorized fashion
128
+ A_all = torch.zeros((N, M, M), dtype=torch.float32)
129
+ A_all[:, 0, 0] = -k_a # Loss from gut
130
+ A_all[:, 1, 0] = k_a # Transfer gut -> central
131
+ A_all[:, 1, 1] = -k_e - k_1p.sum(dim=1) # Loss from central and distribution to peripherals
132
+ A_all[:, 1, 2:2+P] = k_p1 # Transfer central -> peripherals
133
+ A_all[:, 2:2+P, 1] = k_1p # Transfer peripherals -> central
134
+ # Peripheral compartments clearance:
135
+ for i in range(P):
136
+ A_all[:, 2 + i, 2 + i] = -k_p1[:, i]
137
+
138
+ # Compute dy/dt = A_all @ y for each individual.
139
+ dy_dt = torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
140
+ return dy_dt
141
+
142
+ def sample_study_vectorized(study_config, dosing_config, t, solver_method="rk4"):
143
+ """
144
+ Simulates the pharmacokinetic study using vectorized individual configurations.
145
+
146
+ Parameters
147
+ ----------
148
+ study_config : StudyConfig
149
+ Contains global study settings and distribution parameters.
150
+ dosing_config : DosingConfig
151
+ Contains dosing information.
152
+ t : torch.Tensor
153
+ Time points at which the simulation is evaluated.
154
+
155
+ Returns
156
+ -------
157
+ full_simulation : torch.Tensor
158
+ Concentration profiles (N, len(t)).
159
+ full_times : torch.Tensor
160
+ Time points replicated for each individual.
161
+ """
162
+ from torchdiffeq import odeint
163
+
164
+ # Get the vectorized configuration dictionary
165
+ config = sample_individual_configs_vectorized(study_config)
166
+ N = study_config.num_individuals
167
+ P = study_config.num_peripherals
168
+ M = 2 + P
169
+
170
+ # Initial conditions: dose in the gut (first compartment), zeros elsewhere.
171
+ y0 = torch.zeros((N, M), dtype=torch.float32)
172
+ y0[:, 0] = dosing_config.dose
173
+
174
+ def wrapped_ode(t_val, y):
175
+ return ode_func(t_val, y, config)
176
+
177
+ # Solve the ODE system for all individuals in batch
178
+ y = odeint(wrapped_ode, y0, t, method=solver_method)
179
+ # Extract central compartment (index 1) for each individual
180
+ full_simulation = y[:, :, 1].T
181
+ full_times = t.unsqueeze(0).repeat(N, 1)
182
+ return full_simulation, full_times
sim_priors_pk/data/extra/kernels.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gpytorch
3
+
4
+ def create_kernel(config):
5
+ kernel_params = config.kernel_params
6
+ if 'type' not in kernel_params:
7
+ raise ValueError("Kernel type must be specified in kernel_params")
8
+ if kernel_params['type'] == 'RBF':
9
+ kernel = gpytorch.kernels.RBFKernel(ard_num_dims=config.input_dim, requires_grad=False)
10
+ kernel_params_ = kernel_params.get('params', {})
11
+ kernel_length_scale = kernel_params_["raw_lengthscale"]
12
+ kernel_length_scale = torch.tensor([kernel_length_scale] * config.input_dim)
13
+ kernel.initialize(raw_lengthscale=kernel_length_scale)
14
+ return kernel
15
+ raise ValueError(f"Unsupported kernel type: {kernel_params['type']}")
16
+
17
+
18
+ def create_kernel_mix(kernel_params,input_dim=1):
19
+ if 'type' not in kernel_params:
20
+ raise ValueError("Kernel type must be specified in kernel_params")
21
+ if kernel_params['type'] == 'RBF':
22
+ kernel = gpytorch.kernels.RBFKernel(ard_num_dims=input_dim, requires_grad=False)
23
+ kernel_params_ = kernel_params.get('params', {})
24
+ kernel_length_scale = kernel_params_["raw_lengthscale"]
25
+ kernel_length_scale = torch.tensor([kernel_length_scale] * input_dim)
26
+ kernel.initialize(raw_lengthscale=kernel_length_scale)
27
+ return kernel
28
+ raise ValueError(f"Unsupported kernel type: {kernel_params['type']}")
sim_priors_pk/hub_runtime/README.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hub Runtime Bundle
2
+
3
+ This directory contains the parallel Hugging Face export path for
4
+ consumer-facing model bundles.
5
+
6
+ The existing training export remains unchanged:
7
+
8
+ - native export: `BasicLightningExperiment._push_model_to_hub(...)`
9
+ - runtime export: `push_loaded_model_runtime_bundle(...)`
10
+
11
+ The runtime export is intended for users who should be able to load a model
12
+ from the Hugging Face Hub through `transformers` without installing the local
13
+ `sim_priors_pk` package.
14
+
15
+ ## Important Constraint
16
+
17
+ The consumer entrypoint is `transformers`, but `transformers` alone is **not**
18
+ enough today.
19
+
20
+ These runtime bundles still execute PyTorch-based custom code and reconstruct
21
+ the internal PK architecture, so the user needs the runtime Python
22
+ dependencies, but not a local checkout of this repository.
23
+
24
+ Reliable consumer install:
25
+
26
+ ```bash
27
+ pip install torch transformers huggingface_hub lightning datasets pandas torchtyping gpytorch pot torchdiffeq torchsde ruamel.yaml pyyaml
28
+ ```
29
+
30
+ What the consumer does **not** need:
31
+
32
+ - `pip install sim_priors_pk`
33
+ - a local clone of this repository
34
+ - access to the training checkpoint directory
35
+
36
+ ## Consumer Workflow
37
+
38
+ Use the runtime repo, not the native training-artifact repo.
39
+
40
+ ```python
41
+ from transformers import AutoModel
42
+
43
+ model = AutoModel.from_pretrained(
44
+ "your-org/your-model-runtime",
45
+ trust_remote_code=True,
46
+ )
47
+ ```
48
+
49
+ Then call the stable runtime task API:
50
+
51
+ ```python
52
+ outputs = model.run_task(
53
+ task="generate", # or "predict"
54
+ studies=studies, # one StudyJSON or a list[StudyJSON]
55
+ num_samples=8,
56
+ )
57
+ ```
58
+
59
+ The return payload is:
60
+
61
+ ```python
62
+ {
63
+ "task": "generate",
64
+ "io_schema_version": "studyjson-v1",
65
+ "model_info": {...},
66
+ "results": [
67
+ {
68
+ "input_index": 0,
69
+ "samples": [study_json_0, study_json_1, ...],
70
+ }
71
+ ],
72
+ }
73
+ ```
74
+
75
+ ## Generate Example
76
+
77
+ ```python
78
+ from transformers import AutoModel
79
+
80
+ model = AutoModel.from_pretrained(
81
+ "your-org/your-model-runtime",
82
+ trust_remote_code=True,
83
+ )
84
+
85
+ studies = [
86
+ {
87
+ "context": [
88
+ {
89
+ "name_id": "ctx_0",
90
+ "observations": [0.2, 0.5, 0.3],
91
+ "observation_times": [0.5, 1.0, 2.0],
92
+ "dosing": [1.0],
93
+ "dosing_type": ["oral"],
94
+ "dosing_times": [0.0],
95
+ "dosing_name": ["oral"],
96
+ }
97
+ ],
98
+ "target": [],
99
+ "meta_data": {
100
+ "study_name": "demo",
101
+ "substance_name": "drug_x",
102
+ },
103
+ }
104
+ ]
105
+
106
+ outputs = model.run_task(
107
+ task="generate",
108
+ studies=studies,
109
+ num_samples=4,
110
+ )
111
+
112
+ generated_studies = outputs["results"][0]["samples"]
113
+ ```
114
+
115
+ ## Predict Example
116
+
117
+ ```python
118
+ from transformers import AutoModel
119
+
120
+ model = AutoModel.from_pretrained(
121
+ "your-org/your-model-runtime",
122
+ trust_remote_code=True,
123
+ )
124
+
125
+ predict_studies = [
126
+ {
127
+ "context": [
128
+ {
129
+ "name_id": "ctx_0",
130
+ "observations": [0.2, 0.5, 0.3],
131
+ "observation_times": [0.5, 1.0, 2.0],
132
+ "dosing": [1.0],
133
+ "dosing_type": ["oral"],
134
+ "dosing_times": [0.0],
135
+ "dosing_name": ["oral"],
136
+ }
137
+ ],
138
+ "target": [
139
+ {
140
+ "name_id": "tgt_0",
141
+ "observations": [0.25, 0.31],
142
+ "observation_times": [0.5, 1.0],
143
+ "remaining": [0.0, 0.0, 0.0],
144
+ "remaining_times": [2.0, 4.0, 8.0],
145
+ "dosing": [1.0],
146
+ "dosing_type": ["oral"],
147
+ "dosing_times": [0.0],
148
+ "dosing_name": ["oral"],
149
+ }
150
+ ],
151
+ "meta_data": {
152
+ "study_name": "demo",
153
+ "substance_name": "drug_x",
154
+ },
155
+ }
156
+ ]
157
+
158
+ outputs = model.run_task(
159
+ task="predict",
160
+ studies=predict_studies,
161
+ num_samples=4,
162
+ )
163
+
164
+ prediction_samples = outputs["results"][0]["samples"]
165
+ ```
166
+
167
+ ## Producer Workflow
168
+
169
+ To publish a runtime repo from a locally loaded experiment:
170
+
171
+ ```python
172
+ from sim_priors_pk.hub_runtime import push_loaded_model_runtime_bundle
173
+
174
+ runtime_repo_id = push_loaded_model_runtime_bundle(
175
+ experiment=experiment,
176
+ model_card_path=["hf_model_cards", "AICME-PK_Readme.md"],
177
+ )
178
+ ```
179
+
180
+ By default this creates a separate repo:
181
+
182
+ ```text
183
+ <namespace>/<hf_model_name>-runtime
184
+ ```
185
+
186
+ That keeps the native training artifact export and the consumer runtime export
187
+ separate.
sim_priors_pk/hub_runtime/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public helpers for the parallel Hugging Face runtime bundle path."""
2
+
3
+ from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
4
+ from sim_priors_pk.hub_runtime.modeling_sim_priors_pk import PKHubModel
5
+ from sim_priors_pk.hub_runtime.runtime_bundle import (
6
+ RuntimeBundleArtifacts,
7
+ build_runtime_bundle_dir,
8
+ default_runtime_repo_id,
9
+ push_loaded_model_runtime_bundle,
10
+ )
11
+
12
+ __all__ = [
13
+ "PKHubConfig",
14
+ "PKHubModel",
15
+ "RuntimeBundleArtifacts",
16
+ "build_runtime_bundle_dir",
17
+ "default_runtime_repo_id",
18
+ "push_loaded_model_runtime_bundle",
19
+ ]
sim_priors_pk/hub_runtime/configuration_sim_priors_pk.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face configuration for self-contained PK runtime bundles."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+ from sim_priors_pk.hub_runtime.runtime_contract import STUDY_JSON_IO_VERSION
10
+
11
+
12
+ class PKHubConfig(PretrainedConfig):
13
+ """Public Hub config describing a consumer-facing PK runtime bundle."""
14
+
15
+ model_type = "sim_priors_pk"
16
+
17
+ def __init__(
18
+ self,
19
+ architecture_name: Optional[str] = None,
20
+ experiment_type: str = "nodepk",
21
+ experiment_config: Optional[Dict[str, Any]] = None,
22
+ builder_config: Optional[Dict[str, Any]] = None,
23
+ supported_tasks: Optional[List[str]] = None,
24
+ default_task: Optional[str] = None,
25
+ io_schema_version: str = STUDY_JSON_IO_VERSION,
26
+ original_repo_id: Optional[str] = None,
27
+ runtime_repo_id: Optional[str] = None,
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(**kwargs)
31
+ self.architecture_name = architecture_name
32
+ self.experiment_type = experiment_type
33
+ self.experiment_config = dict(experiment_config or {})
34
+ self.builder_config = dict(builder_config or {})
35
+ self.supported_tasks = list(supported_tasks or [])
36
+ self.default_task = default_task or (self.supported_tasks[0] if self.supported_tasks else None)
37
+ self.io_schema_version = io_schema_version
38
+ self.original_repo_id = original_repo_id
39
+ self.runtime_repo_id = runtime_repo_id
40
+
41
+
42
+ __all__ = ["PKHubConfig"]
sim_priors_pk/hub_runtime/modeling_sim_priors_pk.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face AutoModel wrapper for consumer-facing PK runtime bundles."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Optional, Sequence, Union
6
+
7
+ import torch
8
+ from transformers import PreTrainedModel
9
+
10
+ from sim_priors_pk.data.data_empirical.json_schema import StudyJSON
11
+ from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
12
+ from sim_priors_pk.hub_runtime.runtime_contract import (
13
+ RuntimeBuilderConfig,
14
+ build_batch_from_studies,
15
+ infer_supported_tasks,
16
+ instantiate_backbone_from_hub_config,
17
+ normalize_studies_input,
18
+ split_runtime_samples,
19
+ validate_studies_for_task,
20
+ )
21
+ from sim_priors_pk.models.amortized_inference.generative_pk import (
22
+ NewGenerativeMixin,
23
+ NewPredictiveMixin,
24
+ )
25
+
26
+
27
+ class PKHubModel(PreTrainedModel):
28
+ """Thin wrapper exposing a stable StudyJSON runtime API on top of PK models."""
29
+
30
+ config_class = PKHubConfig
31
+ base_model_prefix = "backbone"
32
+
33
+ def __init__(self, config: PKHubConfig, backbone: Optional[torch.nn.Module] = None) -> None:
34
+ super().__init__(config)
35
+ self.backbone = backbone if backbone is not None else instantiate_backbone_from_hub_config(config)
36
+ self.backbone.eval()
37
+
38
+ def forward(self, *args, **kwargs):
39
+ """Delegate raw forward calls to the wrapped PK backbone."""
40
+
41
+ return self.backbone(*args, **kwargs)
42
+
43
+ @property
44
+ def supported_tasks(self) -> Sequence[str]:
45
+ """Tasks supported by this runtime model."""
46
+
47
+ return tuple(getattr(self.config, "supported_tasks", []) or infer_supported_tasks(self.backbone))
48
+
49
+ @torch.inference_mode()
50
+ def run_task(
51
+ self,
52
+ *,
53
+ task: str,
54
+ studies: Union[StudyJSON, Sequence[StudyJSON]],
55
+ num_samples: int = 1,
56
+ **kwargs: Any,
57
+ ) -> Dict[str, Any]:
58
+ """Run the public StudyJSON inference contract for the requested task."""
59
+
60
+ supported_tasks = list(self.supported_tasks)
61
+ if task not in supported_tasks:
62
+ raise ValueError(
63
+ f"Unsupported task {task!r}. Supported tasks: {supported_tasks or 'none'}."
64
+ )
65
+ if int(num_samples) < 1:
66
+ raise ValueError("num_samples must be >= 1.")
67
+
68
+ canonical_studies = normalize_studies_input(studies)
69
+ builder_config = RuntimeBuilderConfig.from_dict(self.config.builder_config)
70
+ validate_studies_for_task(canonical_studies, task=task, builder_config=builder_config)
71
+
72
+ experiment_config_payload = getattr(self.config, "experiment_config", {})
73
+ meta_dosing_payload = experiment_config_payload.get("dosing", {})
74
+ batch = build_batch_from_studies(
75
+ canonical_studies,
76
+ builder_config=builder_config,
77
+ meta_dosing=self.backbone.meta_dosing.__class__(**meta_dosing_payload)
78
+ if meta_dosing_payload
79
+ else self.backbone.meta_dosing,
80
+ )
81
+ batch = batch.to(self.device)
82
+
83
+ if task == "generate":
84
+ if not isinstance(self.backbone, NewGenerativeMixin):
85
+ raise ValueError(f"Backbone {type(self.backbone).__name__} does not support generate.")
86
+ output_studies = self.backbone.sample_new_individuals_to_studyjson(
87
+ batch,
88
+ sample_size=int(num_samples),
89
+ num_steps=kwargs.get("num_steps"),
90
+ )
91
+ elif task == "predict":
92
+ if not isinstance(self.backbone, NewPredictiveMixin):
93
+ raise ValueError(f"Backbone {type(self.backbone).__name__} does not support predict.")
94
+ output_studies = self.backbone.sample_individual_prediction_from_batch_list_to_studyjson(
95
+ [batch],
96
+ sample_size=int(num_samples),
97
+ )[0]
98
+ else:
99
+ raise ValueError(f"Unsupported task {task!r}.")
100
+
101
+ results = [
102
+ {
103
+ "input_index": index,
104
+ "samples": split_runtime_samples(task, study),
105
+ }
106
+ for index, study in enumerate(output_studies)
107
+ ]
108
+
109
+ return {
110
+ "task": task,
111
+ "io_schema_version": self.config.io_schema_version,
112
+ "model_info": {
113
+ "architecture_name": self.config.architecture_name,
114
+ "experiment_type": self.config.experiment_type,
115
+ "supported_tasks": supported_tasks,
116
+ "runtime_repo_id": self.config.runtime_repo_id,
117
+ "original_repo_id": self.config.original_repo_id,
118
+ },
119
+ "results": results,
120
+ }
121
+
122
+
123
+ __all__ = ["PKHubModel"]
sim_priors_pk/hub_runtime/runtime_bundle.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Manual export path for consumer-facing Hugging Face runtime bundles."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ import shutil
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from tempfile import TemporaryDirectory
10
+ from typing import Optional, Sequence
11
+
12
+ import torch
13
+ from huggingface_hub import HfApi, create_repo
14
+
15
+ from sim_priors_pk import config_dir, project_dir
16
+ from sim_priors_pk.hub_runtime.configuration_sim_priors_pk import PKHubConfig
17
+ from sim_priors_pk.hub_runtime.modeling_sim_priors_pk import PKHubModel
18
+ from sim_priors_pk.hub_runtime.runtime_contract import (
19
+ build_runtime_config_payload,
20
+ resolve_model_card_text,
21
+ runtime_readme_text,
22
+ )
23
+
24
+ ROOT_CONFIGURATION_FILENAME = "configuration_sim_priors_pk.py"
25
+ ROOT_MODELING_FILENAME = "modeling_sim_priors_pk.py"
26
+ _HF_TOKEN_PATTERN = re.compile(r"hf_[A-Za-z0-9]{20,}")
27
+ _COMET_KEY_ASSIGNMENT_PATTERN = re.compile(r"(COMET_API_KEY\s*=\s*)(['\"]).*?\2")
28
+ _HF_KEY_ASSIGNMENT_PATTERN = re.compile(r"(HF_KEYS\s*=\s*)(['\"]).*?\2")
29
+
30
+
31
+ @dataclass
32
+ class RuntimeBundleArtifacts:
33
+ """Return metadata for a staged runtime bundle."""
34
+
35
+ bundle_dir: Path
36
+ runtime_repo_id: str
37
+ original_repo_id: Optional[str]
38
+ readme_path: Path
39
+
40
+
41
+ def default_runtime_repo_id(experiment, *, suffix: str = "-runtime") -> str:
42
+ """Resolve the default runtime bundle repo id for a loaded experiment."""
43
+
44
+ if getattr(experiment, "exp_config", None) is None:
45
+ raise RuntimeError("Experiment config is not loaded.")
46
+ if getattr(experiment, "hf_token", None) is None:
47
+ raise RuntimeError(
48
+ "No Hugging Face token available. Set hugging_face_token in the config or KEYS.txt."
49
+ )
50
+
51
+ user = HfApi().whoami(token=experiment.hf_token)["name"]
52
+ return f"{user}/{experiment.exp_config.hf_model_name}{suffix}"
53
+
54
+
55
+ def _default_original_repo_id(experiment) -> Optional[str]:
56
+ """Infer the legacy/native Hub repo id if enough metadata is available."""
57
+
58
+ if getattr(experiment, "exp_config", None) is None:
59
+ return None
60
+ if getattr(experiment, "hf_token", None) is None:
61
+ return None
62
+ user = HfApi().whoami(token=experiment.hf_token)["name"]
63
+ return f"{user}/{experiment.exp_config.hf_model_name}"
64
+
65
+
66
+ def _validate_loaded_experiment(experiment) -> None:
67
+ """Ensure the loaded experiment has the minimum state needed for manual export."""
68
+
69
+ if getattr(experiment, "model", None) is None:
70
+ raise RuntimeError("Experiment model is not loaded.")
71
+ if getattr(experiment, "exp_config", None) is None:
72
+ raise RuntimeError("Experiment config is not loaded.")
73
+ if getattr(experiment, "experiment_dir", None) is None:
74
+ raise RuntimeError("Experiment directory is required before pushing.")
75
+ if getattr(experiment, "hf_token", None) is None:
76
+ raise RuntimeError(
77
+ "No Hugging Face token available. Set hugging_face_token in the config or KEYS.txt."
78
+ )
79
+
80
+
81
+ def _copy_runtime_support_files(bundle_dir: Path) -> None:
82
+ """Copy the local package and root remote-code entrypoints into the bundle."""
83
+
84
+ package_src = project_dir / "sim_priors_pk"
85
+ package_dst = bundle_dir / "sim_priors_pk"
86
+ shutil.copytree(package_src, package_dst, dirs_exist_ok=True, ignore=shutil.ignore_patterns("__pycache__"))
87
+
88
+ root_config_src = package_src / "hub_runtime" / ROOT_CONFIGURATION_FILENAME
89
+ root_modeling_src = package_src / "hub_runtime" / ROOT_MODELING_FILENAME
90
+ shutil.copy2(root_config_src, bundle_dir / ROOT_CONFIGURATION_FILENAME)
91
+ shutil.copy2(root_modeling_src, bundle_dir / ROOT_MODELING_FILENAME)
92
+
93
+ for extra_name in ("requirements.txt", "LICENSE"):
94
+ extra_src = project_dir / extra_name
95
+ if extra_src.is_file():
96
+ shutil.copy2(extra_src, bundle_dir / extra_name)
97
+
98
+ _scrub_runtime_bundle_secrets(bundle_dir)
99
+ _validate_no_hf_secrets(bundle_dir)
100
+
101
+
102
+ def _scrub_runtime_bundle_secrets(bundle_dir: Path) -> None:
103
+ """Remove token-like secrets from copied source files before Hub upload."""
104
+
105
+ candidate_files = [
106
+ *bundle_dir.rglob("*.py"),
107
+ *bundle_dir.rglob("*.md"),
108
+ *bundle_dir.rglob("*.txt"),
109
+ *bundle_dir.rglob("*.json"),
110
+ ]
111
+ for path in candidate_files:
112
+ try:
113
+ text = path.read_text(encoding="utf-8")
114
+ except UnicodeDecodeError:
115
+ continue
116
+
117
+ updated = text
118
+ updated = _HF_TOKEN_PATTERN.sub("hf_REDACTED", updated)
119
+ updated = _COMET_KEY_ASSIGNMENT_PATTERN.sub(r"\1\2REDACTED\2", updated)
120
+ updated = _HF_KEY_ASSIGNMENT_PATTERN.sub(r"\1\2REDACTED\2", updated)
121
+
122
+ if path.as_posix().endswith("sim_priors_pk/utils/__init__.py"):
123
+ updated = (
124
+ "PASCAL_BASE_DIR = ''\n"
125
+ "NERSC_BASE_DIR = ''\n"
126
+ "NERSC_EXPERIMENT_DIR = ''\n"
127
+ "COMET_API_KEY = 'REDACTED'\n"
128
+ "HF_KEYS = 'REDACTED'\n"
129
+ "WORKSPACE = ''\n"
130
+ "PROJECT = ''\n"
131
+ )
132
+
133
+ if updated != text:
134
+ path.write_text(updated, encoding="utf-8")
135
+
136
+
137
+ def _validate_no_hf_secrets(bundle_dir: Path) -> None:
138
+ """Fail fast if token-like Hugging Face secrets remain after scrubbing."""
139
+
140
+ offending_files: list[str] = []
141
+ for path in bundle_dir.rglob("*"):
142
+ if not path.is_file():
143
+ continue
144
+ if path.suffix not in {".py", ".md", ".txt", ".json"}:
145
+ continue
146
+ try:
147
+ text = path.read_text(encoding="utf-8")
148
+ except UnicodeDecodeError:
149
+ continue
150
+ if _HF_TOKEN_PATTERN.search(text):
151
+ offending_files.append(str(path.relative_to(bundle_dir)))
152
+
153
+ if offending_files:
154
+ raise RuntimeError(
155
+ "Refusing to upload runtime bundle because token-like Hugging Face secrets "
156
+ f"remain after scrubbing: {offending_files}"
157
+ )
158
+
159
+
160
+ def build_runtime_bundle_dir(
161
+ *,
162
+ experiment,
163
+ bundle_dir: Path,
164
+ model_card_path: Optional[Sequence[str]] = None,
165
+ hf_repo_id: Optional[str] = None,
166
+ original_repo_id: Optional[str] = None,
167
+ ) -> RuntimeBundleArtifacts:
168
+ """Stage a self-contained runtime bundle in ``bundle_dir`` without uploading it."""
169
+
170
+ _validate_loaded_experiment(experiment)
171
+ bundle_dir.mkdir(parents=True, exist_ok=True)
172
+
173
+ runtime_repo_id = hf_repo_id or default_runtime_repo_id(experiment)
174
+ native_repo_id = original_repo_id or _default_original_repo_id(experiment)
175
+
176
+ normalized_model_card_path = tuple(
177
+ model_card_path
178
+ if model_card_path is not None
179
+ else getattr(experiment.exp_config, "hf_model_card_path", ("hf_model_cards", "README.md"))
180
+ )
181
+ local_model_card_path = Path(config_dir).joinpath(*normalized_model_card_path)
182
+ base_model_card = resolve_model_card_text(local_model_card_path)
183
+
184
+ runtime_payload = build_runtime_config_payload(
185
+ backbone=experiment.model,
186
+ exp_config=experiment.exp_config,
187
+ original_repo_id=native_repo_id,
188
+ runtime_repo_id=runtime_repo_id,
189
+ )
190
+ runtime_config = PKHubConfig(
191
+ **runtime_payload,
192
+ auto_map={
193
+ "AutoConfig": f"{ROOT_CONFIGURATION_FILENAME[:-3]}.PKHubConfig",
194
+ "AutoModel": f"{ROOT_MODELING_FILENAME[:-3]}.PKHubModel",
195
+ },
196
+ architectures=["PKHubModel"],
197
+ )
198
+
199
+ runtime_model = PKHubModel(runtime_config, backbone=experiment.model)
200
+ state_dict = {name: tensor.detach().cpu() for name, tensor in runtime_model.state_dict().items()}
201
+ torch.save(state_dict, bundle_dir / "pytorch_model.bin")
202
+ runtime_config.save_pretrained(str(bundle_dir))
203
+
204
+ _copy_runtime_support_files(bundle_dir)
205
+
206
+ readme_text = runtime_readme_text(
207
+ base_model_card=base_model_card,
208
+ runtime_repo_id=runtime_repo_id,
209
+ original_repo_id=native_repo_id,
210
+ supported_tasks=runtime_config.supported_tasks,
211
+ default_task=runtime_config.default_task,
212
+ )
213
+ readme_path = bundle_dir / "README.md"
214
+ readme_path.write_text(readme_text, encoding="utf-8")
215
+
216
+ return RuntimeBundleArtifacts(
217
+ bundle_dir=bundle_dir,
218
+ runtime_repo_id=runtime_repo_id,
219
+ original_repo_id=native_repo_id,
220
+ readme_path=readme_path,
221
+ )
222
+
223
+
224
+ def push_loaded_model_runtime_bundle(
225
+ experiment,
226
+ model_card_path: Optional[Sequence[str]] = None,
227
+ hf_repo_id: Optional[str] = None,
228
+ alias_name: str = "runtime_bundle_hf",
229
+ commit_message: str = "manual runtime bundle push",
230
+ *,
231
+ original_repo_id: Optional[str] = None,
232
+ exist_ok: bool = True,
233
+ ) -> str:
234
+ """Build and upload the consumer-facing runtime bundle for a loaded experiment."""
235
+
236
+ _validate_loaded_experiment(experiment)
237
+ runtime_repo_id = hf_repo_id or default_runtime_repo_id(experiment)
238
+ create_repo(runtime_repo_id, exist_ok=exist_ok, token=experiment.hf_token)
239
+
240
+ bundle_root = Path(experiment.experiment_dir) / alias_name
241
+ bundle_root.mkdir(parents=True, exist_ok=True)
242
+
243
+ with TemporaryDirectory(dir=str(bundle_root), prefix="hf_runtime_bundle_") as temp_dir:
244
+ staged_dir = Path(temp_dir)
245
+ build_runtime_bundle_dir(
246
+ experiment=experiment,
247
+ bundle_dir=staged_dir,
248
+ model_card_path=model_card_path,
249
+ hf_repo_id=runtime_repo_id,
250
+ original_repo_id=original_repo_id,
251
+ )
252
+
253
+ api = HfApi(token=experiment.hf_token)
254
+ api.upload_folder(
255
+ folder_path=str(staged_dir),
256
+ repo_id=runtime_repo_id,
257
+ commit_message=commit_message,
258
+ token=experiment.hf_token,
259
+ )
260
+
261
+ return runtime_repo_id
262
+
263
+
264
+ __all__ = [
265
+ "RuntimeBundleArtifacts",
266
+ "build_runtime_bundle_dir",
267
+ "default_runtime_repo_id",
268
+ "push_loaded_model_runtime_bundle",
269
+ ]
sim_priors_pk/hub_runtime/runtime_contract.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared runtime-contract helpers for consumer-facing Hub bundles.
2
+
3
+ This module is imported both by the local exporter and by the copied package
4
+ inside the generated Hugging Face runtime bundle. Keep dependencies limited to
5
+ modules that are already required for model inference.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from copy import deepcopy
11
+ from dataclasses import asdict, dataclass
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, get_args, get_origin
14
+
15
+ import torch
16
+ from transformers import PretrainedConfig
17
+
18
+ from sim_priors_pk.config_classes.data_config import (
19
+ MetaDosingConfig,
20
+ MetaStudyConfig,
21
+ MixDataConfig,
22
+ ObservationsConfig,
23
+ SimpleMetaStudyConfig,
24
+ )
25
+ from sim_priors_pk.config_classes.diffusion_pk_config import DiffusionPKExperimentConfig
26
+ from sim_priors_pk.config_classes.flow_pk_config import FlowPKExperimentConfig, VectorFieldPKConfig
27
+ from sim_priors_pk.config_classes.node_pk_config import (
28
+ EncoderDecoderNetworkConfig,
29
+ NodePKExperimentConfig,
30
+ )
31
+ from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
32
+ from sim_priors_pk.config_classes.training_config import TrainingConfig
33
+ from sim_priors_pk.data.data_empirical.builder import EmpiricalBatchConfig, JSON2AICMEBuilder
34
+ from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON, canonicalize_study
35
+ from sim_priors_pk.data.data_generation.observations_classes import ObservationStrategyFactory
36
+ from sim_priors_pk.models import get_model_class
37
+ from sim_priors_pk.models.amortized_inference.generative_pk import (
38
+ NewGenerativeMixin,
39
+ NewPredictiveMixin,
40
+ )
41
+
42
+ SUPPORTED_RUNTIME_ARCHITECTURES = {
43
+ "AICMEPK",
44
+ "ContextVAEPK",
45
+ "FlowPK",
46
+ "PredictionPK",
47
+ }
48
+ STUDY_JSON_IO_VERSION = "studyjson-v1"
49
+
50
+
51
+ @dataclass
52
+ class RuntimeBuilderConfig:
53
+ """Fixed builder capacities serialized into the Hub runtime config."""
54
+
55
+ max_context_individuals: int
56
+ max_target_individuals: int
57
+ max_context_observations: int
58
+ max_target_observations: int
59
+ max_context_remaining: int
60
+ max_target_remaining: int
61
+
62
+ def to_dict(self) -> Dict[str, int]:
63
+ """Return a JSON-serializable representation."""
64
+
65
+ return asdict(self)
66
+
67
+ @classmethod
68
+ def from_dict(cls, payload: Mapping[str, Any]) -> "RuntimeBuilderConfig":
69
+ """Instantiate the builder capacities from serialized config payload."""
70
+
71
+ return cls(
72
+ max_context_individuals=int(payload["max_context_individuals"]),
73
+ max_target_individuals=int(payload["max_target_individuals"]),
74
+ max_context_observations=int(payload["max_context_observations"]),
75
+ max_target_observations=int(payload["max_target_observations"]),
76
+ max_context_remaining=int(payload["max_context_remaining"]),
77
+ max_target_remaining=int(payload["max_target_remaining"]),
78
+ )
79
+
80
+ def to_empirical_batch_config(self, *, max_databatch_size: int) -> EmpiricalBatchConfig:
81
+ """Translate runtime capacities to the builder used by StudyJSON IO."""
82
+
83
+ return EmpiricalBatchConfig(
84
+ max_databatch_size=int(max_databatch_size),
85
+ max_individuals=max(self.max_context_individuals, self.max_target_individuals),
86
+ max_observations=max(self.max_context_observations, self.max_target_observations),
87
+ max_remaining=max(self.max_context_remaining, self.max_target_remaining),
88
+ max_context_individuals=self.max_context_individuals,
89
+ max_target_individuals=self.max_target_individuals,
90
+ max_context_observations=self.max_context_observations,
91
+ max_target_observations=self.max_target_observations,
92
+ max_context_remaining=self.max_context_remaining,
93
+ max_target_remaining=self.max_target_remaining,
94
+ )
95
+
96
+
97
+ def _coerce_annotation(annotation: Any, value: Any) -> Any:
98
+ """Best-effort coercion of JSON-loaded values into dataclass field types."""
99
+
100
+ if value is None:
101
+ return None
102
+
103
+ origin = get_origin(annotation)
104
+ args = get_args(annotation)
105
+
106
+ if origin is Union:
107
+ non_none = [arg for arg in args if arg is not type(None)]
108
+ for candidate in non_none:
109
+ if candidate in (dict, Dict, Any, Mapping):
110
+ continue
111
+ try:
112
+ return _coerce_annotation(candidate, value)
113
+ except Exception:
114
+ continue
115
+ return value
116
+
117
+ if origin in (list, List, Sequence):
118
+ (inner_type,) = args if args else (Any,)
119
+ return [_coerce_annotation(inner_type, item) for item in value]
120
+
121
+ if origin in (tuple,):
122
+ if not args:
123
+ return tuple(value)
124
+ if len(args) == 2 and args[1] is Ellipsis:
125
+ return tuple(_coerce_annotation(args[0], item) for item in value)
126
+ return tuple(_coerce_annotation(inner, item) for inner, item in zip(args, value))
127
+
128
+ if origin in (dict, Dict, Mapping):
129
+ return dict(value)
130
+
131
+ if annotation is Any:
132
+ return value
133
+
134
+ if annotation is MetaStudyConfig and isinstance(value, Mapping) and value.get("simple_mode"):
135
+ return SimpleMetaStudyConfig(**dict(value))
136
+
137
+ if hasattr(annotation, "__dataclass_fields__") and isinstance(value, Mapping):
138
+ kwargs = {}
139
+ for field_name, field_def in annotation.__dataclass_fields__.items():
140
+ if field_name in value:
141
+ kwargs[field_name] = _coerce_annotation(field_def.type, value[field_name])
142
+ return annotation(**kwargs)
143
+
144
+ return value
145
+
146
+
147
+ def _rebuild_node_config(payload: Mapping[str, Any]) -> NodePKExperimentConfig:
148
+ """Reconstruct a ``NodePKExperimentConfig`` from serialized dict content."""
149
+
150
+ return NodePKExperimentConfig(
151
+ experiment_type=str(payload.get("experiment_type", "nodepk")).lower(),
152
+ name_str=str(payload.get("name_str", "NodePK")),
153
+ comet_ai_key=payload.get("comet_ai_key"),
154
+ experiment_name=str(payload.get("experiment_name", "node_pk_compartments")),
155
+ hugging_face_token=payload.get("hugging_face_token"),
156
+ upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)),
157
+ hf_model_name=str(payload.get("hf_model_name", "NodePK_runtime")),
158
+ hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))),
159
+ tags=list(payload.get("tags", [])),
160
+ experiment_indentifier=payload.get("experiment_indentifier"),
161
+ my_results_path=payload.get("my_results_path"),
162
+ experiment_dir=payload.get("experiment_dir"),
163
+ verbose=bool(payload.get("verbose", False)),
164
+ run_index=int(payload.get("run_index", 0)),
165
+ debug_test=bool(payload.get("debug_test", False)),
166
+ network=_coerce_annotation(EncoderDecoderNetworkConfig, payload.get("network", {})),
167
+ mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})),
168
+ context_observations=_coerce_annotation(
169
+ ObservationsConfig, payload.get("context_observations", {})
170
+ ),
171
+ target_observations=_coerce_annotation(
172
+ ObservationsConfig, payload.get("target_observations", {})
173
+ ),
174
+ meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})),
175
+ dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})),
176
+ train=_coerce_annotation(TrainingConfig, payload.get("train", {})),
177
+ )
178
+
179
+
180
+ def _rebuild_flow_config(payload: Mapping[str, Any]) -> FlowPKExperimentConfig:
181
+ """Reconstruct a ``FlowPKExperimentConfig`` from serialized dict content."""
182
+
183
+ return FlowPKExperimentConfig(
184
+ experiment_type=str(payload.get("experiment_type", "flowpk")).lower(),
185
+ name_str=str(payload.get("name_str", "FlowPK")),
186
+ comet_ai_key=payload.get("comet_ai_key"),
187
+ experiment_name=str(payload.get("experiment_name", "flow_pk_compartments")),
188
+ hugging_face_token=payload.get("hugging_face_token"),
189
+ upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)),
190
+ hf_model_name=str(payload.get("hf_model_name", "FlowPK_runtime")),
191
+ hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))),
192
+ tags=list(payload.get("tags", [])),
193
+ experiment_indentifier=payload.get("experiment_indentifier"),
194
+ my_results_path=payload.get("my_results_path"),
195
+ experiment_dir=payload.get("experiment_dir"),
196
+ verbose=bool(payload.get("verbose", False)),
197
+ run_index=int(payload.get("run_index", 0)),
198
+ debug_test=bool(payload.get("debug_test", False)),
199
+ flow_num_steps=int(payload.get("flow_num_steps", 50)),
200
+ vector_field=_coerce_annotation(VectorFieldPKConfig, payload.get("vector_field", {})),
201
+ source_process=_coerce_annotation(SourceProcessConfig, payload.get("source_process", {})),
202
+ mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})),
203
+ context_observations=_coerce_annotation(
204
+ ObservationsConfig, payload.get("context_observations", {})
205
+ ),
206
+ target_observations=_coerce_annotation(
207
+ ObservationsConfig, payload.get("target_observations", {})
208
+ ),
209
+ meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})),
210
+ dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})),
211
+ train=_coerce_annotation(TrainingConfig, payload.get("train", {})),
212
+ )
213
+
214
+
215
+ def _rebuild_diffusion_config(payload: Mapping[str, Any]) -> DiffusionPKExperimentConfig:
216
+ """Reconstruct a ``DiffusionPKExperimentConfig`` from serialized dict content."""
217
+
218
+ return DiffusionPKExperimentConfig(
219
+ experiment_type=str(payload.get("experiment_type", "diffusionpk")).lower(),
220
+ name_str=str(payload.get("name_str", "ContinuousDiffusionPK")),
221
+ diffusion_type=str(payload.get("diffusion_type", "continuous")),
222
+ comet_ai_key=payload.get("comet_ai_key"),
223
+ experiment_name=str(payload.get("experiment_name", "diffusion_pk_compartments")),
224
+ hugging_face_token=payload.get("hugging_face_token"),
225
+ upload_to_hf_hub=bool(payload.get("upload_to_hf_hub", False)),
226
+ hf_model_name=str(payload.get("hf_model_name", "DiffusionPK_runtime")),
227
+ hf_model_card_path=tuple(payload.get("hf_model_card_path", ("hf_model_cards", "README.md"))),
228
+ tags=list(payload.get("tags", [])),
229
+ experiment_indentifier=payload.get("experiment_indentifier"),
230
+ my_results_path=payload.get("my_results_path"),
231
+ experiment_dir=payload.get("experiment_dir"),
232
+ verbose=bool(payload.get("verbose", False)),
233
+ run_index=int(payload.get("run_index", 0)),
234
+ debug_test=bool(payload.get("debug_test", False)),
235
+ predict_gaussian_noise=bool(payload.get("predict_gaussian_noise", True)),
236
+ network=_coerce_annotation(EncoderDecoderNetworkConfig, payload.get("network", {})),
237
+ source_process=_coerce_annotation(SourceProcessConfig, payload.get("source_process", {})),
238
+ mix_data=_coerce_annotation(MixDataConfig, payload.get("mix_data", {})),
239
+ context_observations=_coerce_annotation(
240
+ ObservationsConfig, payload.get("context_observations", {})
241
+ ),
242
+ target_observations=_coerce_annotation(
243
+ ObservationsConfig, payload.get("target_observations", {})
244
+ ),
245
+ meta_study=_coerce_annotation(MetaStudyConfig, payload.get("meta_study", {})),
246
+ dosing=_coerce_annotation(MetaDosingConfig, payload.get("dosing", {})),
247
+ train=_coerce_annotation(TrainingConfig, payload.get("train", {})),
248
+ )
249
+
250
+
251
+ def rebuild_experiment_config(
252
+ payload: Mapping[str, Any],
253
+ ) -> Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig]:
254
+ """Rebuild the serialized experiment config stored in the Hub config."""
255
+
256
+ experiment_type = str(payload.get("experiment_type", "nodepk")).lower()
257
+ if experiment_type == "nodepk":
258
+ return _rebuild_node_config(payload)
259
+ if experiment_type == "flowpk":
260
+ return _rebuild_flow_config(payload)
261
+ if experiment_type == "diffusionpk":
262
+ return _rebuild_diffusion_config(payload)
263
+ raise ValueError(f"Unsupported experiment_type for runtime bundle: {experiment_type!r}.")
264
+
265
+
266
+ def compute_runtime_builder_config(
267
+ exp_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig],
268
+ ) -> RuntimeBuilderConfig:
269
+ """Compute fixed empirical StudyJSON capacities from the experiment config."""
270
+
271
+ context_strategy = ObservationStrategyFactory.from_config(
272
+ exp_config.context_observations,
273
+ exp_config.meta_study,
274
+ )
275
+ target_strategy = ObservationStrategyFactory.from_config(
276
+ exp_config.target_observations,
277
+ exp_config.meta_study,
278
+ )
279
+ ctx_obs_cap, ctx_rem_cap = context_strategy.get_shapes()
280
+ tgt_obs_cap, tgt_rem_cap = target_strategy.get_shapes()
281
+
282
+ max_context_individuals = int(exp_config.meta_study.num_individuals_range[-1])
283
+ max_target_individuals = int(getattr(exp_config.mix_data, "n_of_target_individuals", 1))
284
+ if max_target_individuals < 0:
285
+ raise ValueError("n_of_target_individuals must be >= 0 for Hub runtime export.")
286
+
287
+ return RuntimeBuilderConfig(
288
+ max_context_individuals=max_context_individuals,
289
+ max_target_individuals=max_target_individuals,
290
+ max_context_observations=int(ctx_obs_cap),
291
+ max_target_observations=int(tgt_obs_cap),
292
+ max_context_remaining=int(ctx_rem_cap),
293
+ max_target_remaining=int(tgt_rem_cap),
294
+ )
295
+
296
+
297
+ def infer_supported_tasks(backbone: torch.nn.Module) -> List[str]:
298
+ """Infer the public task surface supported by the wrapped model."""
299
+
300
+ tasks: List[str] = []
301
+ if isinstance(backbone, NewGenerativeMixin):
302
+ tasks.append("generate")
303
+ if isinstance(backbone, NewPredictiveMixin):
304
+ tasks.append("predict")
305
+ return tasks
306
+
307
+
308
+ def validate_runtime_architecture(backbone: torch.nn.Module) -> str:
309
+ """Ensure the loaded architecture is supported by the runtime bundle v1."""
310
+
311
+ architecture_name = backbone.__class__.__name__
312
+ if architecture_name not in SUPPORTED_RUNTIME_ARCHITECTURES:
313
+ raise ValueError(
314
+ "Runtime Hub export only supports "
315
+ f"{sorted(SUPPORTED_RUNTIME_ARCHITECTURES)}, got {architecture_name!r}."
316
+ )
317
+ return architecture_name
318
+
319
+
320
+ def build_runtime_config_payload(
321
+ *,
322
+ backbone: torch.nn.Module,
323
+ exp_config: Union[NodePKExperimentConfig, FlowPKExperimentConfig, DiffusionPKExperimentConfig],
324
+ original_repo_id: Optional[str],
325
+ runtime_repo_id: Optional[str],
326
+ ) -> Dict[str, Any]:
327
+ """Build the serializable fields stored in the Hub config."""
328
+
329
+ architecture_name = validate_runtime_architecture(backbone)
330
+ supported_tasks = infer_supported_tasks(backbone)
331
+ if not supported_tasks:
332
+ raise ValueError(f"Model {architecture_name!r} does not expose runtime tasks.")
333
+
334
+ builder_config = compute_runtime_builder_config(exp_config)
335
+ return {
336
+ "architecture_name": architecture_name,
337
+ "experiment_type": str(getattr(exp_config, "experiment_type", "nodepk")).lower(),
338
+ "experiment_config": asdict(exp_config),
339
+ "builder_config": builder_config.to_dict(),
340
+ "supported_tasks": supported_tasks,
341
+ "default_task": supported_tasks[0],
342
+ "io_schema_version": STUDY_JSON_IO_VERSION,
343
+ "original_repo_id": original_repo_id,
344
+ "runtime_repo_id": runtime_repo_id,
345
+ }
346
+
347
+
348
+ def instantiate_backbone_from_hub_config(config: PretrainedConfig) -> torch.nn.Module:
349
+ """Rebuild the internal PK model represented by the public Hub wrapper."""
350
+
351
+ experiment_config_payload = getattr(config, "experiment_config", None)
352
+ if not isinstance(experiment_config_payload, Mapping):
353
+ raise ValueError("Hub config is missing the serialized experiment_config payload.")
354
+ exp_config = rebuild_experiment_config(experiment_config_payload)
355
+ model_cls = get_model_class(exp_config)
356
+ backbone = model_cls(exp_config)
357
+ backbone.eval()
358
+ return backbone
359
+
360
+
361
+ def normalize_studies_input(
362
+ studies: Union[StudyJSON, Sequence[StudyJSON]],
363
+ ) -> List[StudyJSON]:
364
+ """Normalize runtime input to a mutable list of canonicalized studies."""
365
+
366
+ if isinstance(studies, Mapping):
367
+ raw_studies = [dict(studies)]
368
+ else:
369
+ raw_studies = [dict(study) for study in studies]
370
+ return [canonicalize_study(study, drop_tgt_too_few=False) for study in raw_studies]
371
+
372
+
373
+ def validate_studies_for_task(
374
+ studies: Sequence[StudyJSON],
375
+ *,
376
+ task: str,
377
+ builder_config: RuntimeBuilderConfig,
378
+ ) -> None:
379
+ """Validate task semantics and reject inputs that exceed runtime capacities."""
380
+
381
+ for study_idx, study in enumerate(studies):
382
+ context = list(study.get("context", []))
383
+ target = list(study.get("target", []))
384
+
385
+ if task == "generate":
386
+ if not context:
387
+ raise ValueError("`generate` requires at least one context individual per study.")
388
+ if target:
389
+ raise ValueError("`generate` expects target to be empty in the input StudyJSON.")
390
+ elif task == "predict":
391
+ if not target:
392
+ raise ValueError("`predict` requires at least one target individual per study.")
393
+ else:
394
+ raise ValueError(f"Unsupported task {task!r}.")
395
+
396
+ if len(context) > builder_config.max_context_individuals:
397
+ raise ValueError(
398
+ f"Study {study_idx} exceeds context individual capacity "
399
+ f"({len(context)} > {builder_config.max_context_individuals})."
400
+ )
401
+ if len(target) > builder_config.max_target_individuals:
402
+ raise ValueError(
403
+ f"Study {study_idx} exceeds target individual capacity "
404
+ f"({len(target)} > {builder_config.max_target_individuals})."
405
+ )
406
+
407
+ _validate_individual_block(
408
+ study_idx=study_idx,
409
+ block_name="context",
410
+ individuals=context,
411
+ max_observations=builder_config.max_context_observations,
412
+ max_remaining=builder_config.max_context_remaining,
413
+ )
414
+ _validate_individual_block(
415
+ study_idx=study_idx,
416
+ block_name="target",
417
+ individuals=target,
418
+ max_observations=builder_config.max_target_observations,
419
+ max_remaining=builder_config.max_target_remaining,
420
+ )
421
+
422
+
423
+ def _validate_individual_block(
424
+ *,
425
+ study_idx: int,
426
+ block_name: str,
427
+ individuals: Sequence[IndividualJSON],
428
+ max_observations: int,
429
+ max_remaining: int,
430
+ ) -> None:
431
+ """Reject studies that would otherwise be truncated by the empirical builder."""
432
+
433
+ for ind_idx, individual in enumerate(individuals):
434
+ obs_len = len(individual.get("observations", []))
435
+ rem_len = len(individual.get("remaining", []))
436
+ if obs_len > max_observations:
437
+ raise ValueError(
438
+ f"Study {study_idx} {block_name}[{ind_idx}] exceeds observation capacity "
439
+ f"({obs_len} > {max_observations})."
440
+ )
441
+ if rem_len > max_remaining:
442
+ raise ValueError(
443
+ f"Study {study_idx} {block_name}[{ind_idx}] exceeds remaining capacity "
444
+ f"({rem_len} > {max_remaining})."
445
+ )
446
+
447
+
448
+ def build_batch_from_studies(
449
+ studies: Sequence[StudyJSON],
450
+ *,
451
+ builder_config: RuntimeBuilderConfig,
452
+ meta_dosing: MetaDosingConfig,
453
+ ):
454
+ """Convert canonical studies into the internal PK databatch representation."""
455
+
456
+ builder = JSON2AICMEBuilder(
457
+ builder_config.to_empirical_batch_config(max_databatch_size=max(1, len(studies)))
458
+ )
459
+ return builder.build_one_aicmebatch(list(studies), meta_dosing)
460
+
461
+
462
+ def split_runtime_samples(task: str, study: StudyJSON) -> List[StudyJSON]:
463
+ """Convert model-specific StudyJSON outputs into per-sample StudyJSONs."""
464
+
465
+ if task == "generate":
466
+ return _split_generate_samples(study)
467
+ if task == "predict":
468
+ return _split_predict_samples(study)
469
+ raise ValueError(f"Unsupported task {task!r}.")
470
+
471
+
472
+ def _split_generate_samples(study: StudyJSON) -> List[StudyJSON]:
473
+ """Split generated target individuals into one StudyJSON per sample."""
474
+
475
+ targets = list(study.get("target", []))
476
+ if not targets:
477
+ return [deepcopy(study)]
478
+
479
+ split: List[StudyJSON] = []
480
+ for target in targets:
481
+ split.append(
482
+ {
483
+ "context": deepcopy(study.get("context", [])),
484
+ "target": [deepcopy(target)],
485
+ "meta_data": deepcopy(study.get("meta_data", {})),
486
+ }
487
+ )
488
+ return split
489
+
490
+
491
+ def _split_predict_samples(study: StudyJSON) -> List[StudyJSON]:
492
+ """Split target prediction samples into one StudyJSON per sample index."""
493
+
494
+ targets = list(study.get("target", []))
495
+ if not targets:
496
+ return [deepcopy(study)]
497
+
498
+ sample_count = 0
499
+ for target in targets:
500
+ sample_count = max(sample_count, len(target.get("prediction_samples", [])))
501
+ if sample_count == 0:
502
+ return [deepcopy(study)]
503
+
504
+ split: List[StudyJSON] = []
505
+ for sample_idx in range(sample_count):
506
+ target_block: List[IndividualJSON] = []
507
+ for target in targets:
508
+ target_copy: IndividualJSON = deepcopy(target)
509
+ samples = list(target.get("prediction_samples", []))
510
+ if samples:
511
+ if sample_idx >= len(samples):
512
+ raise ValueError(
513
+ "All target individuals must expose the same number of prediction samples."
514
+ )
515
+ target_copy["prediction_samples"] = [deepcopy(samples[sample_idx])]
516
+ target_block.append(target_copy)
517
+
518
+ split.append(
519
+ {
520
+ "context": deepcopy(study.get("context", [])),
521
+ "target": target_block,
522
+ "meta_data": deepcopy(study.get("meta_data", {})),
523
+ }
524
+ )
525
+ return split
526
+
527
+
528
+ def runtime_readme_text(
529
+ *,
530
+ base_model_card: str,
531
+ runtime_repo_id: str,
532
+ original_repo_id: Optional[str],
533
+ supported_tasks: Sequence[str],
534
+ default_task: str,
535
+ ) -> str:
536
+ """Compose the README uploaded with the consumer-facing runtime bundle."""
537
+
538
+ original_line = (
539
+ f"- Native training/artifact repo: `{original_repo_id}`"
540
+ if original_repo_id
541
+ else "- Native training/artifact repo: not recorded"
542
+ )
543
+ tasks_literal = ", ".join(f"`{task}`" for task in supported_tasks)
544
+
545
+ usage = f"""
546
+
547
+ ## Runtime Bundle
548
+
549
+ This repository is the consumer-facing runtime bundle for this PK model.
550
+
551
+ - Runtime repo: `{runtime_repo_id}`
552
+ {original_line}
553
+ - Supported tasks: {tasks_literal}
554
+ - Default task: `{default_task}`
555
+ - Load path: `AutoModel.from_pretrained(..., trust_remote_code=True)`
556
+
557
+ ### Installation
558
+
559
+ You do **not** need to install `sim_priors_pk` to use this runtime bundle.
560
+
561
+ `transformers` is the public loading entrypoint, but `transformers` alone is
562
+ not sufficient because this is a PyTorch model with custom runtime code. A
563
+ reliable consumer environment is:
564
+
565
+ ```bash
566
+ pip install torch transformers huggingface_hub lightning datasets pandas torchtyping gpytorch pot torchdiffeq torchsde ruamel.yaml pyyaml
567
+ ```
568
+
569
+ ### Python Usage
570
+
571
+ ```python
572
+ from transformers import AutoModel
573
+
574
+ model = AutoModel.from_pretrained("{runtime_repo_id}", trust_remote_code=True)
575
+
576
+ studies = [
577
+ {{
578
+ "context": [
579
+ {{
580
+ "name_id": "ctx_0",
581
+ "observations": [0.2, 0.5, 0.3],
582
+ "observation_times": [0.5, 1.0, 2.0],
583
+ "dosing": [1.0],
584
+ "dosing_type": ["oral"],
585
+ "dosing_times": [0.0],
586
+ "dosing_name": ["oral"],
587
+ }}
588
+ ],
589
+ "target": [],
590
+ "meta_data": {{"study_name": "demo", "substance_name": "drug_x"}},
591
+ }}
592
+ ]
593
+
594
+ outputs = model.run_task(
595
+ task="{default_task}",
596
+ studies=studies,
597
+ num_samples=4,
598
+ )
599
+ print(outputs["results"][0]["samples"])
600
+ ```
601
+
602
+ ### Predictive Sampling
603
+
604
+ ```python
605
+ from transformers import AutoModel
606
+
607
+ model = AutoModel.from_pretrained("{runtime_repo_id}", trust_remote_code=True)
608
+
609
+ predict_studies = [
610
+ {{
611
+ "context": [
612
+ {{
613
+ "name_id": "ctx_0",
614
+ "observations": [0.2, 0.5, 0.3],
615
+ "observation_times": [0.5, 1.0, 2.0],
616
+ "dosing": [1.0],
617
+ "dosing_type": ["oral"],
618
+ "dosing_times": [0.0],
619
+ "dosing_name": ["oral"],
620
+ }}
621
+ ],
622
+ "target": [
623
+ {{
624
+ "name_id": "tgt_0",
625
+ "observations": [0.25, 0.31],
626
+ "observation_times": [0.5, 1.0],
627
+ "remaining": [0.0, 0.0, 0.0],
628
+ "remaining_times": [2.0, 4.0, 8.0],
629
+ "dosing": [1.0],
630
+ "dosing_type": ["oral"],
631
+ "dosing_times": [0.0],
632
+ "dosing_name": ["oral"],
633
+ }}
634
+ ],
635
+ "meta_data": {{"study_name": "demo", "substance_name": "drug_x"}},
636
+ }}
637
+ ]
638
+
639
+ outputs = model.run_task(
640
+ task="predict",
641
+ studies=predict_studies,
642
+ num_samples=4,
643
+ )
644
+ print(outputs["results"][0]["samples"][0]["target"][0]["prediction_samples"])
645
+ ```
646
+
647
+ ### Notes
648
+
649
+ - `trust_remote_code=True` is required because this model uses custom Hugging Face Hub runtime code.
650
+ - The consumer API is `transformers` + `run_task(...)`; the consumer does not need a local clone of this repository.
651
+ - This runtime bundle is intentionally separate from the native training export so you can evaluate both distribution paths in parallel.
652
+ """
653
+
654
+ return base_model_card.rstrip() + "\n" + usage.strip() + "\n"
655
+
656
+
657
+ def resolve_model_card_text(model_card_path: Path) -> str:
658
+ """Read and validate the model card that seeds the runtime README."""
659
+
660
+ if not model_card_path.is_file():
661
+ raise FileNotFoundError(f"Model card not found at: {model_card_path}")
662
+ return model_card_path.read_text(encoding="utf-8")
sim_priors_pk/metrics/__init__.py ADDED
File without changes
sim_priors_pk/metrics/pk_metrics.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import List,Tuple
4
+ from matplotlib import pyplot as plt
5
+ from torchtyping import TensorType, patch_typeguard
6
+ from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch
7
+ from scipy import stats
8
+ import torch
9
+ from typing import Tuple
10
+
11
+ Tensor = torch.Tensor # for brevity – keep your own alias if you prefer
12
+ import os
13
+
14
+ def ensure_folder_exists(folder_name: str):
15
+ if not os.path.exists(folder_name):
16
+ os.makedirs(folder_name)
17
+ print(f"✅ Created folder: {folder_name}")
18
+ else:
19
+ print(f"📁 Folder already exists: {folder_name}")
20
+
21
+ def combine_samples(
22
+ samples_list: list[TensorType["S", "B", "I", "T", 1]]
23
+ ) -> TensorType["S", "P", "T"]:
24
+ """
25
+ Given:
26
+ samples_list: list of length P, each tensor of shape [S, B, I, T, 1]
27
+ (here B = I = 1)
28
+
29
+ Returns:
30
+ combined: tensor of shape [S, P, T]
31
+ """
32
+ # 1) Extract the [S, T] slice from each sample (drop B=1, I=1, last dim=1)
33
+ # - s[:, 0, 0, :, 0] has shape [S, T]
34
+ squeezed: list[TensorType["S", "T"]] = [
35
+ s[:, 0, 0, :, 0]
36
+ for s in samples_list
37
+ ]
38
+
39
+ # 2) Stack along a new “permutation” axis P → [S, P, T]
40
+ combined: TensorType["S", "P", "T"] = torch.stack(squeezed, dim=1)
41
+ return combined
42
+
43
+ def extract_context_by_mask(
44
+ db: AICMECompartmentsDataBatch
45
+ ) -> Tuple[
46
+ List[TensorType["n_i"]], # context observations per compartment
47
+ List[TensorType["n_i"]] # context times per compartment
48
+ ]:
49
+ """
50
+ For B=1, from a single AICMECompartmentsDataBatch:
51
+ - db.context_obs: [1, c_ind, num_obs_c, 1]
52
+ - db.context_obs_time: [1, c_ind, num_obs_c, 1]
53
+ - db.context_obs_mask: [1, c_ind, num_obs_c]
54
+
55
+ Returns two lists of length c_ind:
56
+ obs_list[i].shape == (n_i,) selects those obs where mask==1
57
+ time_list[i].shape == (n_i,) selects the corresponding times
58
+ """
59
+ # Unpack and assert B=1
60
+ B, c_ind, num_obs_c, one = db.context_obs.shape
61
+ assert B == 1 and one == 1, f"Expected B=1 and last dim=1, got B={B}, last={one}"
62
+
63
+ # Drop the batch and singleton dims:
64
+ # [1, c_ind, num_obs_c, 1] → [c_ind, num_obs_c]
65
+ obs = db.context_obs.squeeze(0).squeeze(-1) # TensorType["c_ind", "num_obs_c"]
66
+ times = db.context_obs_time.squeeze(0).squeeze(-1) # TensorType["c_ind", "num_obs_c"]
67
+ mask = db.context_obs_mask.squeeze(0) # TensorType["c_ind", "num_obs_c"]
68
+
69
+ obs_list: List[torch.Tensor] = []
70
+ time_list: List[torch.Tensor] = []
71
+
72
+ for i in range(c_ind):
73
+ mi = mask[i].bool() # [num_obs_c]
74
+ obs_i = obs[i][mi] # [n_i]
75
+ times_i = times[i][mi] # [n_i]
76
+ obs_list.append(obs_i)
77
+ time_list.append(times_i)
78
+
79
+ return obs_list, time_list
80
+
81
+ def compute_pd(
82
+ y_obs : TensorType["I", "T"], # observed data
83
+ y_sim : TensorType["S", "I", "T"], # S simulated datasets
84
+ mask : TensorType["I", "T"], # True/1 = valid obs
85
+ ) -> TensorType["I", "T"]: # pd, NaN where mask == 0
86
+ """
87
+ NOTICE THAT THERE IS NO BATCH INDEX, this works only on individual substances
88
+
89
+ Prediction discrepancy (pd) — Eq. (4) Comets et al. 2008
90
+
91
+ Parameters
92
+ ----------
93
+ y_obs : [I, T] observed values (padding value doesn't matter,
94
+ because `mask` says which entries to trust)
95
+ y_sim : [S, I, T] S Monte-Carlo replicates generated from the model
96
+ mask : [I, T] binary mask — True at valid observation points
97
+
98
+ Returns
99
+ -------
100
+ pd : [I, T] empirical CDF value at (i,j); NaN where mask==0
101
+ """
102
+ S, I, T = y_sim.shape
103
+ assert y_obs.shape == (I, T), "y_obs must be [I,T]"
104
+ assert mask.shape == (I, T), "mask must be [I,T]"
105
+
106
+ # Expand y_obs to [S,I,T] so we can broadcast the < comparison
107
+ y_obs_exp = y_obs.unsqueeze(0).expand(S, -1, -1) # [S,I,T]
108
+
109
+ # δ_{ijk} = 1 if y_sim < y_obs else 0
110
+ delta = (y_sim < y_obs_exp).float() # [S,I,T]
111
+
112
+ # average over the S simulations → empirical CDF
113
+ pd = delta.mean(dim=0) # [I,T]
114
+
115
+ # put NaN where mask == 0 so the caller knows which are pads
116
+ pd = torch.where(mask.bool(), pd, torch.full_like(pd, float("nan")))
117
+
118
+ return pd
119
+
120
+ def sample_covariance_manual_torch(
121
+ X: TensorType["S", "Tv"] # simulations for one subject, S×Tᵥ
122
+ ):
123
+ """
124
+ Pure-Torch analogue of your NumPy helper.
125
+ Returns unbiased covariance [Tᵥ,Tᵥ] and mean vector [Tᵥ].
126
+ """
127
+ S, _ = X.shape
128
+ mean_vec = X.mean(dim=0) # [Tᵥ]
129
+ Xc = X - mean_vec # [S,Tᵥ]
130
+ cov = Xc.t() @ Xc / (S - 1) # [Tᵥ,Tᵥ]
131
+ return cov, mean_vec
132
+
133
+ def whiten_manual_torch_old(
134
+ X: TensorType["S", "Tv"], # data to whiten
135
+ eps: float = 1e-8 # ridge for numerical safety
136
+ ):
137
+ """
138
+ Manual whitening à la your NumPy code.
139
+ Returns whitened X and the whitening matrix W (Σ^{-1/2}).
140
+ """
141
+ cov, mean_vec = sample_covariance_manual_torch(X) # Σ, μ
142
+ eigvals, eigvecs = torch.linalg.eigh(cov + eps * torch.eye(cov.size(0), device=X.device))
143
+ D_inv_sqrt = torch.diag(torch.rsqrt(eigvals)) # diag(1/√λ)
144
+ W = eigvecs @ D_inv_sqrt @ eigvecs.t() # Σ^{-1/2}
145
+ X_white = (X - mean_vec) @ W # apply whitening
146
+ return X_white, W, mean_vec
147
+
148
+ def compute_npde_full_old(
149
+ y_obs: TensorType["I", "T"],
150
+ y_sim: TensorType["S", "I", "T"],
151
+ mask : TensorType["I", "T"],
152
+ eps : float = 1e-8
153
+ ) -> TensorType["I", "T"]:
154
+ """
155
+ Full NPDE with within-subject decorrelation (Σ^{-1/2}) computed
156
+ **exactly** as in your NumPy snippet.
157
+
158
+ NOTICE THAT THERE IS NO BATCH INDEX, this works only on individual substances
159
+
160
+ Shapes
161
+ ------
162
+ y_obs : [I,T] observations (padding allowed)
163
+ y_sim : [S,I,T] S Monte-Carlo replicates
164
+ mask : [I,T] True/1 = valid time-points
165
+ """
166
+ S, I, T = y_sim.shape
167
+ N01 = torch.distributions.Normal(0.0, 1.0)
168
+ out = torch.full_like(y_obs, float("nan")) # result placeholder
169
+
170
+ for i in range(I):
171
+ # ---- select the irregular grid for subject i -------------------
172
+ valid_idx = mask[i].bool()
173
+ if not valid_idx.any():
174
+ continue # nothing to do
175
+
176
+ y_i_obs = y_obs[i, valid_idx] # [Tᵥ]
177
+ y_i_sim = y_sim[:, i, valid_idx] # [S,Tᵥ]
178
+
179
+ # ---- whitening per your NumPy logic ----------------------------
180
+ y_i_sim_white, W, mean_vec = whiten_manual_torch(y_i_sim, eps) # [S,Tᵥ]
181
+ if W is None:
182
+ # Whitening degraded → set result to NaN or skip this subject
183
+ out[i, valid_idx] = float("nan")
184
+ continue
185
+
186
+ # same transform for the single observation vector
187
+ y_i_obs_white = (y_i_obs - mean_vec) @ W # [Tᵥ]
188
+
189
+ # ---- empirical CDF on whitened scale (Eq. 4) -------------------
190
+ delta = (y_i_sim_white < y_i_obs_white).float() # [S,Tᵥ]
191
+ pde = delta.mean(dim=0) # [Tᵥ]
192
+
193
+ # ---- edge-case rule (Eq. 6) ------------------------------------
194
+ one_over_S = 1.0 / S
195
+ pde = torch.where(pde == 0, torch.full_like(pde, one_over_S), pde)
196
+ pde = torch.where(pde == 1, torch.full_like(pde, 1 - one_over_S), pde)
197
+
198
+ # ---- NPDE (Eq. 7) ---------------------------------------------
199
+ npde = N01.icdf(pde) # [Tᵥ]
200
+
201
+ # ---- write back to full-size tensor ----------------------------
202
+ out[i, valid_idx] = npde
203
+
204
+ return out
205
+
206
+ # ---------------------------------------------------------------------
207
+ # 1. Robust whitening
208
+ # ---------------------------------------------------------------------
209
+ def whiten_manual_torch(
210
+ X: Tensor, # [S, Tᵥ]
211
+ eps: float = 1e-8,
212
+ max_attempts: int = 5,
213
+ base_jitter: float = 1e-6
214
+ ) -> Tuple[Tensor, torch.Tensor | None, Tensor, bool]:
215
+ """
216
+ Returns
217
+ -------
218
+ X_white : [S,Tᵥ] whitened simulations
219
+ W : [Tᵥ,Tᵥ] | None Σ^{-½} (None ⇒ degraded to diag)
220
+ mean : [Tᵥ] sample mean
221
+ ok : bool True if full Σ^{-½} was used
222
+ """
223
+ S, T = X.shape
224
+ X64 = X.double()
225
+ mean = X64.mean(dim=0)
226
+ Xm = X64 - mean
227
+ cov = (Xm.T @ Xm) / (S - 1)
228
+ I = torch.eye(T, dtype=X64.dtype, device=X.device)
229
+
230
+ W = None
231
+ for k in range(max_attempts):
232
+ jitter = base_jitter * (10.0 ** k)
233
+ try:
234
+ eigvals, eigvecs = torch.linalg.eigh(cov + (eps + jitter) * I)
235
+ if torch.any(eigvals <= 0):
236
+ raise RuntimeError("non-positive eigenvalues")
237
+ inv_sqrt = torch.rsqrt(eigvals)
238
+ W = eigvecs @ torch.diag(inv_sqrt) @ eigvecs.T
239
+ break
240
+ except RuntimeError:
241
+ pass # try bigger jitter
242
+
243
+ if W is None: # final fallback
244
+ var = cov.diag().clamp_min(eps)
245
+ W = torch.diag(torch.rsqrt(var)) # diagonal only
246
+ ok = False
247
+ else:
248
+ ok = True
249
+
250
+ X_white = (Xm @ W).float()
251
+ return X_white, W.float() if ok else None, mean.float(), ok
252
+
253
+
254
+ # ---------------------------------------------------------------------
255
+ # 2. NPDE with an *output* validity mask
256
+ # ---------------------------------------------------------------------
257
+
258
+ def compute_npde_full(
259
+ y_obs: TensorType["I", "T"],
260
+ y_sim: TensorType["S", "I", "T"],
261
+ mask : TensorType["I", "T"],
262
+ eps : float = 1e-8,
263
+ ) -> Tuple[TensorType["I", "T"], TensorType["I", "T"]]:
264
+ """
265
+ Full NPDE with within-subject decorrelation (Σ^{-1/2}) computed
266
+ **exactly** as in your NumPy snippet.
267
+
268
+ NOTICE THAT THERE IS NO BATCH INDEX, this works only on individual substances
269
+
270
+ Args
271
+ ------
272
+ y_obs : [I,T] observations (padding allowed)
273
+ y_sim : [S,I,T] S Monte-Carlo replicates
274
+ mask : [I,T] True/1 = valid time-points
275
+
276
+ Returns
277
+ -------
278
+ npde : [I,T] – same shape as `y_obs`
279
+ valid_mask : [I,T] – True where npde is statistically valid
280
+ """
281
+ S, I, T = y_sim.shape
282
+ N01 = torch.distributions.Normal(0.0, 1.0)
283
+
284
+ npde_out = torch.full_like(y_obs, float("nan"))
285
+ valid_out = mask.clone().bool() # start with the user mask
286
+
287
+ for i in range(I):
288
+ # ---- select the irregular grid for subject i -------------------
289
+ valid_idx = mask[i].bool()
290
+ if not valid_idx.any():
291
+ valid_out[i] = False
292
+ continue
293
+
294
+ y_i_obs = y_obs[i, valid_idx] # [Tᵥ]
295
+ y_i_sim = y_sim[:, i, valid_idx] # [S,Tᵥ]
296
+
297
+ # ---- whitening per your NumPy logic ----------------------------
298
+ y_i_sim_white, W, mean_vec, ok = whiten_manual_torch(y_i_sim, eps)
299
+
300
+ if not ok: # whitening failed → invalidate
301
+ valid_out[i, valid_idx] = False
302
+ continue
303
+
304
+ # same transform for the single observation vector
305
+ y_i_obs_white = (y_i_obs - mean_vec) @ W
306
+
307
+ # ---- empirical CDF on whitened scale (Eq. 4) -------------------
308
+ delta = (y_i_sim_white < y_i_obs_white).float()
309
+ pde = delta.mean(dim=0)
310
+
311
+ # ---- edge-case rule (Eq. 6) ------------------------------------
312
+ one_over_S = 1.0 / S
313
+ pde = torch.where(pde == 0, torch.full_like(pde, one_over_S), pde)
314
+ pde = torch.where(pde == 1, torch.full_like(pde, 1 - one_over_S), pde)
315
+
316
+ # ---- NPDE (Eq. 7) ---------------------------------------------
317
+ npde = N01.icdf(pde)
318
+ npde_out[i, valid_idx] = npde
319
+
320
+ return npde_out, valid_out
321
+
322
+
323
+ def compute_npde_in_batch(
324
+ y_obs: TensorType["B", "I", "T"],
325
+ y_sim: TensorType["S", "B", "I", "T"],
326
+ mask: TensorType["B", "I", "T"],
327
+ eps: float = 1e-8,
328
+ ) -> TensorType["B", "I", "T"]:
329
+ """Compute NPDE for each element in a batch.
330
+
331
+ Parameters
332
+ ----------
333
+ y_obs : [B, I, T] Observed values per batch item (context observations).
334
+ y_sim : [S, B, I, T] Simulated predictions.
335
+ mask : [B, I, T] Validity mask for observations.
336
+
337
+ Returns
338
+ -------
339
+ Tensor of shape [B, I, T] with NPDE values.
340
+ """
341
+ B = y_obs.size(0)
342
+ results = []
343
+ for b in range(B):
344
+ npde_b = compute_npde_full(y_obs[b], y_sim[:, b], mask[b], eps)
345
+ results.append(npde_b)
346
+ return torch.stack(results, dim=0)
347
+
348
+ def shapiro_wilk_normality(npde: TensorType["T"]) -> Tuple[float, float]:
349
+ """Return Shapiro-Wilk normality test statistic and p-value for a 1-D tensor."""
350
+ npde_np = npde[torch.isfinite(npde)].detach().cpu().numpy()
351
+ w, p = stats.shapiro(npde_np)
352
+ return float(w), float(p)
353
+
354
+ def qq_plot(npde: TensorType["T"], train:bool =False, epoch:str|int = "na", **kwargs) -> str | None:
355
+ """
356
+ Generate and optionally save/show a QQ plot of NPDE values.
357
+
358
+ Args:
359
+ npde: Tensor containing NPDE values.
360
+ train (bool, optional): If True (default), saves plot to file.
361
+ model (optional): Lightning model, used to name the file with `current_epoch`.
362
+
363
+ Returns:
364
+ File path if saved, None otherwise.
365
+ """
366
+ npde_np = npde[torch.isfinite(npde)].detach().cpu().numpy()
367
+
368
+ fig = plt.figure()
369
+ stats.probplot(npde_np, dist="norm", plot=plt)
370
+
371
+ if train:
372
+ # Use model.current_epoch if provided
373
+ path = f"qq_plot_epoch_{epoch}.png"
374
+ fig.savefig(path, bbox_inches="tight")
375
+ plt.close(fig)
376
+ return path
377
+ else:
378
+ plt.show()
379
+ return None
380
+
381
+ def vcp_from_sample(model,databatch_list,empirical_databatch,train=False):
382
+ """
383
+ in order to have a shape [S,I,T] vs [I,T] the models concatenates all samples for each held out individuals
384
+ which are of shape [S,B=1,I=1,T,1] (held out sample) -> [S,I,T] (required by vpc)
385
+ """
386
+ samples_list = model.sample(databatch_list,use_unique_times=True,num_samples=30)
387
+ combined_observation = combine_samples([pair[0] for pair in samples_list])
388
+ combined_times = combine_samples([pair[1] for pair in samples_list])
389
+ print(combined_observation.shape)
390
+ simulation_times = combined_times[0,0,:]
391
+ print(simulation_times.shape)
392
+ patients, patients_time = extract_context_by_mask(empirical_databatch)
393
+ img = vpc(simulation_times, combined_observation, patients, patients_time,train=train)
394
+ return img
395
+
396
+ def vpc_from_empirical(databatch_list,databatch_list_context,model,train=False,image_name="vpc.png",samples_number=100,y_scale=None):
397
+ aicme = databatch_list_context[0]
398
+ patients = [db_tuple[0].target_obs.cpu().detach().numpy() for db_tuple in databatch_list]
399
+ patients_time = [db_tuple[0].target_obs_time.cpu().detach().numpy() for db_tuple in databatch_list]
400
+ max_time_index = aicme.context_obs_mask.sum(axis=2).squeeze().argmax()
401
+ all_samples_times = aicme.context_obs_time[0,max_time_index,aicme.context_obs_mask[0,max_time_index]]
402
+ all_samples = []
403
+ for db_tuple in databatch_list:
404
+ samples,samples_time = model.sample_new_individual(db_tuple,samples_number)
405
+ all_samples.append(samples)
406
+ all_samples = torch.cat(all_samples,dim=2).squeeze()
407
+ all_samples = all_samples[:,:,aicme.context_obs_mask[0,max_time_index]]
408
+ vpc(all_samples_times, all_samples, patients, patients_time,train=train,image_name=image_name,y_scale=y_scale)
409
+
410
+ def vpc(test_time, MetaStudies, patients, patients_time, train=True, image_name="vpc.png", y_scale=None):
411
+ """
412
+ Generate a Visual Predictive Check (VPC) plot with PyTorch tensor inputs.
413
+
414
+ Parameters:
415
+ - test_time: 1D PyTorch tensor of fixed time points for simulated data (shape [T])
416
+ - MetaStudies: 3D PyTorch tensor of simulated data (shape [M, P, T])
417
+ - patients: List of 1D PyTorch tensors, each with observed concentrations
418
+ - patients_time: List of 1D PyTorch tensors, each with corresponding times
419
+ - train: If True, save plot; else show it
420
+ - image_name: File name to save the image if train=True
421
+ - y_scale: Set to "log" for log-scale y-axis; None for linear
422
+ """
423
+ if len(test_time.shape) > 1:
424
+ test_time = test_time.squeeze()
425
+
426
+ test_time_np = test_time.detach().cpu().numpy()
427
+ MetaStudies_np = MetaStudies.detach().cpu().numpy()
428
+
429
+ percentiles = [5, 25, 50, 75, 95]
430
+ sim_percentiles = np.percentile(MetaStudies_np, percentiles, axis=1) # [5, M, T]
431
+ sim_percentiles_agg = np.percentile(sim_percentiles, 50, axis=1) # [5, T]
432
+
433
+ p05, p25, p50, p75, p95 = sim_percentiles_agg
434
+
435
+ plt.figure(figsize=(10, 6))
436
+ plt.fill_between(test_time_np, p05, p95, color='blue', alpha=0.2, label='5th-95th Percentile')
437
+ plt.fill_between(test_time_np, p25, p75, color='blue', alpha=0.4, label='25th-75th Percentile')
438
+ plt.plot(test_time_np, p50, color='blue', label='Median (50th Percentile)')
439
+
440
+ for obs, times in zip(patients, patients_time):
441
+ plt.scatter(times, obs, color='red', alpha=0.6, s=20)
442
+
443
+ plt.xlabel('Time (hours)')
444
+ plt.ylabel('Concentration (g/L)')
445
+ if y_scale == "log":
446
+ plt.yscale('log')
447
+
448
+ plt.title('Visual Predictive Check (VPC)')
449
+ plt.legend()
450
+ plt.grid(True, linestyle='--', alpha=0.7)
451
+
452
+ if train:
453
+ plt.savefig(image_name)
454
+ plt.close()
455
+ return image_name
456
+ else:
457
+ plt.show()
458
+ plt.close()
459
+
460
+
461
+ def get_unique_target_times(
462
+ db_list: List[AICMECompartmentsDataBatch]
463
+ ) -> TensorType["U", 1]:
464
+ """
465
+ Given P databatches, each with
466
+ .target_obs_time: [B, t_ind, num_obs_t, 1]
467
+ returns a tensor of shape [U, 1] containing the sorted unique times
468
+ across *all* batches and *all* target time points.
469
+
470
+ Args:
471
+ db_list: list of length P of AICMECompartmentsDataBatch
472
+
473
+ Returns:
474
+ unique_times: Tensor of shape [U, 1], where U is the number of
475
+ unique target‐observation times across every batch.
476
+ """
477
+ # 1) Flatten each batch's times:
478
+ # db.target_obs_time.squeeze(-1).reshape(-1) has shape [B * t_ind * num_obs_t]
479
+ flat_times = [
480
+ db.target_obs_time.squeeze(-1).reshape(-1) # [B * t_ind * num_obs_t]
481
+ for db in db_list
482
+ ]
483
+ # 2) Concatenate all P batches → [(P * B * t_ind * num_obs_t)]
484
+ all_times = torch.cat(flat_times, dim=0)
485
+
486
+ # 3) Compute sorted unique values → [U]
487
+ unique = torch.unique(all_times)
488
+
489
+ # 4) Return as column vector → [U, 1]
490
+ return unique.unsqueeze(-1).unsqueeze(0).unsqueeze(0) # TensorType[1,1,"U", 1]
sim_priors_pk/metrics/quantiles_coverage.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torchtyping import TensorType
5
+
6
+
7
+ def compute_predictive_quantiles(
8
+ pred_values: TensorType["B", "S", "T", 1],
9
+ pred_mask: TensorType["B", "S", "T"] | TensorType["B", "T"],
10
+ alpha: float,
11
+ ) -> Tuple[
12
+ TensorType["B", "T", 1],
13
+ TensorType["B", "T", 1],
14
+ ]:
15
+ """
16
+ Compute lower and upper predictive quantiles (α/2, 1−α/2)
17
+ across stochastic samples, supporting both shared and
18
+ per-sample (per-individual) masks.
19
+
20
+ Parameters
21
+ ----------
22
+ pred_values : [B, S, T, 1]
23
+ Predicted sample trajectories.
24
+ pred_mask : [B, T] or [B, S, T]
25
+ Boolean mask marking valid time points.
26
+ - If [B, T]: same mask for all samples.
27
+ - If [B, S, T]: individual-specific masks.
28
+ alpha : float
29
+ Significance level (e.g. 0.05 for 90% interval).
30
+
31
+ Returns
32
+ -------
33
+ q_low, q_high : [B, T, 1]
34
+ Predictive lower and upper quantile envelopes.
35
+ """
36
+ B, S, T, _ = pred_values.shape
37
+ device = pred_values.device
38
+
39
+ # --- normalize mask shape ---
40
+ if pred_mask.ndim == 2:
41
+ pred_mask = pred_mask.unsqueeze(1).repeat(1, S, 1) # [B,S,T]
42
+
43
+ q_low_list = []
44
+ q_high_list = []
45
+
46
+ for b in range(B):
47
+ q_low_b = torch.zeros(T, device=device)
48
+ q_high_b = torch.zeros(T, device=device)
49
+
50
+ # for each time index, only include valid samples
51
+ for t_idx in range(T):
52
+ valid_s = pred_mask[b, :, t_idx]
53
+ if valid_s.any():
54
+ vals = pred_values[b, valid_s, t_idx, 0]
55
+ q_low_b[t_idx] = vals.quantile(alpha / 2)
56
+ q_high_b[t_idx] = vals.quantile(1 - alpha / 2)
57
+ else:
58
+ # leave zeros (or NaN if preferred)
59
+ q_low_b[t_idx] = 0.0
60
+ q_high_b[t_idx] = 0.0
61
+
62
+ q_low_list.append(q_low_b.unsqueeze(-1))
63
+ q_high_list.append(q_high_b.unsqueeze(-1))
64
+
65
+ q_low = torch.stack(q_low_list, dim=0) # [B,T,1]
66
+ q_high = torch.stack(q_high_list, dim=0) # [B,T,1]
67
+
68
+ return q_low, q_high
69
+
70
+
71
+ def interpolate_quantiles_to_obs_times(
72
+ q_low: TensorType["B", "Tpred", 1],
73
+ q_high: TensorType["B", "Tpred", 1],
74
+ pred_times: TensorType["B", "Tpred", 1],
75
+ pred_mask: TensorType["B", "Tpred"],
76
+ real_times: TensorType["B", "I", "Treal", 1],
77
+ real_mask: TensorType["B", "I", "Treal"],
78
+ ) -> Tuple[
79
+ TensorType["B", "I", "Treal", 1],
80
+ TensorType["B", "I", "Treal", 1],
81
+ ]:
82
+ """
83
+ Interpolate predictive quantile bands (q_low, q_high) to the irregular
84
+ observation times of real data.
85
+
86
+ Parameters
87
+ ----------
88
+ q_low, q_high : TensorType["B", "Tpred", 1]
89
+ Predictive lower and upper quantile curves at distinct time grid points.
90
+ pred_times : TensorType["B", "Tpred", 1]
91
+ Time grid corresponding to the quantile curves.
92
+ pred_mask : TensorType["B", "Tpred"]
93
+ Boolean mask marking valid predictive times per batch.
94
+ real_times : TensorType["B", "I", "Treal", 1]
95
+ Observation times for each individual and batch.
96
+ real_mask : TensorType["B", "I", "Treal"]
97
+ Mask indicating valid observed time points.
98
+
99
+ Returns
100
+ -------
101
+ q_low_interp, q_high_interp : Tuple[
102
+ TensorType["B", "I", "Treal", 1],
103
+ TensorType["B", "I", "Treal", 1],
104
+ ]
105
+ Interpolated quantile band values at each observed time, padded
106
+ where invalid.
107
+
108
+ Notes
109
+ -----
110
+ - Uses linear interpolation between nearest predictive time knots.
111
+ - Out-of-range times are clamped to the nearest boundary quantile.
112
+ - Invalid (masked) observations are returned as zeros.
113
+ """
114
+
115
+ B, I, Treal, _ = real_times.shape
116
+ device = real_times.device
117
+
118
+ q_low_interp_list, q_high_interp_list = [], []
119
+
120
+ for b in range(B):
121
+ # Extract valid predictive points for this batch
122
+ valid_mask_b = pred_mask[b] # [Tpred]
123
+ valid_T = valid_mask_b.sum().item()
124
+ if valid_T < 2:
125
+ # Degenerate case: not enough points for interpolation
126
+ q_low_interp_list.append(torch.zeros(I, Treal, 1, device=device))
127
+ q_high_interp_list.append(torch.zeros(I, Treal, 1, device=device))
128
+ continue
129
+
130
+ t_pred = pred_times[b, valid_mask_b, 0] # [T_b]
131
+ ql = q_low[b, valid_mask_b, 0] # [T_b]
132
+ qh = q_high[b, valid_mask_b, 0] # [T_b]
133
+
134
+ # For each individual, interpolate its observation times
135
+ q_low_i, q_high_i = [], []
136
+ for i in range(I):
137
+ t_obs = real_times[b, i, :, 0] # [Treal]
138
+ valid_obs = real_mask[b, i] # [Treal]
139
+
140
+ # Clamp obs times into predictive range
141
+ t_clamped = torch.clamp(t_obs, t_pred.min(), t_pred.max())
142
+
143
+ # Use searchsorted to find bracketing indices
144
+ idx_right = torch.searchsorted(t_pred, t_clamped)
145
+ idx_left = (idx_right - 1).clamp(min=0)
146
+ idx_right = idx_right.clamp(max=valid_T - 1)
147
+
148
+ # Gather times and values for interpolation
149
+ t_L, t_R = t_pred[idx_left], t_pred[idx_right]
150
+ ql_L, ql_R = ql[idx_left], ql[idx_right]
151
+ qh_L, qh_R = qh[idx_left], qh[idx_right]
152
+
153
+ denom = (t_R - t_L).clamp(min=1e-8)
154
+ w_R = (t_clamped - t_L) / denom
155
+ w_L = 1.0 - w_R
156
+
157
+ ql_interp = w_L * ql_L + w_R * ql_R
158
+ qh_interp = w_L * qh_L + w_R * qh_R
159
+
160
+ # Zero out invalid times
161
+ ql_interp = ql_interp.masked_fill(~valid_obs, 0.0)
162
+ qh_interp = qh_interp.masked_fill(~valid_obs, 0.0)
163
+
164
+ q_low_i.append(ql_interp.unsqueeze(-1))
165
+ q_high_i.append(qh_interp.unsqueeze(-1))
166
+
167
+ q_low_interp_list.append(torch.stack(q_low_i, dim=0)) # [I, Treal, 1]
168
+ q_high_interp_list.append(torch.stack(q_high_i, dim=0)) # [I, Treal, 1]
169
+
170
+ q_low_interp = torch.stack(q_low_interp_list, dim=0) # [B, I, Treal, 1]
171
+ q_high_interp = torch.stack(q_high_interp_list, dim=0) # [B, I, Treal, 1]
172
+
173
+ return q_low_interp, q_high_interp
174
+
175
+
176
+ def compute_time_weighted_coverage(
177
+ real_values: TensorType["B", "I", "Treal", 1],
178
+ real_times: TensorType["B", "I", "Treal", 1],
179
+ real_mask: TensorType["B", "I", "Treal"],
180
+ q_low_interp: TensorType["B", "I", "Treal", 1],
181
+ q_high_interp: TensorType["B", "I", "Treal", 1],
182
+ reduce: bool = True,
183
+ ) -> TensorType["B"]:
184
+ """
185
+ Compute time-weighted coverage fraction of observations within predictive bands.
186
+ """
187
+ # [B, I, Treal, 1]
188
+ covered = (real_values >= q_low_interp) & (real_values <= q_high_interp)
189
+ covered = covered.squeeze(-1) & real_mask # [B, I, Treal]
190
+
191
+ # Compute Δt (difference along time)
192
+ dt = torch.diff(real_times, dim=2, prepend=real_times[:, :, :1])
193
+ dt = dt.squeeze(-1) * real_mask # [B, I, Treal]
194
+ dt_sum = dt.sum(dim=(1, 2), keepdim=True).clamp(min=1e-8)
195
+ weights = dt / dt_sum # normalized time weights
196
+
197
+ coverage = (covered.float() * weights).sum(dim=(1, 2)) # [B]
198
+ return coverage
199
+
200
+
201
+ def compute_interval_score(
202
+ real_values: TensorType["B", "I", "Treal", 1],
203
+ real_times: TensorType["B", "I", "Treal", 1],
204
+ real_mask: TensorType["B", "I", "Treal"],
205
+ q_low_interp: TensorType["B", "I", "Treal", 1],
206
+ q_high_interp: TensorType["B", "I", "Treal", 1],
207
+ alpha: float,
208
+ ) -> TensorType["B"]:
209
+ """
210
+ Compute the time-weighted interval score (Gneiting & Raftery, 2007).
211
+ """
212
+ width = (q_high_interp - q_low_interp).abs()
213
+ below = (q_low_interp - real_values).clamp(min=0)
214
+ above = (real_values - q_high_interp).clamp(min=0)
215
+
216
+ interval_score = width + (2 / alpha) * (below + above)
217
+ interval_score = interval_score.squeeze(-1) * real_mask # [B, I, Treal]
218
+
219
+ # Δt weighting
220
+ dt = torch.diff(real_times, dim=2, prepend=real_times[:, :, :1]).squeeze(-1)
221
+ dt = dt * real_mask
222
+ dt_sum = dt.sum(dim=(1, 2), keepdim=True).clamp(min=1e-8)
223
+ weights = dt / dt_sum
224
+
225
+ # Weighted mean per batch
226
+ score_weighted = (interval_score * weights).sum(dim=(1, 2))
227
+ return score_weighted
228
+
229
+
230
+ def compute_percentile_coverage(
231
+ pred_values,
232
+ pred_times,
233
+ pred_mask,
234
+ real_values,
235
+ real_times,
236
+ real_mask,
237
+ alpha: float = 0.05,
238
+ ):
239
+ """
240
+ Compute predictive interval coverage and interval score between predicted and observed trajectories.
241
+
242
+ This function evaluates how well a stochastic predictive model captures
243
+ the true (real) observations within its predictive uncertainty bands.
244
+
245
+ It combines three subroutines:
246
+ 1. :func:`compute_predictive_quantiles` — compute lower and upper predictive quantiles.
247
+ 2. :func:`interpolate_quantiles_to_obs_times` — align quantile predictions to observation times.
248
+ 3. :func:`compute_time_weighted_coverage` and :func:`compute_interval_score` —
249
+ compute Δt-weighted coverage fraction and proper scoring rule.
250
+
251
+ Parameters
252
+ ----------
253
+ pred_values : TensorType["B", "S", "T_pred", 1]
254
+ Stochastic predictions for each batch element `B` and stochastic sample `S`.
255
+ Typically obtained by sampling the model multiple times.
256
+
257
+ pred_times : TensorType["B", "T_pred", 1]
258
+ Distinct prediction time grid per batch (shared across stochastic samples).
259
+
260
+ pred_mask : TensorType["B", "T_pred"]
261
+ Boolean mask indicating valid prediction time steps.
262
+
263
+ real_values : TensorType["B", "I", "T_real", 1]
264
+ Ground-truth or observed values for each batch and individual.
265
+
266
+ real_times : TensorType["B", "I", "T_real", 1]
267
+ Observation times corresponding to `real_values`.
268
+
269
+ real_mask : TensorType["B", "I", "T_real"]
270
+ Boolean mask indicating valid observed time points.
271
+
272
+ alpha : float, optional (default = 0.05)
273
+ Significance level defining the predictive interval width.
274
+ For example:
275
+ * α = 0.05 → 90% central interval (quantiles 0.025 and 0.975)
276
+ * α = 0.10 → 80% central interval (quantiles 0.05 and 0.95)
277
+ Smaller α yields wider intervals (more conservative coverage).
278
+
279
+ Returns
280
+ -------
281
+ dict[str, TensorType["B"]]
282
+ Dictionary containing:
283
+ - ``"coverage"`` : Δt-weighted fraction of observations inside the predictive interval.
284
+ - ``"interval_score"`` : Proper interval score (Gneiting & Raftery, 2007),
285
+ penalizing both interval width and miscoverage.
286
+
287
+ Notes
288
+ -----
289
+ - High coverage (≈1.0) indicates all real points lie inside the predictive band.
290
+ In well-calibrated models, expected coverage ≈ 1−α.
291
+ - Lower interval scores correspond to sharper and better-calibrated predictions.
292
+
293
+ References
294
+ ----------
295
+ Gneiting, T. & Raftery, A. E. (2007). *Strictly Proper Scoring Rules, Prediction, and Estimation*.
296
+ Journal of the American Statistical Association, 102(477), 359-378.
297
+ """
298
+ q_low, q_high = compute_predictive_quantiles(pred_values, pred_mask, alpha)
299
+ q_low_interp, q_high_interp = interpolate_quantiles_to_obs_times(
300
+ q_low, q_high, pred_times, pred_mask, real_times, real_mask
301
+ )
302
+
303
+ coverage = compute_time_weighted_coverage(
304
+ real_values, real_times, real_mask, q_low_interp, q_high_interp
305
+ )
306
+ interval_score = compute_interval_score(
307
+ real_values, real_times, real_mask, q_low_interp, q_high_interp, alpha
308
+ )
309
+
310
+ return {"coverage": coverage, "interval_score": interval_score}
sim_priors_pk/metrics/sampling_quality.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate sampling quality of a model based on Visual Predictive Checks or Normalized Prediction Distribution Errors (NPDEs).
2
+ # Input for both evaluations: a StudyJSON object containing the observed data and a List[StudyJSON] containing replicates of simulated data from the model.
3
+ # This way, both neural networks and NLME models can be evaluated using the same code, as long as they can produce the required StudyJSON objects.
4
+
5
+ from typing import List, Optional, Sequence
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from scipy.stats import chi2, norm, shapiro, ttest_1samp
11
+
12
+ from sim_priors_pk.data.data_empirical.json_schema import IndividualJSON, StudyJSON
13
+
14
+
15
+ def json_to_dataframe(study_json: StudyJSON) -> pd.DataFrame:
16
+ """
17
+ Convert a StudyJSON object to a pandas DataFrame for easier analysis.
18
+
19
+ Args:
20
+ study_json (StudyJSON): The StudyJSON object to convert.
21
+ Returns:
22
+ pd.DataFrame: A DataFrame with columns ["Type", "ID", "Time", "Value"] from the StudyJSON data.
23
+ """
24
+
25
+ frames = []
26
+
27
+ for data_type in ["context", "target"]:
28
+ entries = study_json.get(data_type, [])
29
+
30
+ for j, entry in enumerate(entries):
31
+ # Prefer name_id, else _id, else a deterministic fallback
32
+ name_id = entry.get("name_id") or entry.get("_id") or f"{data_type}_{j}"
33
+
34
+ df_entry = pd.DataFrame(
35
+ {
36
+ "Type": data_type,
37
+ "ID": str(name_id), # ensure it's a string
38
+ "Time": entry["observation_times"],
39
+ "Value": entry["observations"],
40
+ }
41
+ )
42
+ frames.append(df_entry)
43
+
44
+ if frames:
45
+ return pd.concat(frames, ignore_index=True)
46
+ else:
47
+ return pd.DataFrame(columns=["Type", "ID", "Time", "Value"])
48
+
49
+
50
+ def json_list_to_dataframe(study_list: List[StudyJSON]) -> pd.DataFrame:
51
+ """
52
+ Convert a list of StudyJSON objects to a pandas DataFrame for easier analysis.
53
+
54
+ Args:
55
+ study_list (List[StudyJSON]): The list of StudyJSON objects to convert.
56
+
57
+ Returns:
58
+ pd.DataFrame: A DataFrame with columns ["Type", "ID", "Time", "Value", "Replicate"] from the StudyJSON data.
59
+ """
60
+
61
+ frames = []
62
+
63
+ for replicate_idx, study in enumerate(study_list):
64
+ df = json_to_dataframe(study)
65
+ df["Replicate"] = replicate_idx
66
+ frames.append(df)
67
+
68
+ return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
69
+
70
+
71
+ def validate_npde_vpc_inputs(
72
+ data: pd.DataFrame, simulations: List[pd.DataFrame], differentTimesError: bool = True
73
+ ) -> None:
74
+ """
75
+ Validate the inputs for NPDE / VPC calculation.
76
+
77
+ Args:
78
+ data (pd.DataFrame): The observed data in DataFrame format with columns ["Type", "ID", "Time", "Value"].
79
+ simulations (List[pd.DataFrame]): A list of DataFrames with columns ["Type", "ID", "Time", "Value","Replicate"]
80
+ representing simulated data from the model.
81
+ differentTimesError (bool): Whether to raise an error if observation times differ between individuals (default: True).
82
+
83
+ Returns:
84
+ None: If the inputs are valid, otherwise raises a ValueError.
85
+ """
86
+
87
+ key_cols = ["Type", "ID", "Time"]
88
+
89
+ obs_keys = data[key_cols].drop_duplicates().sort_values(key_cols).reset_index(drop=True)
90
+
91
+ pred_keys = simulations[key_cols].drop_duplicates().sort_values(key_cols).reset_index(drop=True) # type: ignore
92
+
93
+ if not obs_keys.equals(pred_keys):
94
+ raise ValueError("Observations and predictions are not structurally identical.")
95
+
96
+ if differentTimesError:
97
+ if (data.groupby("ID")["Time"].apply(lambda x: tuple(sorted(x))).nunique()) != 1:
98
+ raise ValueError("Observation times differ between individuals.")
99
+
100
+ return None
101
+
102
+
103
+ def compute_npde_data(data: StudyJSON, simulations: List[StudyJSON]) -> np.ndarray:
104
+ """
105
+ Compute Normalized Prediction Distribution Errors (NPDEs) for a given StudyJSON and a list of simulated StudyJSONs.
106
+
107
+ Args:
108
+ data (StudyJSON): The observed data in StudyJSON
109
+ simulations (List[StudyJSON]): A list of StudyJSON objects representing simulated data from the model.
110
+
111
+ Returns:
112
+ np.ndarray: An array of NPDE values.
113
+ """
114
+ # Extract observed values and predicted values from the StudyJSON objects and validate them before calculating NPDEs.
115
+ observed_values = json_to_dataframe(data)
116
+ predicted_values = json_list_to_dataframe(simulations)
117
+
118
+ validate_npde_vpc_inputs(observed_values, predicted_values, differentTimesError=False)
119
+
120
+ # Merge observations and predictions into a single DataFrame for NPDE calculation.
121
+ key_cols = ["Type", "ID", "Time"]
122
+
123
+ pred_wide = predicted_values.pivot(index=key_cols, columns="Replicate", values="Value")
124
+
125
+ obs_indexed = observed_values.set_index(key_cols)
126
+
127
+ combined = pred_wide.join(obs_indexed["Value"].rename("Observed"))
128
+
129
+ # Calculate NPDEs for each replicate and return as a numpy array.
130
+ replicate_cols = pred_wide.columns
131
+
132
+ pred_vals = combined[replicate_cols].values
133
+ obs_vals = combined["Observed"].values
134
+
135
+ # Empirical CDF (truncated to avoid 0 and 1) for each observation based on the predicted distribution from the replicates.
136
+ pde = (pred_vals <= obs_vals[:, None]).sum(axis=1) / (len(replicate_cols) + 1) + 0.5 / (
137
+ len(replicate_cols) + 1
138
+ )
139
+ npde = norm.ppf(pde)
140
+
141
+ return npde
142
+
143
+
144
+ def npde_plot(npde_values: np.ndarray) -> None:
145
+ """
146
+ Create a quantile-quantile-plot of NPDE values.
147
+
148
+ Args:
149
+ npde_values (np.ndarray): An array of NPDE values to plot.
150
+
151
+ Returns:
152
+ None
153
+ """
154
+ plt.figure(figsize=(6, 6))
155
+ plt.title("Q-Q Plot of NPDE Values")
156
+ plt.xlabel("Theoretical Quantiles")
157
+ plt.ylabel("Empirical Quantiles")
158
+ norm_qq = np.sort(npde_values)
159
+ theoretical_qq = norm.ppf((np.arange(len(npde_values)) + 1) / (len(npde_values) + 1))
160
+ plt.plot(theoretical_qq, norm_qq, marker="o", linestyle="")
161
+ plt.plot(theoretical_qq, theoretical_qq, color="red", linestyle="--")
162
+ plt.grid()
163
+ plt.show()
164
+
165
+
166
+ def npde_pvalues(npde_values: np.ndarray) -> dict:
167
+ """
168
+ Calculate p-values based on the theoretical N(0,1) distribution of NPDE values.
169
+
170
+ Args:
171
+ npde_values (np.ndarray): An array of NPDE values to summarize.
172
+
173
+ Returns:
174
+ dict: A dictionary containing p-values for different tests applied to the NPDE values:
175
+ - "mean": The (one-sample) t-test for zero mean of the NPDE values.
176
+ - "variance": The (one-sample) chi-squared test for unit variance of the NPDE values.
177
+ - "normality": The Shapiro-Wilk test for normality of the NPDE values.
178
+ """
179
+
180
+ # variance test not implemented in scipy, so we calculate the p-value manually based on
181
+ # the chi-squared distribution of the sample variance under the null hypothesis of unit variance.
182
+ n = len(npde_values)
183
+ sample_var = np.var(npde_values, ddof=1)
184
+ chi2_stat = (n - 1) * sample_var
185
+ p_lower = chi2.cdf(chi2_stat, df=n - 1)
186
+ p_upper = 1 - p_lower
187
+ p_var = 2 * min(p_lower, p_upper)
188
+
189
+ return {
190
+ "mean": ttest_1samp(npde_values, 0).pvalue, # type: ignore
191
+ "variance": p_var,
192
+ "normality": shapiro(npde_values).pvalue,
193
+ }
194
+
195
+
196
+ def compute_vpc_data(
197
+ data: StudyJSON,
198
+ simulations: Sequence[StudyJSON],
199
+ quantiles: List[float] = [0.05, 0.5, 0.95],
200
+ confidence: float = 0.9,
201
+ n_bins: Optional[int] = None,
202
+ binning: str = "equal_count", # "equal_count" or "equal_width"
203
+ ) -> pd.DataFrame:
204
+ """
205
+ Compute data for a Visual Predictive Check (VPC) plot for the given StudyJSON and a list of simulated StudyJSONs.
206
+
207
+ Args:
208
+ data (StudyJSON): The observed data in StudyJSON
209
+ simulations (List[StudyJSON]): A list of simulated data in StudyJSON format.
210
+ quantiles (List[float]): The quantiles to display in the VPC plot (default: [0.05, 0.5, 0.95]).
211
+ confidence (float): The confidence level for the prediction intervals (default: 0.9).
212
+ Returns:
213
+ pd.DataFrame: A DataFrame containing the VPC data.
214
+ """
215
+
216
+ observed_values = json_to_dataframe(data)
217
+ predicted_values = json_list_to_dataframe(simulations)
218
+
219
+ alpha_low = (1 - confidence) / 2
220
+ alpha_high = 1 - alpha_low
221
+
222
+ # --------------------------------
223
+ # Binning (if requested OR if needed)
224
+ # --------------------------------
225
+ if n_bins is not None:
226
+ validate_npde_vpc_inputs(observed_values, predicted_values, differentTimesError=False)
227
+
228
+ match binning:
229
+ case "equal_count":
230
+ observed_values["TimeBin"] = pd.qcut(
231
+ observed_values["Time"], q=n_bins, duplicates="drop"
232
+ )
233
+
234
+ # Use same bin edges for predicted
235
+ bins = observed_values["TimeBin"].cat.categories
236
+ predicted_values["TimeBin"] = pd.cut(predicted_values["Time"], bins=bins)
237
+
238
+ case "equal_width":
239
+ tmin = observed_values["Time"].min()
240
+ tmax = observed_values["Time"].max()
241
+ bins = np.linspace(tmin, tmax, n_bins + 1)
242
+
243
+ observed_values["TimeBin"] = pd.cut(
244
+ observed_values["Time"], bins=bins, include_lowest=True
245
+ )
246
+ predicted_values["TimeBin"] = pd.cut(
247
+ predicted_values["Time"], bins=bins, include_lowest=True
248
+ )
249
+
250
+ case _:
251
+ raise ValueError("binning must be 'equal_width' or 'equal_count'")
252
+
253
+ # Use bin midpoint for plotting
254
+ bin_midpoints = (
255
+ observed_values.groupby("TimeBin", observed=False)["Time"].mean().rename("Time")
256
+ )
257
+
258
+ # Replace Time with bin midpoint
259
+ observed_values["Time"] = observed_values["TimeBin"].map(bin_midpoints)
260
+ predicted_values["Time"] = predicted_values["TimeBin"].map(bin_midpoints)
261
+
262
+ # Drop bin column
263
+ observed_values = observed_values.drop(columns="TimeBin")
264
+ predicted_values = predicted_values.drop(columns="TimeBin")
265
+
266
+ else:
267
+ validate_npde_vpc_inputs(observed_values, predicted_values, differentTimesError=True) # type: ignore
268
+
269
+ # --------------------------------
270
+ # Quantile calculation
271
+ # --------------------------------
272
+ vpc_obs = (
273
+ observed_values.groupby("Time")["Value"]
274
+ .quantile(quantiles) # type: ignore
275
+ .rename("Obs")
276
+ .reset_index()
277
+ .rename(columns={"level_1": "Quantile"})
278
+ )
279
+
280
+ vpc_pred = (
281
+ predicted_values.groupby(["Time", "Replicate"])["Value"]
282
+ .quantile(quantiles) # type: ignore
283
+ .rename("SimQuantile")
284
+ .reset_index()
285
+ .rename(columns={"level_2": "Quantile"})
286
+ .groupby(["Time", "Quantile"])["SimQuantile"]
287
+ .quantile([alpha_low, alpha_high]) # type: ignore
288
+ .rename("VPC")
289
+ .reset_index()
290
+ .rename(columns={"level_2": "PI"})
291
+ .pivot(index=["Time", "Quantile"], columns="PI", values="VPC")
292
+ .reset_index()
293
+ .rename(columns={alpha_low: "LowerPred", alpha_high: "UpperPred"})
294
+ )
295
+
296
+ vpc_data = vpc_obs.merge(vpc_pred, on=["Time", "Quantile"], how="left")
297
+
298
+ return vpc_data
299
+
300
+
301
+ def vpc_plot(vpc_data: pd.DataFrame, ax=None, log_y: bool = False) -> None:
302
+ """
303
+ Create a Visual Predictive Check (VPC) plot for the given VPC data.
304
+
305
+ Args:
306
+ vpc_data (pd.DataFrame): The VPC data to plot.
307
+ ax: Optional matplotlib axis to plot on. If None, a new figure and axis will be created.
308
+ log_y: Whether to use a logarithmic scale for the y-axis (default: False).
309
+
310
+ Returns:
311
+ None
312
+ """
313
+
314
+ quantiles = np.sort(vpc_data["Quantile"].unique())
315
+
316
+ # Enforce exactly 3 quantiles
317
+ if len(quantiles) != 3:
318
+ raise ValueError(f"Expected exactly 3 quantiles, got {len(quantiles)}: {quantiles}")
319
+
320
+ # Default axis management
321
+ if ax is None:
322
+ fig, ax = plt.subplots(figsize=(10, 6))
323
+
324
+ # Log-scale option
325
+ if log_y:
326
+ # Safety check: log scale requires strictly positive values
327
+ y_cols = ["Obs", "LowerPred", "UpperPred"]
328
+ if (vpc_data[y_cols] <= 0).any().any():
329
+ raise ValueError("Log scale requested but non-positive values detected.")
330
+ ax.set_yscale("log")
331
+
332
+ # Color scheme: lower, median, upper
333
+ colors = ["tab:blue", "tab:orange", "tab:blue"]
334
+
335
+ # Map sorted quantiles to colors
336
+ q_to_color = dict(zip(quantiles, colors))
337
+
338
+ # Plot observed quantiles
339
+ quantiles = vpc_data["Quantile"].unique()
340
+ for q in quantiles:
341
+ subset = vpc_data[vpc_data["Quantile"] == q]
342
+
343
+ color = q_to_color[q]
344
+ is_median = np.isclose(q, 0.5)
345
+
346
+ ax.plot(
347
+ subset["Time"],
348
+ subset["Obs"],
349
+ marker="o",
350
+ color=color,
351
+ linewidth=2 if is_median else 1,
352
+ label=f"Observed {q:.0%}",
353
+ )
354
+
355
+ ax.fill_between(
356
+ subset["Time"],
357
+ subset["LowerPred"],
358
+ subset["UpperPred"],
359
+ color=color,
360
+ alpha=0.25,
361
+ label=f"Simulated {q:.0%} PI",
362
+ )
363
+
364
+ # Keep legend outside the plotting area to avoid occluding trajectories.
365
+ ax.legend(
366
+ loc="upper center",
367
+ bbox_to_anchor=(0.5, -0.18),
368
+ ncol=3,
369
+ frameon=False,
370
+ )
371
+ ax.figure.subplots_adjust(bottom=0.25)
372
+ return ax
373
+
374
+
375
+ if __name__ == "__main__":
376
+ # Example usage
377
+ observed_data = StudyJSON(
378
+ context=[
379
+ IndividualJSON(name_id="1", observation_times=[0, 1, 2], observations=[10, 20, 30]),
380
+ IndividualJSON(name_id="2", observation_times=[0, 1, 2], observations=[11, 21, 31]),
381
+ ]
382
+ )
383
+
384
+ simulated_data = [
385
+ StudyJSON(
386
+ context=[
387
+ IndividualJSON(name_id="1", observation_times=[0, 1, 2], observations=[12, 22, 32]),
388
+ IndividualJSON(name_id="2", observation_times=[0, 1, 2], observations=[13, 21, 30]),
389
+ ]
390
+ ),
391
+ StudyJSON(
392
+ context=[
393
+ IndividualJSON(name_id="1", observation_times=[0, 1, 2], observations=[8, 18, 28]),
394
+ IndividualJSON(name_id="2", observation_times=[0, 1, 2], observations=[11, 19, 27]),
395
+ ]
396
+ ),
397
+ ]
398
+ # Convert to dataframes for visualization (optional)
399
+ observed_values = json_to_dataframe(observed_data)
400
+ simulated_values = json_list_to_dataframe(simulated_data)
401
+
402
+ validate_npde_inputs(observed_values, simulated_values)
403
+ npde_results = calculate_npde(observed_data, simulated_data)
404
+
405
+ print("NPDE Results:", npde_results)
406
+
407
+ vpc_data = create_vpc_data(observed_data, simulated_data)
408
+
409
+ vpc_plot(vpc_data)