FangDai commited on
Commit
bf41494
·
verified ·
1 Parent(s): bea5a4b

Upload 4 files

Browse files
Files changed (4) hide show
  1. requirements.txt +6 -0
  2. run_sample.ipynb +883 -0
  3. test.py +173 -0
  4. train_patient_model.py +576 -0
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision>=0.10.0
3
+ scikit-learn>=0.24.2
4
+ pillow>=8.0.0
5
+ numpy>=1.19.5
6
+ pyradiomics>=3.0.1
run_sample.ipynb ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "application/javascript": "\nvar cell = this.closest('.cell');\nif (cell) {\n cell.classList.remove('output_scroll');\n}\n",
11
+ "text/plain": [
12
+ "<IPython.core.display.Javascript object>"
13
+ ]
14
+ },
15
+ "metadata": {},
16
+ "output_type": "display_data"
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "\n",
23
+ "Bad key paths in file /export/home/daifang/.config/matplotlib/matplotlibrc, line 3 ('paths: /export/home/daifang/fonts/arial/')\n",
24
+ "You probably need to get an updated matplotlibrc file from\n",
25
+ "https://github.com/matplotlib/matplotlib/blob/v3.3.4/matplotlibrc.template\n",
26
+ "or from the matplotlib source distribution\n"
27
+ ]
28
+ },
29
+ {
30
+ "name": "stdout",
31
+ "output_type": "stream",
32
+ "text": [
33
+ "Using device: cuda:1\n",
34
+ "\n",
35
+ "[Epoch 001] Train Loss=4.903 | Val Loss=4.777\n",
36
+ " All | Subtype AUC=0.610 | TNM AUC=0.571 | DFS C-index=0.631 | OS C-index=0.454\n",
37
+ " Immune | Subtype AUC=0.571 | TNM AUC=0.585 | DFS C-index=0.627 | OS C-index=0.478\n",
38
+ " Chemo | Subtype AUC=0.675 | TNM AUC=0.559 | DFS C-index=0.588 | OS C-index=0.476\n",
39
+ " ✓ Best model updated\n",
40
+ "\n",
41
+ "[Epoch 002] Train Loss=4.770 | Val Loss=4.596\n",
42
+ " All | Subtype AUC=0.678 | TNM AUC=0.683 | DFS C-index=0.606 | OS C-index=0.536\n",
43
+ " Immune | Subtype AUC=0.691 | TNM AUC=0.654 | DFS C-index=0.585 | OS C-index=0.591\n",
44
+ " Chemo | Subtype AUC=0.634 | TNM AUC=0.729 | DFS C-index=0.639 | OS C-index=0.456\n",
45
+ " ✓ Best model updated\n",
46
+ "\n",
47
+ "[Epoch 003] Train Loss=4.551 | Val Loss=4.404\n",
48
+ " All | Subtype AUC=0.698 | TNM AUC=0.678 | DFS C-index=0.592 | OS C-index=0.450\n",
49
+ " Immune | Subtype AUC=0.667 | TNM AUC=0.686 | DFS C-index=0.635 | OS C-index=0.480\n",
50
+ " Chemo | Subtype AUC=0.762 | TNM AUC=0.687 | DFS C-index=0.541 | OS C-index=0.379\n",
51
+ " ✓ Best model updated\n",
52
+ "\n",
53
+ "[Epoch 004] Train Loss=4.299 | Val Loss=4.076\n",
54
+ " All | Subtype AUC=0.763 | TNM AUC=0.753 | DFS C-index=0.550 | OS C-index=0.597\n",
55
+ " Immune | Subtype AUC=0.720 | TNM AUC=0.745 | DFS C-index=0.554 | OS C-index=0.606\n",
56
+ " Chemo | Subtype AUC=0.824 | TNM AUC=0.735 | DFS C-index=0.510 | OS C-index=0.589\n",
57
+ " ✓ Best model updated\n",
58
+ "\n",
59
+ "[Epoch 005] Train Loss=3.962 | Val Loss=3.941\n",
60
+ " All | Subtype AUC=0.693 | TNM AUC=0.703 | DFS C-index=0.640 | OS C-index=0.658\n",
61
+ " Immune | Subtype AUC=0.804 | TNM AUC=0.714 | DFS C-index=0.607 | OS C-index=0.675\n",
62
+ " Chemo | Subtype AUC=0.539 | TNM AUC=0.694 | DFS C-index=0.694 | OS C-index=0.621\n",
63
+ " ✓ Best model updated\n",
64
+ "\n",
65
+ "[Epoch 006] Train Loss=3.911 | Val Loss=3.840\n",
66
+ " All | Subtype AUC=0.810 | TNM AUC=0.636 | DFS C-index=0.591 | OS C-index=0.549\n",
67
+ " Immune | Subtype AUC=0.840 | TNM AUC=0.657 | DFS C-index=0.552 | OS C-index=0.551\n",
68
+ " Chemo | Subtype AUC=0.771 | TNM AUC=0.618 | DFS C-index=0.580 | OS C-index=0.577\n",
69
+ " ✓ Best model updated\n",
70
+ "\n",
71
+ "[Epoch 007] Train Loss=3.823 | Val Loss=3.876\n",
72
+ " All | Subtype AUC=0.800 | TNM AUC=0.722 | DFS C-index=0.645 | OS C-index=0.666\n",
73
+ " Immune | Subtype AUC=0.833 | TNM AUC=0.764 | DFS C-index=0.604 | OS C-index=0.654\n",
74
+ " Chemo | Subtype AUC=0.750 | TNM AUC=0.668 | DFS C-index=0.671 | OS C-index=0.605\n",
75
+ "\n",
76
+ "[Epoch 008] Train Loss=3.779 | Val Loss=3.794\n",
77
+ " All | Subtype AUC=0.689 | TNM AUC=0.711 | DFS C-index=0.654 | OS C-index=0.616\n",
78
+ " Immune | Subtype AUC=0.729 | TNM AUC=0.729 | DFS C-index=0.680 | OS C-index=0.682\n",
79
+ " Chemo | Subtype AUC=0.637 | TNM AUC=0.725 | DFS C-index=0.624 | OS C-index=0.565\n",
80
+ " ✓ Best model updated\n",
81
+ "\n",
82
+ "[Epoch 009] Train Loss=3.725 | Val Loss=3.630\n",
83
+ " All | Subtype AUC=0.856 | TNM AUC=0.718 | DFS C-index=0.692 | OS C-index=0.677\n",
84
+ " Immune | Subtype AUC=0.897 | TNM AUC=0.791 | DFS C-index=0.713 | OS C-index=0.675\n",
85
+ " Chemo | Subtype AUC=0.785 | TNM AUC=0.652 | DFS C-index=0.678 | OS C-index=0.677\n",
86
+ " ✓ Best model updated\n",
87
+ "\n",
88
+ "[Epoch 010] Train Loss=3.706 | Val Loss=3.646\n",
89
+ " All | Subtype AUC=0.749 | TNM AUC=0.773 | DFS C-index=0.659 | OS C-index=0.585\n",
90
+ " Immune | Subtype AUC=0.817 | TNM AUC=0.723 | DFS C-index=0.638 | OS C-index=0.596\n",
91
+ " Chemo | Subtype AUC=0.649 | TNM AUC=0.855 | DFS C-index=0.710 | OS C-index=0.540\n",
92
+ "\n",
93
+ "[Epoch 011] Train Loss=3.694 | Val Loss=3.690\n",
94
+ " All | Subtype AUC=0.791 | TNM AUC=0.769 | DFS C-index=0.637 | OS C-index=0.542\n",
95
+ " Immune | Subtype AUC=0.812 | TNM AUC=0.731 | DFS C-index=0.677 | OS C-index=0.593\n",
96
+ " Chemo | Subtype AUC=0.787 | TNM AUC=0.812 | DFS C-index=0.604 | OS C-index=0.488\n",
97
+ "\n",
98
+ "[Epoch 012] Train Loss=3.600 | Val Loss=3.469\n",
99
+ " All | Subtype AUC=0.875 | TNM AUC=0.744 | DFS C-index=0.663 | OS C-index=0.662\n",
100
+ " Immune | Subtype AUC=0.896 | TNM AUC=0.754 | DFS C-index=0.671 | OS C-index=0.753\n",
101
+ " Chemo | Subtype AUC=0.887 | TNM AUC=0.749 | DFS C-index=0.635 | OS C-index=0.573\n",
102
+ " ✓ Best model updated\n",
103
+ "\n",
104
+ "[Epoch 013] Train Loss=3.641 | Val Loss=3.486\n",
105
+ " All | Subtype AUC=0.870 | TNM AUC=0.701 | DFS C-index=0.659 | OS C-index=0.649\n",
106
+ " Immune | Subtype AUC=0.912 | TNM AUC=0.701 | DFS C-index=0.708 | OS C-index=0.722\n",
107
+ " Chemo | Subtype AUC=0.811 | TNM AUC=0.724 | DFS C-index=0.600 | OS C-index=0.512\n",
108
+ "\n",
109
+ "[Epoch 014] Train Loss=3.554 | Val Loss=3.560\n",
110
+ " All | Subtype AUC=0.796 | TNM AUC=0.672 | DFS C-index=0.670 | OS C-index=0.656\n",
111
+ " Immune | Subtype AUC=0.773 | TNM AUC=0.739 | DFS C-index=0.685 | OS C-index=0.753\n",
112
+ " Chemo | Subtype AUC=0.818 | TNM AUC=0.600 | DFS C-index=0.659 | OS C-index=0.556\n",
113
+ "\n",
114
+ "[Epoch 015] Train Loss=3.490 | Val Loss=3.608\n",
115
+ " All | Subtype AUC=0.815 | TNM AUC=0.730 | DFS C-index=0.620 | OS C-index=0.622\n",
116
+ " Immune | Subtype AUC=0.850 | TNM AUC=0.846 | DFS C-index=0.657 | OS C-index=0.635\n",
117
+ " Chemo | Subtype AUC=0.752 | TNM AUC=0.590 | DFS C-index=0.565 | OS C-index=0.629\n",
118
+ "\n",
119
+ "[Epoch 016] Train Loss=3.559 | Val Loss=3.347\n",
120
+ " All | Subtype AUC=0.793 | TNM AUC=0.782 | DFS C-index=0.655 | OS C-index=0.636\n",
121
+ " Immune | Subtype AUC=0.889 | TNM AUC=0.781 | DFS C-index=0.643 | OS C-index=0.659\n",
122
+ " Chemo | Subtype AUC=0.673 | TNM AUC=0.809 | DFS C-index=0.663 | OS C-index=0.605\n",
123
+ " ✓ Best model updated\n",
124
+ "\n",
125
+ "[Epoch 017] Train Loss=3.530 | Val Loss=3.405\n",
126
+ " All | Subtype AUC=0.840 | TNM AUC=0.731 | DFS C-index=0.645 | OS C-index=0.611\n",
127
+ " Immune | Subtype AUC=0.832 | TNM AUC=0.738 | DFS C-index=0.688 | OS C-index=0.654\n",
128
+ " Chemo | Subtype AUC=0.847 | TNM AUC=0.728 | DFS C-index=0.643 | OS C-index=0.597\n",
129
+ "\n",
130
+ "[Epoch 018] Train Loss=3.545 | Val Loss=3.434\n",
131
+ " All | Subtype AUC=0.912 | TNM AUC=0.729 | DFS C-index=0.660 | OS C-index=0.595\n",
132
+ " Immune | Subtype AUC=0.979 | TNM AUC=0.677 | DFS C-index=0.702 | OS C-index=0.585\n",
133
+ " Chemo | Subtype AUC=0.804 | TNM AUC=0.767 | DFS C-index=0.671 | OS C-index=0.605\n",
134
+ "\n",
135
+ "[Epoch 019] Train Loss=3.605 | Val Loss=3.394\n",
136
+ " All | Subtype AUC=0.807 | TNM AUC=0.722 | DFS C-index=0.655 | OS C-index=0.664\n",
137
+ " Immune | Subtype AUC=0.858 | TNM AUC=0.736 | DFS C-index=0.674 | OS C-index=0.701\n",
138
+ " Chemo | Subtype AUC=0.713 | TNM AUC=0.768 | DFS C-index=0.675 | OS C-index=0.645\n",
139
+ "\n",
140
+ "[Epoch 020] Train Loss=3.386 | Val Loss=3.381\n",
141
+ " All | Subtype AUC=0.784 | TNM AUC=0.757 | DFS C-index=0.674 | OS C-index=0.646\n",
142
+ " Immune | Subtype AUC=0.796 | TNM AUC=0.732 | DFS C-index=0.696 | OS C-index=0.727\n",
143
+ " Chemo | Subtype AUC=0.758 | TNM AUC=0.818 | DFS C-index=0.675 | OS C-index=0.565\n",
144
+ "\n",
145
+ "[Epoch 021] Train Loss=3.505 | Val Loss=3.388\n",
146
+ " All | Subtype AUC=0.767 | TNM AUC=0.836 | DFS C-index=0.679 | OS C-index=0.632\n",
147
+ " Immune | Subtype AUC=0.800 | TNM AUC=0.877 | DFS C-index=0.641 | OS C-index=0.696\n",
148
+ " Chemo | Subtype AUC=0.663 | TNM AUC=0.772 | DFS C-index=0.714 | OS C-index=0.544\n",
149
+ "\n",
150
+ "[Epoch 022] Train Loss=3.577 | Val Loss=3.618\n",
151
+ " All | Subtype AUC=0.723 | TNM AUC=0.683 | DFS C-index=0.641 | OS C-index=0.569\n",
152
+ " Immune | Subtype AUC=0.711 | TNM AUC=0.592 | DFS C-index=0.699 | OS C-index=0.619\n",
153
+ " Chemo | Subtype AUC=0.745 | TNM AUC=0.765 | DFS C-index=0.635 | OS C-index=0.524\n",
154
+ "\n",
155
+ "[Epoch 023] Train Loss=3.420 | Val Loss=3.244\n",
156
+ " All | Subtype AUC=0.874 | TNM AUC=0.752 | DFS C-index=0.654 | OS C-index=0.653\n",
157
+ " Immune | Subtype AUC=0.912 | TNM AUC=0.839 | DFS C-index=0.652 | OS C-index=0.701\n",
158
+ " Chemo | Subtype AUC=0.833 | TNM AUC=0.629 | DFS C-index=0.639 | OS C-index=0.625\n",
159
+ " ✓ Best model updated\n",
160
+ "\n",
161
+ "[Epoch 024] Train Loss=3.525 | Val Loss=3.595\n",
162
+ " All | Subtype AUC=0.775 | TNM AUC=0.754 | DFS C-index=0.600 | OS C-index=0.610\n",
163
+ " Immune | Subtype AUC=0.733 | TNM AUC=0.758 | DFS C-index=0.641 | OS C-index=0.690\n",
164
+ " Chemo | Subtype AUC=0.800 | TNM AUC=0.773 | DFS C-index=0.631 | OS C-index=0.589\n",
165
+ "\n",
166
+ "[Epoch 025] Train Loss=3.571 | Val Loss=3.480\n",
167
+ " All | Subtype AUC=0.820 | TNM AUC=0.763 | DFS C-index=0.599 | OS C-index=0.627\n",
168
+ " Immune | Subtype AUC=0.825 | TNM AUC=0.742 | DFS C-index=0.557 | OS C-index=0.627\n",
169
+ " Chemo | Subtype AUC=0.837 | TNM AUC=0.796 | DFS C-index=0.616 | OS C-index=0.613\n",
170
+ "\n",
171
+ "[Epoch 026] Train Loss=3.508 | Val Loss=3.209\n",
172
+ " All | Subtype AUC=0.874 | TNM AUC=0.765 | DFS C-index=0.724 | OS C-index=0.633\n",
173
+ " Immune | Subtype AUC=0.853 | TNM AUC=0.804 | DFS C-index=0.713 | OS C-index=0.688\n",
174
+ " Chemo | Subtype AUC=0.903 | TNM AUC=0.719 | DFS C-index=0.788 | OS C-index=0.573\n",
175
+ " ✓ Best model updated\n",
176
+ "\n",
177
+ "[Epoch 027] Train Loss=3.417 | Val Loss=3.468\n",
178
+ " All | Subtype AUC=0.836 | TNM AUC=0.674 | DFS C-index=0.685 | OS C-index=0.672\n",
179
+ " Immune | Subtype AUC=0.834 | TNM AUC=0.707 | DFS C-index=0.713 | OS C-index=0.701\n",
180
+ " Chemo | Subtype AUC=0.873 | TNM AUC=0.604 | DFS C-index=0.608 | OS C-index=0.589\n",
181
+ "\n",
182
+ "[Epoch 028] Train Loss=3.443 | Val Loss=3.314\n",
183
+ " All | Subtype AUC=0.817 | TNM AUC=0.734 | DFS C-index=0.653 | OS C-index=0.670\n",
184
+ " Immune | Subtype AUC=0.875 | TNM AUC=0.738 | DFS C-index=0.604 | OS C-index=0.719\n",
185
+ " Chemo | Subtype AUC=0.744 | TNM AUC=0.745 | DFS C-index=0.694 | OS C-index=0.593\n",
186
+ "\n",
187
+ "[Epoch 029] Train Loss=3.451 | Val Loss=3.259\n",
188
+ " All | Subtype AUC=0.782 | TNM AUC=0.819 | DFS C-index=0.708 | OS C-index=0.637\n",
189
+ " Immune | Subtype AUC=0.814 | TNM AUC=0.811 | DFS C-index=0.638 | OS C-index=0.688\n",
190
+ " Chemo | Subtype AUC=0.750 | TNM AUC=0.833 | DFS C-index=0.776 | OS C-index=0.552\n",
191
+ "\n",
192
+ "[Epoch 030] Train Loss=3.556 | Val Loss=3.322\n",
193
+ " All | Subtype AUC=0.880 | TNM AUC=0.762 | DFS C-index=0.650 | OS C-index=0.662\n",
194
+ " Immune | Subtype AUC=0.884 | TNM AUC=0.756 | DFS C-index=0.660 | OS C-index=0.727\n",
195
+ " Chemo | Subtype AUC=0.863 | TNM AUC=0.773 | DFS C-index=0.584 | OS C-index=0.548\n",
196
+ "\n",
197
+ "[Epoch 031] Train Loss=3.493 | Val Loss=3.233\n",
198
+ " All | Subtype AUC=0.857 | TNM AUC=0.764 | DFS C-index=0.705 | OS C-index=0.638\n",
199
+ " Immune | Subtype AUC=0.894 | TNM AUC=0.772 | DFS C-index=0.663 | OS C-index=0.688\n",
200
+ " Chemo | Subtype AUC=0.800 | TNM AUC=0.755 | DFS C-index=0.757 | OS C-index=0.585\n",
201
+ "\n",
202
+ "[Epoch 032] Train Loss=3.483 | Val Loss=3.330\n",
203
+ " All | Subtype AUC=0.894 | TNM AUC=0.743 | DFS C-index=0.616 | OS C-index=0.595\n",
204
+ " Immune | Subtype AUC=0.917 | TNM AUC=0.792 | DFS C-index=0.613 | OS C-index=0.567\n",
205
+ " Chemo | Subtype AUC=0.869 | TNM AUC=0.644 | DFS C-index=0.561 | OS C-index=0.597\n",
206
+ "\n",
207
+ "[Epoch 033] Train Loss=3.447 | Val Loss=3.307\n",
208
+ " All | Subtype AUC=0.824 | TNM AUC=0.846 | DFS C-index=0.657 | OS C-index=0.660\n",
209
+ " Immune | Subtype AUC=0.936 | TNM AUC=0.821 | DFS C-index=0.635 | OS C-index=0.740\n",
210
+ " Chemo | Subtype AUC=0.655 | TNM AUC=0.898 | DFS C-index=0.671 | OS C-index=0.577\n",
211
+ "\n",
212
+ "[Epoch 034] Train Loss=3.396 | Val Loss=3.207\n",
213
+ " All | Subtype AUC=0.779 | TNM AUC=0.746 | DFS C-index=0.679 | OS C-index=0.693\n",
214
+ " Immune | Subtype AUC=0.747 | TNM AUC=0.774 | DFS C-index=0.641 | OS C-index=0.719\n",
215
+ " Chemo | Subtype AUC=0.855 | TNM AUC=0.703 | DFS C-index=0.584 | OS C-index=0.605\n",
216
+ " ✓ Best model updated\n",
217
+ "\n",
218
+ "[Epoch 035] Train Loss=3.362 | Val Loss=3.279\n",
219
+ " All | Subtype AUC=0.802 | TNM AUC=0.756 | DFS C-index=0.666 | OS C-index=0.667\n",
220
+ " Immune | Subtype AUC=0.861 | TNM AUC=0.826 | DFS C-index=0.749 | OS C-index=0.709\n",
221
+ " Chemo | Subtype AUC=0.733 | TNM AUC=0.658 | DFS C-index=0.580 | OS C-index=0.609\n",
222
+ "\n",
223
+ "[Epoch 036] Train Loss=3.468 | Val Loss=3.541\n",
224
+ " All | Subtype AUC=0.737 | TNM AUC=0.736 | DFS C-index=0.614 | OS C-index=0.601\n",
225
+ " Immune | Subtype AUC=0.764 | TNM AUC=0.743 | DFS C-index=0.579 | OS C-index=0.635\n",
226
+ " Chemo | Subtype AUC=0.725 | TNM AUC=0.736 | DFS C-index=0.616 | OS C-index=0.565\n",
227
+ "\n",
228
+ "[Epoch 037] Train Loss=3.259 | Val Loss=3.274\n",
229
+ " All | Subtype AUC=0.883 | TNM AUC=0.793 | DFS C-index=0.654 | OS C-index=0.630\n",
230
+ " Immune | Subtype AUC=0.857 | TNM AUC=0.751 | DFS C-index=0.660 | OS C-index=0.675\n",
231
+ " Chemo | Subtype AUC=0.905 | TNM AUC=0.837 | DFS C-index=0.655 | OS C-index=0.569\n",
232
+ "\n",
233
+ "[Epoch 038] Train Loss=3.328 | Val Loss=3.333\n",
234
+ " All | Subtype AUC=0.809 | TNM AUC=0.772 | DFS C-index=0.655 | OS C-index=0.630\n",
235
+ " Immune | Subtype AUC=0.861 | TNM AUC=0.813 | DFS C-index=0.652 | OS C-index=0.646\n",
236
+ " Chemo | Subtype AUC=0.758 | TNM AUC=0.719 | DFS C-index=0.647 | OS C-index=0.589\n",
237
+ "\n",
238
+ "[Epoch 039] Train Loss=3.409 | Val Loss=3.216\n",
239
+ " All | Subtype AUC=0.871 | TNM AUC=0.768 | DFS C-index=0.679 | OS C-index=0.625\n",
240
+ " Immune | Subtype AUC=0.882 | TNM AUC=0.744 | DFS C-index=0.783 | OS C-index=0.703\n",
241
+ " Chemo | Subtype AUC=0.865 | TNM AUC=0.823 | DFS C-index=0.624 | OS C-index=0.544\n",
242
+ "\n",
243
+ "[Epoch 040] Train Loss=3.421 | Val Loss=3.211\n",
244
+ " All | Subtype AUC=0.883 | TNM AUC=0.694 | DFS C-index=0.737 | OS C-index=0.680\n",
245
+ " Immune | Subtype AUC=0.901 | TNM AUC=0.700 | DFS C-index=0.674 | OS C-index=0.701\n",
246
+ " Chemo | Subtype AUC=0.870 | TNM AUC=0.688 | DFS C-index=0.741 | OS C-index=0.573\n",
247
+ "\n",
248
+ "[Epoch 041] Train Loss=3.318 | Val Loss=3.326\n",
249
+ " All | Subtype AUC=0.821 | TNM AUC=0.778 | DFS C-index=0.680 | OS C-index=0.678\n",
250
+ " Immune | Subtype AUC=0.855 | TNM AUC=0.817 | DFS C-index=0.613 | OS C-index=0.722\n",
251
+ " Chemo | Subtype AUC=0.781 | TNM AUC=0.724 | DFS C-index=0.769 | OS C-index=0.661\n",
252
+ "\n",
253
+ "[Epoch 042] Train Loss=3.203 | Val Loss=3.307\n",
254
+ " All | Subtype AUC=0.825 | TNM AUC=0.778 | DFS C-index=0.677 | OS C-index=0.603\n",
255
+ " Immune | Subtype AUC=0.851 | TNM AUC=0.770 | DFS C-index=0.708 | OS C-index=0.609\n",
256
+ " Chemo | Subtype AUC=0.850 | TNM AUC=0.815 | DFS C-index=0.675 | OS C-index=0.597\n",
257
+ "\n",
258
+ "⏹ Early stopping triggered\n",
259
+ "\n",
260
+ "Running inference with best model...\n",
261
+ "train | loss=3.330 | Immune=105 Chemo=75\n",
262
+ "val | loss=3.458 | Immune=34 Chemo=26\n",
263
+ "test | loss=3.877 | Immune=32 Chemo=28\n"
264
+ ]
265
+ },
266
+ {
267
+ "name": "stderr",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.\n",
271
+ "findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.\n",
272
+ "findfont: Font family ['Arial'] not found. Falling back to DejaVu Sans.\n"
273
+ ]
274
+ },
275
+ {
276
+ "name": "stdout",
277
+ "output_type": "stream",
278
+ "text": [
279
+ "✔ Figure 7 generated (DFS/OS KM + HR) for Immune/Chemo.\n"
280
+ ]
281
+ }
282
+ ],
283
+ "source": [
284
+ "import sys\n",
285
+ "sys.path.insert(0, \"/export/home/daifang/lunghospital/MM-DLS-master/MM-DLS-master\")\n",
286
+ "# main.py\n",
287
+ "import os\n",
288
+ "import sys\n",
289
+ "import numpy as np\n",
290
+ "import torch\n",
291
+ "import torch.nn as nn\n",
292
+ "from torch.utils.data import DataLoader, random_split\n",
293
+ "\n",
294
+ "from sklearn.metrics import roc_auc_score, accuracy_score\n",
295
+ "from sklearn.preprocessing import label_binarize\n",
296
+ "\n",
297
+ "import pandas as pd\n",
298
+ "import matplotlib.pyplot as plt\n",
299
+ "from lifelines import KaplanMeierFitter, CoxPHFitter\n",
300
+ "from lifelines.statistics import multivariate_logrank_test\n",
301
+ "from lifelines.utils import concordance_index\n",
302
+ "from sklearn.metrics import brier_score_loss\n",
303
+ "from scipy.stats import norm\n",
304
+ "\n",
305
+ "\n",
306
+ "\n",
307
+ "# =========================================================\n",
308
+ "# Project path (IMPORTANT for Jupyter / HPC)\n",
309
+ "# =========================================================\n",
310
+ "PROJECT_ROOT = os.path.abspath(\".\")\n",
311
+ "if PROJECT_ROOT not in sys.path:\n",
312
+ " sys.path.insert(0, PROJECT_ROOT)\n",
313
+ "\n",
314
+ "# =========================================================\n",
315
+ "# imports: mm_dls/ \n",
316
+ "# =========================================================\n",
317
+ "def _import_modules():\n",
318
+ "\n",
319
+ " from mm_dls.HierMM_DLS import HierMM_DLS\n",
320
+ " from mm_dls.FakePatientDataset import FakePatientDataset\n",
321
+ " from mm_dls.CoxphLoss import CoxPHLoss\n",
322
+ " return HierMM_DLS, FakePatientDataset, CoxPHLoss\n",
323
+ "\n",
324
+ "\n",
325
+ "HierMM_DLS, FakePatientDataset, CoxPHLoss = _import_modules()\n",
326
+ "\n",
327
+ "\n",
328
+ "# =========================\n",
329
+ "# Training configuration\n",
330
+ "# =========================\n",
331
+ "EPOCHS = 300\n",
332
+ "PATIENCE = 8\n",
333
+ "BATCH_SIZE = 4\n",
334
+ "LR = 1e-4\n",
335
+ "WEIGHT_DECAY = 1e-5\n",
336
+ "\n",
337
+ "# =========================\n",
338
+ "# Task definition\n",
339
+ "# =========================\n",
340
+ "NUM_SUBTYPES = 2 # e.g., LUAD vs LUSC\n",
341
+ "NUM_TNM = 3 # Stage I–II / III / IV\n",
342
+ "\n",
343
+ "# =========================\n",
344
+ "# Image settings\n",
345
+ "# =========================\n",
346
+ "N_SLICES = 30 # max slices per patient\n",
347
+ "IMG_SIZE = 224\n",
348
+ "\n",
349
+ "\n",
350
+ "SAVE_DIR = \"./results\"\n",
351
+ "FIG_DIR = \"./figures\"\n",
352
+ "os.makedirs(SAVE_DIR, exist_ok=True)\n",
353
+ "os.makedirs(FIG_DIR, exist_ok=True)\n",
354
+ "\n",
355
+ "# -------------------------\n",
356
+ "# GPU (force cuda:1)\n",
357
+ "# -------------------------\n",
358
+ "assert torch.cuda.is_available(), \"CUDA not available\"\n",
359
+ "DEVICE = torch.device(\"cuda:1\")\n",
360
+ "torch.cuda.set_device(DEVICE)\n",
361
+ "print(\"Using device:\", DEVICE)\n",
362
+ "\n",
363
+ "\n",
364
+ "# =========================================================\n",
365
+ "# Core utils\n",
366
+ "# =========================================================\n",
367
+ "def _sigmoid(x):\n",
368
+ " return 1 / (1 + np.exp(-x))\n",
369
+ "\n",
370
+ "def _ensure_numpy(x):\n",
371
+ " if isinstance(x, torch.Tensor):\n",
372
+ " return x.detach().cpu().numpy()\n",
373
+ " return x\n",
374
+ "\n",
375
+ "def _risk_to_groups(risk, q=(1/3, 2/3), labels=(\"Low\", \"Mediate\", \"High\")):\n",
376
+ " \"\"\"\n",
377
+ " Convert continuous risk into 3 groups by tertiles.\n",
378
+ " \"\"\"\n",
379
+ " r = np.asarray(risk).reshape(-1)\n",
380
+ " t1, t2 = np.quantile(r, q[0]), np.quantile(r, q[1])\n",
381
+ " out = np.full(len(r), labels[1], dtype=object)\n",
382
+ " out[r <= t1] = labels[0]\n",
383
+ " out[r >= t2] = labels[2]\n",
384
+ " return out\n",
385
+ "\n",
386
+ "def _evaluate_survival_metrics(time, event, risk, time_point=30):\n",
387
+ " \"\"\"\n",
388
+ " C-index + Brier at a fixed time point.\n",
389
+ " risk: higher => earlier event, so use -risk in concordance_index.\n",
390
+ " \"\"\"\n",
391
+ " time = np.asarray(time).reshape(-1)\n",
392
+ " event = np.asarray(event).reshape(-1).astype(int)\n",
393
+ " risk = np.asarray(risk).reshape(-1)\n",
394
+ "\n",
395
+ " c_index = concordance_index(time, -risk, event)\n",
396
+ "\n",
397
+ " # Brier: predict survival at time_point using a monotonic transform of risk (proxy)\n",
398
+ " # This is a \"proxy\" survival probability for demo/debug; replace with proper survival model if needed.\n",
399
+ " y_true = (time > time_point).astype(int) # 1 means survived beyond time_point\n",
400
+ " # map risk into [0,1] survival prob proxy: higher risk => lower survival prob\n",
401
+ " y_prob = 1 - (risk - risk.min()) / (risk.max() - risk.min() + 1e-8)\n",
402
+ " brier = brier_score_loss(y_true, y_prob)\n",
403
+ "\n",
404
+ " return float(c_index), float(brier)\n",
405
+ "\n",
406
+ "\n",
407
+ "# =========================================================\n",
408
+ "# One epoch (train / eval)\n",
409
+ "# =========================================================\n",
410
+ "def run_epoch_verbose(model, loader, optimizer, device, train=True):\n",
411
+ " ce = nn.CrossEntropyLoss()\n",
412
+ " bce = nn.BCEWithLogitsLoss(reduction=\"none\")\n",
413
+ " cox = CoxPHLoss()\n",
414
+ "\n",
415
+ " model.train() if train else model.eval()\n",
416
+ "\n",
417
+ " losses = []\n",
418
+ "\n",
419
+ " # classification\n",
420
+ " sub_y_all, sub_s_all = [], []\n",
421
+ " tnm_y_all, tnm_s_all = [], []\n",
422
+ " treat_all = []\n",
423
+ "\n",
424
+ " # survival (cox risk + time/event)\n",
425
+ " dfs_r_all, dfs_t_all, dfs_e_all = [], [], []\n",
426
+ " os_r_all, os_t_all, os_e_all = [], [], []\n",
427
+ "\n",
428
+ " # survival 1y/3y/5y logits (optional save)\n",
429
+ " dfs_log_all, os_log_all = [], []\n",
430
+ "\n",
431
+ " for batch in loader:\n",
432
+ " # NOTE: dataset must return 19 items including treatment\n",
433
+ " if len(batch) != 19:\n",
434
+ " raise ValueError(f\"Batch length mismatch: expected 19, got {len(batch)}. \"\n",
435
+ " f\"Please ensure Dataset __getitem__ returns treatment as the 19th item.\")\n",
436
+ "\n",
437
+ " (\n",
438
+ " pid, lesion, space, rad, pet, cli,\n",
439
+ " y_sub, y_tnm,\n",
440
+ " dfs_t, dfs_e,\n",
441
+ " os_t, os_e,\n",
442
+ " dfs1, dfs3, dfs5,\n",
443
+ " os1, os3, os5,\n",
444
+ " treatment\n",
445
+ " ) = batch\n",
446
+ "\n",
447
+ " lesion, space = lesion.to(device), space.to(device)\n",
448
+ " rad, pet, cli = rad.to(device), pet.to(device), cli.to(device)\n",
449
+ " y_sub, y_tnm = y_sub.to(device), y_tnm.to(device)\n",
450
+ " dfs_t, dfs_e = dfs_t.to(device), dfs_e.to(device)\n",
451
+ " os_t, os_e = os_t.to(device), os_e.to(device)\n",
452
+ " treatment = treatment.to(device)\n",
453
+ "\n",
454
+ " dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(device)\n",
455
+ " os_y = torch.stack([os1, os3, os5 ], dim=1).to(device)\n",
456
+ "\n",
457
+ " with torch.set_grad_enabled(train):\n",
458
+ " sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(\n",
459
+ " lesion, space, rad, pet, cli\n",
460
+ " )\n",
461
+ "\n",
462
+ " loss = (\n",
463
+ " ce(sub_l, y_sub) +\n",
464
+ " ce(tnm_l, y_tnm) +\n",
465
+ " cox(dfs_r, dfs_t, dfs_e) +\n",
466
+ " cox(os_r, os_t, os_e) +\n",
467
+ " bce(dfs_log, dfs_y).mean() +\n",
468
+ " bce(os_log, os_y ).mean()\n",
469
+ " )\n",
470
+ "\n",
471
+ " if train:\n",
472
+ " optimizer.zero_grad()\n",
473
+ " loss.backward()\n",
474
+ " optimizer.step()\n",
475
+ "\n",
476
+ " losses.append(loss.item())\n",
477
+ "\n",
478
+ " # ----- Collect predictions -----\n",
479
+ " sub_prob = torch.softmax(sub_l, dim=1)[:, 1] # subtype prob\n",
480
+ " tnm_prob = torch.softmax(tnm_l, dim=1) # [B,3]\n",
481
+ "\n",
482
+ " sub_s_all.append(_ensure_numpy(sub_prob))\n",
483
+ " sub_y_all.append(_ensure_numpy(y_sub))\n",
484
+ "\n",
485
+ " tnm_s_all.append(_ensure_numpy(tnm_prob))\n",
486
+ " tnm_y_all.append(_ensure_numpy(y_tnm))\n",
487
+ "\n",
488
+ " treat_all.append(_ensure_numpy(treatment))\n",
489
+ "\n",
490
+ " # survival\n",
491
+ " dfs_r_all.append(_ensure_numpy(dfs_r))\n",
492
+ " dfs_t_all.append(_ensure_numpy(dfs_t))\n",
493
+ " dfs_e_all.append(_ensure_numpy(dfs_e))\n",
494
+ "\n",
495
+ " os_r_all.append(_ensure_numpy(os_r))\n",
496
+ " os_t_all.append(_ensure_numpy(os_t))\n",
497
+ " os_e_all.append(_ensure_numpy(os_e))\n",
498
+ "\n",
499
+ " dfs_log_all.append(_ensure_numpy(dfs_log))\n",
500
+ " os_log_all.append(_ensure_numpy(os_log))\n",
501
+ "\n",
502
+ " return (\n",
503
+ " float(np.mean(losses)),\n",
504
+ "\n",
505
+ " np.concatenate(sub_y_all),\n",
506
+ " np.concatenate(sub_s_all),\n",
507
+ "\n",
508
+ " np.concatenate(tnm_y_all),\n",
509
+ " np.concatenate(tnm_s_all),\n",
510
+ "\n",
511
+ " np.concatenate(treat_all),\n",
512
+ "\n",
513
+ " np.concatenate(dfs_r_all),\n",
514
+ " np.concatenate(dfs_t_all),\n",
515
+ " np.concatenate(dfs_e_all),\n",
516
+ "\n",
517
+ " np.concatenate(os_r_all),\n",
518
+ " np.concatenate(os_t_all),\n",
519
+ " np.concatenate(os_e_all),\n",
520
+ "\n",
521
+ " np.concatenate(dfs_log_all, axis=0), # [N,3]\n",
522
+ " np.concatenate(os_log_all, axis=0), # [N,3]\n",
523
+ " )\n",
524
+ "\n",
525
+ "\n",
526
+ "# =========================================================\n",
527
+ "# Evaluation by cohort (classification + survival)\n",
528
+ "# =========================================================\n",
529
+ "def evaluate_by_treatment(sub_y, sub_s, tnm_y, tnm_s, treat,\n",
530
+ " dfs_r, dfs_t, dfs_e, os_r, os_t, os_e):\n",
531
+ " results = {}\n",
532
+ "\n",
533
+ " cohorts = {\n",
534
+ " \"All\": np.ones_like(treat, dtype=bool),\n",
535
+ " \"Immune\": treat == 0,\n",
536
+ " \"Chemo\": treat == 1,\n",
537
+ " }\n",
538
+ "\n",
539
+ " for name, mask in cohorts.items():\n",
540
+ " if mask.sum() < 10:\n",
541
+ " continue\n",
542
+ "\n",
543
+ " res = {}\n",
544
+ "\n",
545
+ " # Subtype (binary)\n",
546
+ " res[\"Subtype_AUC\"] = roc_auc_score(sub_y[mask], sub_s[mask])\n",
547
+ " res[\"Subtype_ACC\"] = accuracy_score(sub_y[mask], (sub_s[mask] > 0.5).astype(int))\n",
548
+ "\n",
549
+ " # TNM (multiclass macro AUC + ACC)\n",
550
+ " tnm_bin = label_binarize(tnm_y[mask], classes=[0, 1, 2])\n",
551
+ " res[\"TNM_AUC_macro\"] = roc_auc_score(\n",
552
+ " tnm_bin, tnm_s[mask], average=\"macro\", multi_class=\"ovr\"\n",
553
+ " )\n",
554
+ " res[\"TNM_ACC\"] = accuracy_score(\n",
555
+ " tnm_y[mask], np.argmax(tnm_s[mask], axis=1)\n",
556
+ " )\n",
557
+ "\n",
558
+ " # Survival\n",
559
+ " dfs_c, dfs_b = _evaluate_survival_metrics(dfs_t[mask], dfs_e[mask], dfs_r[mask], time_point=30)\n",
560
+ " os_c, os_b = _evaluate_survival_metrics(os_t[mask], os_e[mask], os_r[mask], time_point=30)\n",
561
+ "\n",
562
+ " res[\"DFS_C_index\"] = dfs_c\n",
563
+ " res[\"DFS_Brier_30m\"] = dfs_b\n",
564
+ " res[\"OS_C_index\"] = os_c\n",
565
+ " res[\"OS_Brier_30m\"] = os_b\n",
566
+ "\n",
567
+ " results[name] = res\n",
568
+ "\n",
569
+ " return results\n",
570
+ "\n",
571
+ "\n",
572
+ "# =========================================================\n",
573
+ "# Figure 7: KM + HR (per cohort, per endpoint)\n",
574
+ "# =========================================================\n",
575
+ "def plot_km_curve_with_hr(df, title, save_prefix):\n",
576
+ " \"\"\"\n",
577
+ " df must contain columns: time, event, group (Low/Mediate/High)\n",
578
+ " \"\"\"\n",
579
+ " kmf = KaplanMeierFitter()\n",
580
+ " fig, ax = plt.subplots(figsize=(8, 6), facecolor=\"white\")\n",
581
+ " ax.set_facecolor(\"white\")\n",
582
+ "\n",
583
+ " colors = {\"Low\": \"#91c7ae\", \"Mediate\": \"#f7b977\", \"High\": \"#d87c7c\"}\n",
584
+ " groups = [\"Low\", \"Mediate\", \"High\"]\n",
585
+ "\n",
586
+ " # plot KM\n",
587
+ " lines = {}\n",
588
+ " at_risk_table = []\n",
589
+ " times = np.arange(0, 70, 10)\n",
590
+ "\n",
591
+ " for g in groups:\n",
592
+ " m = df[\"group\"] == g\n",
593
+ " if m.sum() == 0:\n",
594
+ " continue\n",
595
+ "\n",
596
+ " kmf.fit(df.loc[m, \"time\"], event_observed=df.loc[m, \"event\"], label=g)\n",
597
+ " kmf.plot_survival_function(\n",
598
+ " ax=ax, ci_show=True, linewidth=2, color=colors[g], marker=\"+\"\n",
599
+ " )\n",
600
+ " lines[g] = ax.get_lines()[-1]\n",
601
+ "\n",
602
+ " at_risk_table.append([np.sum(df.loc[m, \"time\"] >= t) for t in times])\n",
603
+ "\n",
604
+ " # legend\n",
605
+ " handles = [lines[g] for g in groups if g in lines]\n",
606
+ " labels = [\"Low\", \"Medium\", \"High\"][:len(handles)]\n",
607
+ " ax.legend(handles, labels, title=\"Groups\", loc=\"upper right\",\n",
608
+ " frameon=True, framealpha=0.5, fontsize=12, title_fontsize=12)\n",
609
+ "\n",
610
+ " # at risk numbers (optional, matches your style)\n",
611
+ " if len(at_risk_table) == 3:\n",
612
+ " low, mid, high = at_risk_table\n",
613
+ " for i, t in enumerate(times):\n",
614
+ " ax.text(t, -0.38, str(low[i]), color=\"#207f4c\", fontsize=14, ha=\"center\")\n",
615
+ " ax.text(t, -0.48, str(mid[i]), color=\"#fca106\", fontsize=14, ha=\"center\")\n",
616
+ " ax.text(t, -0.58, str(high[i]), color=\"#cc163a\", fontsize=14, ha=\"center\")\n",
617
+ "\n",
618
+ " ax.text(-1, -0.28, \"Number at risk\", color=\"black\", ha=\"center\", fontsize=14)\n",
619
+ " ax.text(-10, -0.38, \"Low\", color=\"#207f4c\", fontsize=14)\n",
620
+ " ax.text(-10, -0.48, \"Medium\", color=\"#fca106\", fontsize=14)\n",
621
+ " ax.text(-10, -0.58, \"High\", color=\"#cc163a\", fontsize=14)\n",
622
+ "\n",
623
+ " # Cox HR + p-values\n",
624
+ " df2 = df.copy()\n",
625
+ " df2[\"group_code\"] = df2[\"group\"].map({\"Low\": 0, \"Mediate\": 1, \"High\": 2})\n",
626
+ " cph = CoxPHFitter()\n",
627
+ " cph.fit(df2[[\"time\", \"event\", \"group_code\"]], duration_col=\"time\", event_col=\"event\")\n",
628
+ "\n",
629
+ " coef = float(cph.params_[\"group_code\"])\n",
630
+ " se = float(cph.standard_errors_[\"group_code\"])\n",
631
+ "\n",
632
+ " hr_med_vs_low = np.exp(coef * 1)\n",
633
+ " hr_high_vs_low = np.exp(coef * 2)\n",
634
+ "\n",
635
+ " z_med = (coef * 1) / se\n",
636
+ " p_med = 2 * (1 - norm.cdf(abs(z_med)))\n",
637
+ "\n",
638
+ " z_high = (coef * 2) / se\n",
639
+ " p_high = 2 * (1 - norm.cdf(abs(z_high)))\n",
640
+ "\n",
641
+ " # logrank\n",
642
+ " res_lr = multivariate_logrank_test(df2[\"time\"], df2[\"group\"], df2[\"event\"])\n",
643
+ "\n",
644
+ " # C-index + brier (proxy)\n",
645
+ " c_index, brier = _evaluate_survival_metrics(df2[\"time\"].values, df2[\"event\"].values,\n",
646
+ " df2[\"group_code\"].values, time_point=30)\n",
647
+ "\n",
648
+ " ax.text(25, 0.46, f\"P(log-rank)={res_lr.p_value:.3f}\", fontsize=12)\n",
649
+ " ax.text(25, 0.36, f\"C-index={c_index:.3f}\", fontsize=12)\n",
650
+ " ax.text(25, 0.26, f\"Brier(30m)={brier:.3f}\", fontsize=12)\n",
651
+ " ax.text(25, 0.16, f\"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}\", fontsize=12)\n",
652
+ " ax.text(25, 0.06, f\"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}\", fontsize=12)\n",
653
+ "\n",
654
+ " # cosmetics\n",
655
+ " ax.spines[\"top\"].set_visible(False)\n",
656
+ " ax.spines[\"right\"].set_visible(False)\n",
657
+ " ax.set_title(title, fontsize=14)\n",
658
+ " ax.set_xlabel(\"Time since treatment start (months)\", fontsize=14)\n",
659
+ " ax.set_ylabel(\"Survival probability\", fontsize=14)\n",
660
+ " ax.set_ylim(0, 1.05)\n",
661
+ " ax.grid(alpha=0.3)\n",
662
+ "\n",
663
+ " plt.tight_layout()\n",
664
+ " plt.savefig(save_prefix + \".png\", dpi=600, bbox_inches=\"tight\")\n",
665
+ " plt.savefig(save_prefix + \".pdf\", dpi=600, bbox_inches=\"tight\")\n",
666
+ " plt.close()\n",
667
+ " return save_prefix\n",
668
+ "\n",
669
+ "\n",
670
+ "def generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=(\"val\", \"test\")):\n",
671
+ " \"\"\"\n",
672
+ " Load saved dfs/os arrays and generate KM+HR for Immune/Chemo separately.\n",
673
+ " \"\"\"\n",
674
+ " os.makedirs(fig_dir, exist_ok=True)\n",
675
+ "\n",
676
+ " for split in which_split:\n",
677
+ " # load arrays\n",
678
+ " trt = np.load(os.path.join(result_dir, f\"treatment_{split}.npy\"))\n",
679
+ "\n",
680
+ " dfs_r = np.load(os.path.join(result_dir, f\"dfs_{split}_risk.npy\"))\n",
681
+ " dfs_t = np.load(os.path.join(result_dir, f\"dfs_{split}_time.npy\"))\n",
682
+ " dfs_e = np.load(os.path.join(result_dir, f\"dfs_{split}_event.npy\"))\n",
683
+ "\n",
684
+ " os_r = np.load(os.path.join(result_dir, f\"os_{split}_risk.npy\"))\n",
685
+ " os_t = np.load(os.path.join(result_dir, f\"os_{split}_time.npy\"))\n",
686
+ " os_e = np.load(os.path.join(result_dir, f\"os_{split}_event.npy\"))\n",
687
+ "\n",
688
+ " for cohort_name, mask in {\n",
689
+ " \"Immune\": trt == 0,\n",
690
+ " \"Chemo\": trt == 1\n",
691
+ " }.items():\n",
692
+ " if mask.sum() < 20:\n",
693
+ " print(f\"[Figure7] Skip {split}-{cohort_name}: too few samples ({mask.sum()})\")\n",
694
+ " continue\n",
695
+ "\n",
696
+ " # DFS groups\n",
697
+ " dfs_group = _risk_to_groups(dfs_r[mask])\n",
698
+ " df_dfs = pd.DataFrame({\n",
699
+ " \"time\": dfs_t[mask],\n",
700
+ " \"event\": dfs_e[mask].astype(int),\n",
701
+ " \"group\": dfs_group\n",
702
+ " })\n",
703
+ "\n",
704
+ " # OS groups\n",
705
+ " os_group = _risk_to_groups(os_r[mask])\n",
706
+ " df_os = pd.DataFrame({\n",
707
+ " \"time\": os_t[mask],\n",
708
+ " \"event\": os_e[mask].astype(int),\n",
709
+ " \"group\": os_group\n",
710
+ " })\n",
711
+ "\n",
712
+ " # save CSV (optional, for reproducibility)\n",
713
+ " df_dfs.to_csv(os.path.join(result_dir, f\"dfs_{split}_{cohort_name}.csv\"), index=False)\n",
714
+ " df_os.to_csv(os.path.join(result_dir, f\"os_{split}_{cohort_name}.csv\"), index=False)\n",
715
+ "\n",
716
+ " # plot\n",
717
+ " plot_km_curve_with_hr(\n",
718
+ " df_dfs,\n",
719
+ " title=f\"Disease-Free Survival (DFS) — Kaplan-Meier Curves\\n{cohort_name} {split} set (n={mask.sum()})\",\n",
720
+ " save_prefix=os.path.join(fig_dir, f\"Figure7_DFS_{cohort_name}_{split}\")\n",
721
+ " )\n",
722
+ " plot_km_curve_with_hr(\n",
723
+ " df_os,\n",
724
+ " title=f\"Overall Survival (OS) — Kaplan-Meier Curves\\n{cohort_name} {split} set (n={mask.sum()})\",\n",
725
+ " save_prefix=os.path.join(fig_dir, f\"Figure7_OS_{cohort_name}_{split}\")\n",
726
+ " )\n",
727
+ "\n",
728
+ " print(\"✔ Figure 7 generated (DFS/OS KM + HR) for Immune/Chemo.\")\n",
729
+ "\n",
730
+ "\n",
731
+ "# =========================================================\n",
732
+ "# Main\n",
733
+ "# =========================================================\n",
734
+ "def main():\n",
735
+ " # -------------------------\n",
736
+ " # Dataset (must return treatment as 19th item)\n",
737
+ " # -------------------------\n",
738
+ " from mm_dls.PatientDataset import PatientDataset\n",
739
+ "\n",
740
+ " dataset = PatientDataset(\n",
741
+ " data_root=\"/path/to/DATA_ROOT\",\n",
742
+ " clinical_csv=\"/path/to/clinical.csv\",\n",
743
+ " radiomics_npy=\"/path/to/radiomics.npy\",\n",
744
+ " pet_npy=\"/path/to/pet.npy\",\n",
745
+ " n_slices=N_SLICES,\n",
746
+ " img_size=IMG_SIZE,\n",
747
+ " )\n",
748
+ "\n",
749
+ "\n",
750
+ " n_train = int(0.6 * len(dataset))\n",
751
+ " n_val = int(0.2 * len(dataset))\n",
752
+ " n_test = len(dataset) - n_train - n_val\n",
753
+ "\n",
754
+ " train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])\n",
755
+ "\n",
756
+ " loaders = {\n",
757
+ " \"train\": DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=4),\n",
758
+ " \"val\": DataLoader(val_set, BATCH_SIZE, shuffle=False, num_workers=4),\n",
759
+ " \"test\": DataLoader(test_set, BATCH_SIZE, shuffle=False, num_workers=4),\n",
760
+ " }\n",
761
+ "\n",
762
+ " # -------------------------\n",
763
+ " # Model\n",
764
+ " # -------------------------\n",
765
+ " model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)\n",
766
+ " optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
767
+ "\n",
768
+ " best_val_loss = 1e9\n",
769
+ " wait = 0\n",
770
+ "\n",
771
+ " # -------------------------\n",
772
+ " # Training\n",
773
+ " # -------------------------\n",
774
+ " for epoch in range(1, EPOCHS + 1):\n",
775
+ " tr = run_epoch_verbose(model, loaders[\"train\"], optimizer, DEVICE, train=True)\n",
776
+ " va = run_epoch_verbose(model, loaders[\"val\"], optimizer, DEVICE, train=False)\n",
777
+ "\n",
778
+ " tr_loss = tr[0]\n",
779
+ " va_loss = va[0]\n",
780
+ "\n",
781
+ " # unpack val for metrics\n",
782
+ " _, sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e, _, _ = va\n",
783
+ " metrics = evaluate_by_treatment(sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e)\n",
784
+ "\n",
785
+ " print(f\"\\n[Epoch {epoch:03d}] Train Loss={tr_loss:.3f} | Val Loss={va_loss:.3f}\")\n",
786
+ " for k, v in metrics.items():\n",
787
+ " print(\n",
788
+ " f\" {k:7s} | \"\n",
789
+ " f\"Subtype AUC={v['Subtype_AUC']:.3f} | \"\n",
790
+ " f\"TNM AUC={v['TNM_AUC_macro']:.3f} | \"\n",
791
+ " f\"DFS C-index={v['DFS_C_index']:.3f} | \"\n",
792
+ " f\"OS C-index={v['OS_C_index']:.3f}\"\n",
793
+ " )\n",
794
+ "\n",
795
+ " # early stopping\n",
796
+ " if va_loss < best_val_loss:\n",
797
+ " best_val_loss = va_loss\n",
798
+ " wait = 0\n",
799
+ " torch.save(model.state_dict(), os.path.join(SAVE_DIR, \"best_model.pt\"))\n",
800
+ " print(\" ✓ Best model updated\")\n",
801
+ " else:\n",
802
+ " wait += 1\n",
803
+ " if wait >= PATIENCE:\n",
804
+ " print(\"\\n⏹ Early stopping triggered\")\n",
805
+ " break\n",
806
+ "\n",
807
+ " # -------------------------\n",
808
+ " # Inference (best model)\n",
809
+ " # -------------------------\n",
810
+ " print(\"\\nRunning inference with best model...\")\n",
811
+ " model.load_state_dict(torch.load(os.path.join(SAVE_DIR, \"best_model.pt\"), map_location=DEVICE))\n",
812
+ "\n",
813
+ " for split in [\"train\", \"val\", \"test\"]:\n",
814
+ " out = run_epoch_verbose(model, loaders[split], optimizer, DEVICE, train=False)\n",
815
+ " (\n",
816
+ " loss,\n",
817
+ " sy, ss,\n",
818
+ " ty, ts,\n",
819
+ " trt,\n",
820
+ " dfs_r, dfs_t, dfs_e,\n",
821
+ " os_r, os_t, os_e,\n",
822
+ " dfs_log, os_log\n",
823
+ " ) = out\n",
824
+ "\n",
825
+ " # classification\n",
826
+ " np.save(os.path.join(SAVE_DIR, f\"subtype_{split}_labels.npy\"), sy)\n",
827
+ " np.save(os.path.join(SAVE_DIR, f\"subtype_{split}_scores.npy\"), ss)\n",
828
+ " np.save(os.path.join(SAVE_DIR, f\"tnm_{split}_labels.npy\"), ty)\n",
829
+ " np.save(os.path.join(SAVE_DIR, f\"tnm_{split}_scores.npy\"), ts)\n",
830
+ " np.save(os.path.join(SAVE_DIR, f\"treatment_{split}.npy\"), trt)\n",
831
+ "\n",
832
+ " # survival (cox risk + time/event)\n",
833
+ " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_risk.npy\"), dfs_r)\n",
834
+ " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_time.npy\"), dfs_t)\n",
835
+ " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_event.npy\"), dfs_e)\n",
836
+ "\n",
837
+ " np.save(os.path.join(SAVE_DIR, f\"os_{split}_risk.npy\"), os_r)\n",
838
+ " np.save(os.path.join(SAVE_DIR, f\"os_{split}_time.npy\"), os_t)\n",
839
+ " np.save(os.path.join(SAVE_DIR, f\"os_{split}_event.npy\"), os_e)\n",
840
+ "\n",
841
+ " # 1y/3y/5y logits (optional, for AUC at specific horizons)\n",
842
+ " np.save(os.path.join(SAVE_DIR, f\"dfs_{split}_logits_1y3y5y.npy\"), dfs_log)\n",
843
+ " np.save(os.path.join(SAVE_DIR, f\"os_{split}_logits_1y3y5y.npy\"), os_log)\n",
844
+ "\n",
845
+ " print(f\"{split:5s} | loss={loss:.3f} | Immune={np.sum(trt==0)} Chemo={np.sum(trt==1)}\")\n",
846
+ "\n",
847
+ " print(\"\\n✓ Inference completed. Results saved.\")\n",
848
+ "\n",
849
+ " # -------------------------\n",
850
+ " # Figure: Immune/Chemo KM + HR\n",
851
+ " # -------------------------\n",
852
+ " print(\"\\nGenerating Figure (KM + HR) ...\")\n",
853
+ " generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=(\"val\", \"test\"))\n",
854
+ " print(\"✓ Figure done. Files saved under ./figures\")\n",
855
+ "\n",
856
+ "\n",
857
+ "if __name__ == \"__main__\":\n",
858
+ " main()\n"
859
+ ]
860
+ }
861
+ ],
862
+ "metadata": {
863
+ "kernelspec": {
864
+ "display_name": "VGG",
865
+ "language": "python",
866
+ "name": "python3"
867
+ },
868
+ "language_info": {
869
+ "codemirror_mode": {
870
+ "name": "ipython",
871
+ "version": 3
872
+ },
873
+ "file_extension": ".py",
874
+ "mimetype": "text/x-python",
875
+ "name": "python",
876
+ "nbconvert_exporter": "python",
877
+ "pygments_lexer": "ipython3",
878
+ "version": "3.6.8"
879
+ }
880
+ },
881
+ "nbformat": 4,
882
+ "nbformat_minor": 2
883
+ }
test.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_mm_dls.py
2
+ # =========================================================
3
+ # 🔍 Minimal test for MM-DLS pipeline
4
+ # - CUDA
5
+ # - forward / loss
6
+ # - pandas / lifelines (GLIBCXX check)
7
+ # =========================================================
8
+
9
+ import os
10
+ import sys
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import DataLoader, Subset
15
+
16
+ import pandas as pd
17
+ from lifelines import KaplanMeierFitter
18
+ from lifelines.utils import concordance_index
19
+
20
+ # ---------------------------------------------------------
21
+ # Project path
22
+ # ---------------------------------------------------------
23
+ PROJECT_ROOT = os.path.abspath(".")
24
+ if PROJECT_ROOT not in sys.path:
25
+ sys.path.insert(0, PROJECT_ROOT)
26
+
27
+ # ---------------------------------------------------------
28
+ # Imports from mm_dls
29
+ # ---------------------------------------------------------
30
+ from mm_dls.HierMM_DLS import HierMM_DLS
31
+ from mm_dls.CoxphLoss import CoxPHLoss
32
+ from mm_dls.PatientDataset import PatientDataset
33
+
34
+
35
+ # =========================================================
36
+ # Basic config (VERY SMALL)
37
+ # =========================================================
38
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39
+ print("Using device:", DEVICE)
40
+
41
+ BATCH_SIZE = 2
42
+ NUM_SUBTYPES = 2
43
+ NUM_TNM = 3
44
+ N_SLICES = 30
45
+ IMG_SIZE = 224
46
+
47
+
48
+ # =========================================================
49
+ # Test Dataset Loader
50
+ # =========================================================
51
+ def get_test_loader():
52
+ dataset = PatientDataset(
53
+ data_root="/path/to/DATA_ROOT",
54
+ clinical_csv="/path/to/clinical.csv",
55
+ radiomics_npy="/path/to/radiomics.npy",
56
+ pet_npy="/path/to/pet.npy",
57
+ n_slices=N_SLICES,
58
+ img_size=IMG_SIZE,
59
+ )
60
+
61
+ # 🔑 只取前 8 个样本
62
+ idx = list(range(min(8, len(dataset))))
63
+ subset = Subset(dataset, idx)
64
+
65
+ loader = DataLoader(
66
+ subset,
67
+ batch_size=BATCH_SIZE,
68
+ shuffle=False,
69
+ num_workers=2,
70
+ )
71
+ return loader
72
+
73
+
74
+ # =========================================================
75
+ # One forward + loss
76
+ # =========================================================
77
+ def test_forward_and_loss():
78
+ print("\n[TEST] Forward + Loss")
79
+
80
+ loader = get_test_loader()
81
+ model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)
82
+
83
+ ce = nn.CrossEntropyLoss()
84
+ bce = nn.BCEWithLogitsLoss()
85
+ cox = CoxPHLoss()
86
+
87
+ model.eval()
88
+
89
+ for batch in loader:
90
+ assert len(batch) == 19, f"Dataset must return 19 items, got {len(batch)}"
91
+
92
+ (
93
+ pid, lesion, space, rad, pet, cli,
94
+ y_sub, y_tnm,
95
+ dfs_t, dfs_e,
96
+ os_t, os_e,
97
+ dfs1, dfs3, dfs5,
98
+ os1, os3, os5,
99
+ treatment
100
+ ) = batch
101
+
102
+ lesion, space = lesion.to(DEVICE), space.to(DEVICE)
103
+ rad, pet, cli = rad.to(DEVICE), pet.to(DEVICE), cli.to(DEVICE)
104
+ y_sub, y_tnm = y_sub.to(DEVICE), y_tnm.to(DEVICE)
105
+ dfs_t, dfs_e = dfs_t.to(DEVICE), dfs_e.to(DEVICE)
106
+ os_t, os_e = os_t.to(DEVICE), os_e.to(DEVICE)
107
+
108
+ dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(DEVICE)
109
+ os_y = torch.stack([os1, os3, os5 ], dim=1).to(DEVICE)
110
+
111
+ with torch.no_grad():
112
+ sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(
113
+ lesion, space, rad, pet, cli
114
+ )
115
+
116
+ loss = (
117
+ ce(sub_l, y_sub) +
118
+ ce(tnm_l, y_tnm) +
119
+ cox(dfs_r, dfs_t, dfs_e) +
120
+ cox(os_r, os_t, os_e) +
121
+ bce(dfs_log, dfs_y) +
122
+ bce(os_log, os_y)
123
+ )
124
+
125
+ print(" ✓ Forward OK | Loss =", float(loss))
126
+ break
127
+
128
+
129
+ # =========================================================
130
+ # Test pandas + lifelines (GLIBCXX killer)
131
+ # =========================================================
132
+ def test_pandas_lifelines():
133
+ print("\n[TEST] pandas + lifelines")
134
+
135
+ # fake survival data
136
+ time = np.array([10, 12, 8, 20, 15, 25])
137
+ event = np.array([1, 1, 0, 1, 0, 0])
138
+ risk = np.array([0.9, 0.8, 0.2, 1.2, 0.3, 0.4])
139
+
140
+ # pandas
141
+ df = pd.DataFrame({
142
+ "time": time,
143
+ "event": event,
144
+ "risk": risk
145
+ })
146
+
147
+ print(" pandas OK:", df.shape)
148
+
149
+ # C-index
150
+ cidx = concordance_index(df["time"], -df["risk"], df["event"])
151
+ print(" C-index =", round(cidx, 3))
152
+
153
+ # KM
154
+ kmf = KaplanMeierFitter()
155
+ kmf.fit(df["time"], event_observed=df["event"])
156
+ surv_10 = kmf.predict(10)
157
+
158
+ print(" KM survival@10 =", float(surv_10))
159
+ print(" ✓ lifelines OK")
160
+
161
+
162
+ # =========================================================
163
+ # Main
164
+ # =========================================================
165
+ if __name__ == "__main__":
166
+ print("\n==============================")
167
+ print(" MM-DLS TEST START ")
168
+ print("==============================")
169
+
170
+ test_forward_and_loss()
171
+ test_pandas_lifelines()
172
+
173
+ print("\n✅ ALL TESTS PASSED")
train_patient_model.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.insert(0, "/export/home/daifang/lunghospital/MM-DLS-master/MM-DLS-master")
3
+ # main.py
4
+ import os
5
+ import sys
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import DataLoader, random_split
10
+
11
+ from sklearn.metrics import roc_auc_score, accuracy_score
12
+ from sklearn.preprocessing import label_binarize
13
+
14
+ import pandas as pd
15
+ import matplotlib.pyplot as plt
16
+ from lifelines import KaplanMeierFitter, CoxPHFitter
17
+ from lifelines.statistics import multivariate_logrank_test
18
+ from lifelines.utils import concordance_index
19
+ from sklearn.metrics import brier_score_loss
20
+ from scipy.stats import norm
21
+
22
+
23
+
24
+ # =========================================================
25
+ # Project path (IMPORTANT for Jupyter / HPC)
26
+ # =========================================================
27
+ PROJECT_ROOT = os.path.abspath(".")
28
+ if PROJECT_ROOT not in sys.path:
29
+ sys.path.insert(0, PROJECT_ROOT)
30
+
31
+ # =========================================================
32
+ # imports: mm_dls/
33
+ # =========================================================
34
+ def _import_modules():
35
+
36
+ from mm_dls.HierMM_DLS import HierMM_DLS
37
+ from mm_dls.FakePatientDataset import FakePatientDataset
38
+ from mm_dls.CoxphLoss import CoxPHLoss
39
+ return HierMM_DLS, FakePatientDataset, CoxPHLoss
40
+
41
+
42
+ HierMM_DLS, FakePatientDataset, CoxPHLoss = _import_modules()
43
+
44
+
45
+ # =========================
46
+ # Training configuration
47
+ # =========================
48
+ EPOCHS = 300
49
+ PATIENCE = 8
50
+ BATCH_SIZE = 4
51
+ LR = 1e-4
52
+ WEIGHT_DECAY = 1e-5
53
+
54
+ # =========================
55
+ # Task definition
56
+ # =========================
57
+ NUM_SUBTYPES = 2 # e.g., LUAD vs LUSC
58
+ NUM_TNM = 3 # Stage I–II / III / IV
59
+
60
+ # =========================
61
+ # Image settings
62
+ # =========================
63
+ N_SLICES = 30 # max slices per patient
64
+ IMG_SIZE = 224
65
+
66
+
67
+ SAVE_DIR = "./results"
68
+ FIG_DIR = "./figures"
69
+ os.makedirs(SAVE_DIR, exist_ok=True)
70
+ os.makedirs(FIG_DIR, exist_ok=True)
71
+
72
+ # -------------------------
73
+ # GPU (force cuda:1)
74
+ # -------------------------
75
+ assert torch.cuda.is_available(), "CUDA not available"
76
+ DEVICE = torch.device("cuda:1")
77
+ torch.cuda.set_device(DEVICE)
78
+ print("Using device:", DEVICE)
79
+
80
+
81
+ # =========================================================
82
+ # Core utils
83
+ # =========================================================
84
+ def _sigmoid(x):
85
+ return 1 / (1 + np.exp(-x))
86
+
87
+ def _ensure_numpy(x):
88
+ if isinstance(x, torch.Tensor):
89
+ return x.detach().cpu().numpy()
90
+ return x
91
+
92
+ def _risk_to_groups(risk, q=(1/3, 2/3), labels=("Low", "Mediate", "High")):
93
+ """
94
+ Convert continuous risk into 3 groups by tertiles.
95
+ """
96
+ r = np.asarray(risk).reshape(-1)
97
+ t1, t2 = np.quantile(r, q[0]), np.quantile(r, q[1])
98
+ out = np.full(len(r), labels[1], dtype=object)
99
+ out[r <= t1] = labels[0]
100
+ out[r >= t2] = labels[2]
101
+ return out
102
+
103
+ def _evaluate_survival_metrics(time, event, risk, time_point=30):
104
+ """
105
+ C-index + Brier at a fixed time point.
106
+ risk: higher => earlier event, so use -risk in concordance_index.
107
+ """
108
+ time = np.asarray(time).reshape(-1)
109
+ event = np.asarray(event).reshape(-1).astype(int)
110
+ risk = np.asarray(risk).reshape(-1)
111
+
112
+ c_index = concordance_index(time, -risk, event)
113
+
114
+ # Brier: predict survival at time_point using a monotonic transform of risk (proxy)
115
+ # This is a "proxy" survival probability for demo/debug; replace with proper survival model if needed.
116
+ y_true = (time > time_point).astype(int) # 1 means survived beyond time_point
117
+ # map risk into [0,1] survival prob proxy: higher risk => lower survival prob
118
+ y_prob = 1 - (risk - risk.min()) / (risk.max() - risk.min() + 1e-8)
119
+ brier = brier_score_loss(y_true, y_prob)
120
+
121
+ return float(c_index), float(brier)
122
+
123
+
124
+ # =========================================================
125
+ # One epoch (train / eval)
126
+ # =========================================================
127
+ def run_epoch_verbose(model, loader, optimizer, device, train=True):
128
+ ce = nn.CrossEntropyLoss()
129
+ bce = nn.BCEWithLogitsLoss(reduction="none")
130
+ cox = CoxPHLoss()
131
+
132
+ model.train() if train else model.eval()
133
+
134
+ losses = []
135
+
136
+ # classification
137
+ sub_y_all, sub_s_all = [], []
138
+ tnm_y_all, tnm_s_all = [], []
139
+ treat_all = []
140
+
141
+ # survival (cox risk + time/event)
142
+ dfs_r_all, dfs_t_all, dfs_e_all = [], [], []
143
+ os_r_all, os_t_all, os_e_all = [], [], []
144
+
145
+ # survival 1y/3y/5y logits (optional save)
146
+ dfs_log_all, os_log_all = [], []
147
+
148
+ for batch in loader:
149
+ # NOTE: dataset must return 19 items including treatment
150
+ if len(batch) != 19:
151
+ raise ValueError(f"Batch length mismatch: expected 19, got {len(batch)}. "
152
+ f"Please ensure Dataset __getitem__ returns treatment as the 19th item.")
153
+
154
+ (
155
+ pid, lesion, space, rad, pet, cli,
156
+ y_sub, y_tnm,
157
+ dfs_t, dfs_e,
158
+ os_t, os_e,
159
+ dfs1, dfs3, dfs5,
160
+ os1, os3, os5,
161
+ treatment
162
+ ) = batch
163
+
164
+ lesion, space = lesion.to(device), space.to(device)
165
+ rad, pet, cli = rad.to(device), pet.to(device), cli.to(device)
166
+ y_sub, y_tnm = y_sub.to(device), y_tnm.to(device)
167
+ dfs_t, dfs_e = dfs_t.to(device), dfs_e.to(device)
168
+ os_t, os_e = os_t.to(device), os_e.to(device)
169
+ treatment = treatment.to(device)
170
+
171
+ dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(device)
172
+ os_y = torch.stack([os1, os3, os5 ], dim=1).to(device)
173
+
174
+ with torch.set_grad_enabled(train):
175
+ sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(
176
+ lesion, space, rad, pet, cli
177
+ )
178
+
179
+ loss = (
180
+ ce(sub_l, y_sub) +
181
+ ce(tnm_l, y_tnm) +
182
+ cox(dfs_r, dfs_t, dfs_e) +
183
+ cox(os_r, os_t, os_e) +
184
+ bce(dfs_log, dfs_y).mean() +
185
+ bce(os_log, os_y ).mean()
186
+ )
187
+
188
+ if train:
189
+ optimizer.zero_grad()
190
+ loss.backward()
191
+ optimizer.step()
192
+
193
+ losses.append(loss.item())
194
+
195
+ # ----- Collect predictions -----
196
+ sub_prob = torch.softmax(sub_l, dim=1)[:, 1] # subtype prob
197
+ tnm_prob = torch.softmax(tnm_l, dim=1) # [B,3]
198
+
199
+ sub_s_all.append(_ensure_numpy(sub_prob))
200
+ sub_y_all.append(_ensure_numpy(y_sub))
201
+
202
+ tnm_s_all.append(_ensure_numpy(tnm_prob))
203
+ tnm_y_all.append(_ensure_numpy(y_tnm))
204
+
205
+ treat_all.append(_ensure_numpy(treatment))
206
+
207
+ # survival
208
+ dfs_r_all.append(_ensure_numpy(dfs_r))
209
+ dfs_t_all.append(_ensure_numpy(dfs_t))
210
+ dfs_e_all.append(_ensure_numpy(dfs_e))
211
+
212
+ os_r_all.append(_ensure_numpy(os_r))
213
+ os_t_all.append(_ensure_numpy(os_t))
214
+ os_e_all.append(_ensure_numpy(os_e))
215
+
216
+ dfs_log_all.append(_ensure_numpy(dfs_log))
217
+ os_log_all.append(_ensure_numpy(os_log))
218
+
219
+ return (
220
+ float(np.mean(losses)),
221
+
222
+ np.concatenate(sub_y_all),
223
+ np.concatenate(sub_s_all),
224
+
225
+ np.concatenate(tnm_y_all),
226
+ np.concatenate(tnm_s_all),
227
+
228
+ np.concatenate(treat_all),
229
+
230
+ np.concatenate(dfs_r_all),
231
+ np.concatenate(dfs_t_all),
232
+ np.concatenate(dfs_e_all),
233
+
234
+ np.concatenate(os_r_all),
235
+ np.concatenate(os_t_all),
236
+ np.concatenate(os_e_all),
237
+
238
+ np.concatenate(dfs_log_all, axis=0), # [N,3]
239
+ np.concatenate(os_log_all, axis=0), # [N,3]
240
+ )
241
+
242
+
243
+ # =========================================================
244
+ # Evaluation by cohort (classification + survival)
245
+ # =========================================================
246
+ def evaluate_by_treatment(sub_y, sub_s, tnm_y, tnm_s, treat,
247
+ dfs_r, dfs_t, dfs_e, os_r, os_t, os_e):
248
+ results = {}
249
+
250
+ cohorts = {
251
+ "All": np.ones_like(treat, dtype=bool),
252
+ "Immune": treat == 0,
253
+ "Chemo": treat == 1,
254
+ }
255
+
256
+ for name, mask in cohorts.items():
257
+ if mask.sum() < 10:
258
+ continue
259
+
260
+ res = {}
261
+
262
+ # Subtype (binary)
263
+ res["Subtype_AUC"] = roc_auc_score(sub_y[mask], sub_s[mask])
264
+ res["Subtype_ACC"] = accuracy_score(sub_y[mask], (sub_s[mask] > 0.5).astype(int))
265
+
266
+ # TNM (multiclass macro AUC + ACC)
267
+ tnm_bin = label_binarize(tnm_y[mask], classes=[0, 1, 2])
268
+ res["TNM_AUC_macro"] = roc_auc_score(
269
+ tnm_bin, tnm_s[mask], average="macro", multi_class="ovr"
270
+ )
271
+ res["TNM_ACC"] = accuracy_score(
272
+ tnm_y[mask], np.argmax(tnm_s[mask], axis=1)
273
+ )
274
+
275
+ # Survival
276
+ dfs_c, dfs_b = _evaluate_survival_metrics(dfs_t[mask], dfs_e[mask], dfs_r[mask], time_point=30)
277
+ os_c, os_b = _evaluate_survival_metrics(os_t[mask], os_e[mask], os_r[mask], time_point=30)
278
+
279
+ res["DFS_C_index"] = dfs_c
280
+ res["DFS_Brier_30m"] = dfs_b
281
+ res["OS_C_index"] = os_c
282
+ res["OS_Brier_30m"] = os_b
283
+
284
+ results[name] = res
285
+
286
+ return results
287
+
288
+
289
+ # =========================================================
290
+ # Figure 7: KM + HR (per cohort, per endpoint)
291
+ # =========================================================
292
+ def plot_km_curve_with_hr(df, title, save_prefix):
293
+ """
294
+ df must contain columns: time, event, group (Low/Mediate/High)
295
+ """
296
+ kmf = KaplanMeierFitter()
297
+ fig, ax = plt.subplots(figsize=(8, 6), facecolor="white")
298
+ ax.set_facecolor("white")
299
+
300
+ colors = {"Low": "#91c7ae", "Mediate": "#f7b977", "High": "#d87c7c"}
301
+ groups = ["Low", "Mediate", "High"]
302
+
303
+ # plot KM
304
+ lines = {}
305
+ at_risk_table = []
306
+ times = np.arange(0, 70, 10)
307
+
308
+ for g in groups:
309
+ m = df["group"] == g
310
+ if m.sum() == 0:
311
+ continue
312
+
313
+ kmf.fit(df.loc[m, "time"], event_observed=df.loc[m, "event"], label=g)
314
+ kmf.plot_survival_function(
315
+ ax=ax, ci_show=True, linewidth=2, color=colors[g], marker="+"
316
+ )
317
+ lines[g] = ax.get_lines()[-1]
318
+
319
+ at_risk_table.append([np.sum(df.loc[m, "time"] >= t) for t in times])
320
+
321
+ # legend
322
+ handles = [lines[g] for g in groups if g in lines]
323
+ labels = ["Low", "Medium", "High"][:len(handles)]
324
+ ax.legend(handles, labels, title="Groups", loc="upper right",
325
+ frameon=True, framealpha=0.5, fontsize=12, title_fontsize=12)
326
+
327
+ # at risk numbers (optional, matches your style)
328
+ if len(at_risk_table) == 3:
329
+ low, mid, high = at_risk_table
330
+ for i, t in enumerate(times):
331
+ ax.text(t, -0.38, str(low[i]), color="#207f4c", fontsize=14, ha="center")
332
+ ax.text(t, -0.48, str(mid[i]), color="#fca106", fontsize=14, ha="center")
333
+ ax.text(t, -0.58, str(high[i]), color="#cc163a", fontsize=14, ha="center")
334
+
335
+ ax.text(-1, -0.28, "Number at risk", color="black", ha="center", fontsize=14)
336
+ ax.text(-10, -0.38, "Low", color="#207f4c", fontsize=14)
337
+ ax.text(-10, -0.48, "Medium", color="#fca106", fontsize=14)
338
+ ax.text(-10, -0.58, "High", color="#cc163a", fontsize=14)
339
+
340
+ # Cox HR + p-values
341
+ df2 = df.copy()
342
+ df2["group_code"] = df2["group"].map({"Low": 0, "Mediate": 1, "High": 2})
343
+ cph = CoxPHFitter()
344
+ cph.fit(df2[["time", "event", "group_code"]], duration_col="time", event_col="event")
345
+
346
+ coef = float(cph.params_["group_code"])
347
+ se = float(cph.standard_errors_["group_code"])
348
+
349
+ hr_med_vs_low = np.exp(coef * 1)
350
+ hr_high_vs_low = np.exp(coef * 2)
351
+
352
+ z_med = (coef * 1) / se
353
+ p_med = 2 * (1 - norm.cdf(abs(z_med)))
354
+
355
+ z_high = (coef * 2) / se
356
+ p_high = 2 * (1 - norm.cdf(abs(z_high)))
357
+
358
+ # logrank
359
+ res_lr = multivariate_logrank_test(df2["time"], df2["group"], df2["event"])
360
+
361
+ # C-index + brier (proxy)
362
+ c_index, brier = _evaluate_survival_metrics(df2["time"].values, df2["event"].values,
363
+ df2["group_code"].values, time_point=30)
364
+
365
+ ax.text(25, 0.46, f"P(log-rank)={res_lr.p_value:.3f}", fontsize=12)
366
+ ax.text(25, 0.36, f"C-index={c_index:.3f}", fontsize=12)
367
+ ax.text(25, 0.26, f"Brier(30m)={brier:.3f}", fontsize=12)
368
+ ax.text(25, 0.16, f"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}", fontsize=12)
369
+ ax.text(25, 0.06, f"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}", fontsize=12)
370
+
371
+ # cosmetics
372
+ ax.spines["top"].set_visible(False)
373
+ ax.spines["right"].set_visible(False)
374
+ ax.set_title(title, fontsize=14)
375
+ ax.set_xlabel("Time since treatment start (months)", fontsize=14)
376
+ ax.set_ylabel("Survival probability", fontsize=14)
377
+ ax.set_ylim(0, 1.05)
378
+ ax.grid(alpha=0.3)
379
+
380
+ plt.tight_layout()
381
+ plt.savefig(save_prefix + ".png", dpi=600, bbox_inches="tight")
382
+ plt.savefig(save_prefix + ".pdf", dpi=600, bbox_inches="tight")
383
+ plt.close()
384
+ return save_prefix
385
+
386
+
387
+ def generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=("val", "test")):
388
+ """
389
+ Load saved dfs/os arrays and generate KM+HR for Immune/Chemo separately.
390
+ """
391
+ os.makedirs(fig_dir, exist_ok=True)
392
+
393
+ for split in which_split:
394
+ # load arrays
395
+ trt = np.load(os.path.join(result_dir, f"treatment_{split}.npy"))
396
+
397
+ dfs_r = np.load(os.path.join(result_dir, f"dfs_{split}_risk.npy"))
398
+ dfs_t = np.load(os.path.join(result_dir, f"dfs_{split}_time.npy"))
399
+ dfs_e = np.load(os.path.join(result_dir, f"dfs_{split}_event.npy"))
400
+
401
+ os_r = np.load(os.path.join(result_dir, f"os_{split}_risk.npy"))
402
+ os_t = np.load(os.path.join(result_dir, f"os_{split}_time.npy"))
403
+ os_e = np.load(os.path.join(result_dir, f"os_{split}_event.npy"))
404
+
405
+ for cohort_name, mask in {
406
+ "Immune": trt == 0,
407
+ "Chemo": trt == 1
408
+ }.items():
409
+ if mask.sum() < 20:
410
+ print(f"[Figure7] Skip {split}-{cohort_name}: too few samples ({mask.sum()})")
411
+ continue
412
+
413
+ # DFS groups
414
+ dfs_group = _risk_to_groups(dfs_r[mask])
415
+ df_dfs = pd.DataFrame({
416
+ "time": dfs_t[mask],
417
+ "event": dfs_e[mask].astype(int),
418
+ "group": dfs_group
419
+ })
420
+
421
+ # OS groups
422
+ os_group = _risk_to_groups(os_r[mask])
423
+ df_os = pd.DataFrame({
424
+ "time": os_t[mask],
425
+ "event": os_e[mask].astype(int),
426
+ "group": os_group
427
+ })
428
+
429
+ # save CSV (optional, for reproducibility)
430
+ df_dfs.to_csv(os.path.join(result_dir, f"dfs_{split}_{cohort_name}.csv"), index=False)
431
+ df_os.to_csv(os.path.join(result_dir, f"os_{split}_{cohort_name}.csv"), index=False)
432
+
433
+ # plot
434
+ plot_km_curve_with_hr(
435
+ df_dfs,
436
+ title=f"Disease-Free Survival (DFS) — Kaplan-Meier Curves\n{cohort_name} {split} set (n={mask.sum()})",
437
+ save_prefix=os.path.join(fig_dir, f"Figure7_DFS_{cohort_name}_{split}")
438
+ )
439
+ plot_km_curve_with_hr(
440
+ df_os,
441
+ title=f"Overall Survival (OS) — Kaplan-Meier Curves\n{cohort_name} {split} set (n={mask.sum()})",
442
+ save_prefix=os.path.join(fig_dir, f"Figure7_OS_{cohort_name}_{split}")
443
+ )
444
+
445
+ print("✔ Figure 7 generated (DFS/OS KM + HR) for Immune/Chemo.")
446
+
447
+
448
+ # =========================================================
449
+ # Main
450
+ # =========================================================
451
+ def main():
452
+ # -------------------------
453
+ # Dataset (must return treatment as 19th item)
454
+ # -------------------------
455
+ from mm_dls.PatientDataset import PatientDataset
456
+
457
+ dataset = PatientDataset(
458
+ data_root="/path/to/DATA_ROOT",
459
+ clinical_csv="/path/to/clinical.csv",
460
+ radiomics_npy="/path/to/radiomics.npy",
461
+ pet_npy="/path/to/pet.npy",
462
+ n_slices=N_SLICES,
463
+ img_size=IMG_SIZE,
464
+ )
465
+
466
+
467
+ n_train = int(0.6 * len(dataset))
468
+ n_val = int(0.2 * len(dataset))
469
+ n_test = len(dataset) - n_train - n_val
470
+
471
+ train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])
472
+
473
+ loaders = {
474
+ "train": DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=4),
475
+ "val": DataLoader(val_set, BATCH_SIZE, shuffle=False, num_workers=4),
476
+ "test": DataLoader(test_set, BATCH_SIZE, shuffle=False, num_workers=4),
477
+ }
478
+
479
+ # -------------------------
480
+ # Model
481
+ # -------------------------
482
+ model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)
483
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
484
+
485
+ best_val_loss = 1e9
486
+ wait = 0
487
+
488
+ # -------------------------
489
+ # Training
490
+ # -------------------------
491
+ for epoch in range(1, EPOCHS + 1):
492
+ tr = run_epoch_verbose(model, loaders["train"], optimizer, DEVICE, train=True)
493
+ va = run_epoch_verbose(model, loaders["val"], optimizer, DEVICE, train=False)
494
+
495
+ tr_loss = tr[0]
496
+ va_loss = va[0]
497
+
498
+ # unpack val for metrics
499
+ _, sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e, _, _ = va
500
+ metrics = evaluate_by_treatment(sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e)
501
+
502
+ print(f"\n[Epoch {epoch:03d}] Train Loss={tr_loss:.3f} | Val Loss={va_loss:.3f}")
503
+ for k, v in metrics.items():
504
+ print(
505
+ f" {k:7s} | "
506
+ f"Subtype AUC={v['Subtype_AUC']:.3f} | "
507
+ f"TNM AUC={v['TNM_AUC_macro']:.3f} | "
508
+ f"DFS C-index={v['DFS_C_index']:.3f} | "
509
+ f"OS C-index={v['OS_C_index']:.3f}"
510
+ )
511
+
512
+ # early stopping
513
+ if va_loss < best_val_loss:
514
+ best_val_loss = va_loss
515
+ wait = 0
516
+ torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model.pt"))
517
+ print(" ✓ Best model updated")
518
+ else:
519
+ wait += 1
520
+ if wait >= PATIENCE:
521
+ print("\n⏹ Early stopping triggered")
522
+ break
523
+
524
+ # -------------------------
525
+ # Inference (best model)
526
+ # -------------------------
527
+ print("\nRunning inference with best model...")
528
+ model.load_state_dict(torch.load(os.path.join(SAVE_DIR, "best_model.pt"), map_location=DEVICE))
529
+
530
+ for split in ["train", "val", "test"]:
531
+ out = run_epoch_verbose(model, loaders[split], optimizer, DEVICE, train=False)
532
+ (
533
+ loss,
534
+ sy, ss,
535
+ ty, ts,
536
+ trt,
537
+ dfs_r, dfs_t, dfs_e,
538
+ os_r, os_t, os_e,
539
+ dfs_log, os_log
540
+ ) = out
541
+
542
+ # classification
543
+ np.save(os.path.join(SAVE_DIR, f"subtype_{split}_labels.npy"), sy)
544
+ np.save(os.path.join(SAVE_DIR, f"subtype_{split}_scores.npy"), ss)
545
+ np.save(os.path.join(SAVE_DIR, f"tnm_{split}_labels.npy"), ty)
546
+ np.save(os.path.join(SAVE_DIR, f"tnm_{split}_scores.npy"), ts)
547
+ np.save(os.path.join(SAVE_DIR, f"treatment_{split}.npy"), trt)
548
+
549
+ # survival (cox risk + time/event)
550
+ np.save(os.path.join(SAVE_DIR, f"dfs_{split}_risk.npy"), dfs_r)
551
+ np.save(os.path.join(SAVE_DIR, f"dfs_{split}_time.npy"), dfs_t)
552
+ np.save(os.path.join(SAVE_DIR, f"dfs_{split}_event.npy"), dfs_e)
553
+
554
+ np.save(os.path.join(SAVE_DIR, f"os_{split}_risk.npy"), os_r)
555
+ np.save(os.path.join(SAVE_DIR, f"os_{split}_time.npy"), os_t)
556
+ np.save(os.path.join(SAVE_DIR, f"os_{split}_event.npy"), os_e)
557
+
558
+ # 1y/3y/5y logits (optional, for AUC at specific horizons)
559
+ np.save(os.path.join(SAVE_DIR, f"dfs_{split}_logits_1y3y5y.npy"), dfs_log)
560
+ np.save(os.path.join(SAVE_DIR, f"os_{split}_logits_1y3y5y.npy"), os_log)
561
+
562
+ print(f"{split:5s} | loss={loss:.3f} | Immune={np.sum(trt==0)} Chemo={np.sum(trt==1)}")
563
+
564
+ print("\n✓ Inference completed. Results saved.")
565
+
566
+ # -------------------------
567
+ # Figure: Immune/Chemo KM + HR
568
+ # -------------------------
569
+ print("\nGenerating Figure (KM + HR) ...")
570
+ generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=("val", "test"))
571
+ print("✓ Figure done. Files saved under ./figures")
572
+
573
+
574
+ if __name__ == "__main__":
575
+ main()
576
+