K446 commited on
Commit
1d191e4
·
verified ·
1 Parent(s): 7b306cb

Delete generate_plots.py

Browse files
Files changed (1) hide show
  1. generate_plots.py +0 -307
generate_plots.py DELETED
@@ -1,307 +0,0 @@
1
- """Generate training plots from logged training data."""
2
- import os
3
- import json
4
- import numpy as np
5
- import matplotlib
6
- matplotlib.use('Agg')
7
- import matplotlib.pyplot as plt
8
-
9
- os.makedirs("training/outputs", exist_ok=True)
10
-
11
- # All 449 training steps extracted from the training log
12
- rewards = [
13
- -0.18578660488128662, -0.12301036529242992, -0.2747359238564968, -0.30009209364652634,
14
- -0.2703569196164608, -0.24127129651606083, -0.08399589732289314, -0.11878747679293156,
15
- -0.05325012654066086, -0.021383648738265038, -0.11647990718483925, -0.12830854021012783,
16
- -0.07859327644109726, -0.062035027891397476, -0.28994257375597954, -0.05203340668231249,
17
- -0.20743045955896378, -0.06474572420120239, -0.06319488771259785, -0.06409797258675098,
18
- -0.02603842318058014, -0.09335997886955738, -0.22815338149666786, 0.11535784974694252,
19
- -0.15228833630681038, 0.16921303793787956, -0.05354591645300388, 0.0813290998339653,
20
- 0.057836150750517845, 0.049862340092659, -0.012776482850313187, 0.07129384018480778,
21
- 0.06172069534659386, -0.004314497113227844, 0.26807015016674995, 0.33759409189224243,
22
- 0.30997015349566936, 0.34701258316636086, 0.29778963327407837, 0.3557572774589062,
23
- 0.22040660306811333, 0.19206945598125458, 0.24810272827744484, 0.26202990114688873,
24
- 0.3874269649386406, 0.5775104463100433, 0.412799209356308, 0.5506034344434738,
25
- 0.5067616701126099, 0.40515726059675217, 0.5588711947202682, 0.5634059756994247,
26
- 0.4039550945162773, 0.5155875980854034, 0.5783856362104416, 0.580144003033638,
27
- 0.5121691823005676, 0.5833786576986313, 0.5272477120161057, 0.5836405158042908,
28
- 0.5493134558200836, 0.5400870218873024, 0.5268918424844742, 0.597753182053566,
29
- 0.5757492780685425, 0.6002768129110336, 0.4947819709777832, 0.5797900557518005,
30
- 0.6096376329660416, 0.6012084484100342, 0.5948903113603592, 0.6152122467756271,
31
- 0.5859103500843048, 0.593388631939888, 0.5888432413339615, 0.5871430486440659,
32
- 0.6037257760763168, 0.608445480465889, 0.6111176311969757, 0.6088756918907166,
33
- 0.617440938949585, 0.5364247262477875, 0.6171374917030334, 0.61806720495224,
34
- 0.5384384840726852, 0.6131065785884857, 0.6336067169904709, 0.5625222399830818,
35
- 0.6201395094394684, 0.5604271367192268, 0.6164691746234894, 0.5698070898652077,
36
- 0.5734636038541794, 0.6113622784614563, 0.5929720252752304, 0.5639816671609879,
37
- 0.588249459862709, 0.6279790103435516, 0.6442658007144928, 0.602244570851326,
38
- 0.6248061060905457, 0.6190209984779358, 0.6029432117938995, 0.46744125476107,
39
- 0.64055135846138, 0.6167348772287369, 0.6421176940202713, 0.6349569857120514,
40
- 0.5953923761844635, 0.6287701427936554, 0.6182780563831329, 0.6208404898643494,
41
- 0.6566016525030136, 0.6026060730218887, 0.6440890580415726, 0.6258739531040192,
42
- 0.6422613263130188, 0.6495921015739441, 0.6294001936912537, 0.6501388698816299,
43
- 0.6263301968574524, 0.6417667120695114, 0.6583167463541031, 0.6618165671825409,
44
- 0.618654727935791, 0.6316704601049423, 0.6253484189510345, 0.6209764331579208,
45
- 0.6513039767742157, 0.6175498366355896, 0.6438220143318176, 0.6232690960168839,
46
- 0.6455031633377075, 0.6400457620620728, 0.5865997821092606, 0.6412583589553833,
47
- 0.6423900127410889, 0.6430913358926773, 0.5947229713201523, 0.6378145664930344,
48
- 0.6347617357969284, 0.6227764636278152, 0.6115130484104156, 0.619041696190834,
49
- 0.6370682269334793, 0.6424119472503662, 0.6064454615116119, 0.6429545283317566,
50
- 0.6444623470306396, 0.640910416841507, 0.6546966582536697, 0.6172017753124237,
51
- 0.6528860777616501, 0.6289037466049194, 0.6421212702989578, 0.641191765666008,
52
- 0.6529533863067627, 0.6347779184579849, 0.6358228027820587, 0.6538639217615128,
53
- 0.622765526175499, 0.6157135218381882, 0.6647461652755737, 0.6429563164710999,
54
- 0.6327588856220245, 0.6607349812984467, 0.6299811005592346, 0.6335073709487915,
55
- 0.6295449882745743, 0.6447764039039612, 0.6679948419332504, 0.6275373697280884,
56
- 0.6362748295068741, 0.6520860940217972, 0.6445683687925339, 0.6265115588903427,
57
- 0.6601778268814087, 0.6509897261857986, 0.6658665686845779, 0.6472330242395401,
58
- 0.6349419355392456, 0.6362574249505997, 0.639707624912262, 0.6521458774805069,
59
- 0.6283893138170242, 0.6409243643283844, 0.4912406029179692, 0.6509060710668564,
60
- 0.6391417533159256, 0.6477353125810623, 0.6539895087480545, 0.6675603687763214,
61
- 0.6587939709424973, 0.657221257686615, 0.6590015888214111, 0.6346411406993866,
62
- 0.6513633877038956, 0.6667361706495285, 0.6224590390920639, 0.6662313640117645,
63
- 0.6409972608089447, 0.6431838124990463, 0.6545909196138382, 0.6433757543563843,
64
- 0.6702606827020645, 0.6787336617708206, 0.6583948284387589, 0.6685910671949387,
65
- 0.6483594626188278, 0.6422435194253922, 0.6496011763811111, 0.6627089530229568,
66
- 0.6541863232851028, 0.6380441784858704, 0.6676874160766602, 0.619408369064331,
67
- 0.674984872341156, 0.6594787091016769, 0.6471594125032425, 0.664968878030777,
68
- 0.6094392091035843, 0.6406512260437012, 0.651197537779808, 0.658475250005722,
69
- 0.6643944382667542, 0.6608465164899826, 0.6218504756689072, 0.6645185798406601,
70
- 0.6627729833126068, 0.6416528224945068, 0.6508330553770065, 0.6713765859603882,
71
- 0.6407269686460495, 0.6450571715831757, 0.6566052138805389, 0.6176406294107437,
72
- 0.6360985189676285, 0.6675495505332947, 0.6451499909162521, 0.6709684878587723,
73
- 0.6390052437782288, 0.631124421954155, 0.6516198068857193, 0.6592375189065933,
74
- 0.6607232093811035, 0.6665454506874084, 0.6784592717885971, 0.6679108291864395,
75
- 0.6747743785381317, 0.6604794561862946, 0.6463411301374435, 0.6588997393846512,
76
- 0.6369200497865677, 0.6638156026601791, 0.6568935811519623, 0.6349741220474243,
77
- 0.6757373809814453, 0.6636634916067123, 0.6647922098636627, 0.6848382502794266,
78
- 0.6746585667133331, 0.6585167646408081, 0.6778526455163956, 0.6565847545862198,
79
- 0.6661055386066437, 0.6497465819120407, 0.6569660305976868, 0.6432889252901077,
80
- 0.6657276153564453, 0.6702485382556915, 0.657979741692543, 0.6453153342008591,
81
- 0.6447050124406815, 0.6546015292406082, 0.6665160208940506, 0.6468475759029388,
82
- 0.6682360768318176, 0.6528605669736862, 0.6791192591190338, 0.6656849384307861,
83
- 0.6661409437656403, 0.6565423607826233, 0.6476109772920609, 0.6441425532102585,
84
- 0.6333185732364655, 0.6528846025466919, 0.5346547998487949, 0.661629244685173,
85
- 0.6457860767841339, 0.6625054627656937, 0.6554056107997894, 0.5183801241219044,
86
- 0.6669785678386688, 0.6486610025167465, 0.6643702834844589, 0.6631092876195908,
87
- 0.6672863662242889, 0.5593330450356007, 0.6752507239580154, 0.6672438830137253,
88
- 0.6647252142429352, 0.6570066511631012, 0.6669302135705948, 0.6489714831113815,
89
- 0.6476901769638062, 0.6283148229122162, 0.678331196308136, 0.6656024307012558,
90
- 0.662788450717926, 0.6759517192840576, 0.639068067073822, 0.6756545603275299,
91
- 0.6527899652719498, 0.6730388104915619, 0.6459566354751587, 0.6560013592243195,
92
- 0.6748766750097275, 0.6687155216932297, 0.6706540584564209, 0.6495843082666397,
93
- 0.6799521893262863, 0.6635957360267639, 0.6720803678035736, 0.6645216792821884,
94
- 0.6716215461492538, 0.6518281102180481, 0.6669072657823563, 0.6701558530330658,
95
- 0.667682871222496, 0.6670085489749908, 0.6641965061426163, 0.6715318560600281,
96
- 0.6682032495737076, 0.6779512614011765, 0.658478781580925, 0.637330174446106,
97
- 0.6767725795507431, 0.6605011075735092, 0.6717278361320496, 0.6763487756252289,
98
- 0.6709421873092651, 0.6665571480989456, 0.654511958360672, 0.6721566319465637,
99
- 0.6596964299678802, 0.6524780243635178, 0.6477847546339035, 0.6643114984035492,
100
- 0.6747605353593826, 0.6629264950752258, 0.665297195315361, 0.6693083792924881,
101
- 0.6696890145540237, 0.5966470688581467, 0.6815635859966278, 0.6738880425691605,
102
- 0.673828199505806, 0.6660105437040329, 0.6719370037317276, 0.6882820278406143,
103
- 0.6640917211771011, 0.6722412407398224, 0.552493441849947, 0.6623934805393219,
104
- 0.6788368225097656, 0.6565920561552048, 0.672383576631546, 0.6848682165145874,
105
- 0.6602808088064194, 0.6702089160680771, 0.6784865409135818, 0.6650059223175049,
106
- 0.6742192059755325, 0.6690966337919235, 0.669212743639946, 0.6460111290216446,
107
- 0.5430178381502628, 0.6669035255908966, 0.66722771525383, 0.6645000576972961,
108
- 0.6494639664888382, 0.6689274609088898, 0.6722604483366013, 0.6583697944879532,
109
- 0.6557460725307465, 0.6811504364013672, 0.6752683371305466, 0.6526945680379868,
110
- 0.6799066811800003, 0.6642590761184692, 0.6735653281211853, 0.6775491684675217,
111
- 0.6502445936203003, 0.6474847346544266, 0.6698097139596939, 0.5537179000675678,
112
- 0.6778432428836823, 0.6478461921215057, 0.6734054982662201, 0.6732118874788284,
113
- 0.6726815104484558, 0.652365118265152, 0.6767247319221497, 0.6702376455068588,
114
- 0.674629420042038, 0.6761960536241531, 0.673548698425293, 0.6691678017377853,
115
- 0.6714010536670685, 0.6520178616046906, 0.6619316786527634, 0.6795330345630646,
116
- 0.6742851585149765, 0.679363876581192, 0.6469457894563675, 0.678314134478569,
117
- 0.6797148585319519, 0.6546463519334793, 0.5537998266518116, 0.6691249161958694,
118
- 0.679972305893898, 0.6313492655754089, 0.6602607369422913, 0.6651852130889893,
119
- 0.6764066517353058, 0.6723304837942123, 0.6575123965740204, 0.6464853435754776,
120
- 0.665999174118042, 0.6613194197416306, 0.6648440957069397, 0.6763277351856232,
121
- 0.6656117290258408, 0.6499385833740234, 0.6681733727455139, 0.673409029841423,
122
- 0.6539389342069626, 0.6613607704639435, 0.6615600138902664, 0.6840917021036148,
123
- 0.6623311191797256, 0.6651297807693481, 0.6267247498035431, 0.6782162338495255,
124
- 0.6677617877721786, 0.6655223816633224, 0.6517190784215927, 0.6561715453863144,
125
- 0.6818244755268097,
126
- ]
127
-
128
- losses = [
129
- 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0001,
130
- 0.0, 0.0, 0.0001, 0.0001, 0.0001, 0.0002, 0.0001, 0.0002, 0.0003, 0.0002, 0.0003,
131
- 0.0003, 0.0002, 0.0003, 0.0005, 0.0005, 0.0005, 0.0003, 0.0006, 0.0009, 0.0007,
132
- 0.0006, 0.001, 0.0012, 0.0009, 0.0013, 0.0008, 0.001, 0.0015, 0.0017, 0.0011,
133
- 0.001, 0.001, 0.0019, 0.0014, 0.0021, 0.0012, 0.0014, 0.0015, 0.0011, 0.0012,
134
- 0.002, 0.0018, 0.0018, 0.0019, 0.002, 0.0022, 0.0022, 0.0024, 0.0031, 0.0024,
135
- 0.0029, 0.002, 0.0035, 0.0025, 0.0027, 0.0025, 0.0021, 0.0016, 0.0024, 0.0028,
136
- 0.0024, 0.0038, 0.0032, 0.0039, 0.0019, 0.0027, 0.0029, 0.0043, 0.0031, 0.003,
137
- 0.0029, 0.0026, 0.0019, 0.0022, 0.0026, 0.0025, 0.0035, 0.0027, 0.0018, 0.0036,
138
- 0.0022, 0.0034, 0.003, 0.0026, 0.0026, 0.0029, 0.0026, 0.0023, 0.0037, 0.0037,
139
- 0.0029, 0.0039, 0.0026, 0.004, 0.004, 0.0031, 0.0064, 0.0038, 0.0048, 0.0038,
140
- 0.0039, 0.0029, 0.0038, 0.0039, 0.0045, 0.0055, 0.005, 0.0047, 0.0041, 0.0046,
141
- 0.0046, 0.0036, 0.0042, 0.0027, 0.0034, 0.0035, 0.0044, 0.004, 0.0043, 0.0036,
142
- 0.0029, 0.0048, 0.0042, 0.0042, 0.0044, 0.004, 0.0039, 0.0039, 0.0029, 0.0035,
143
- 0.0047, 0.0032, 0.0045, 0.0037, 0.0046, 0.0055, 0.0051, 0.0035, 0.0061, 0.0044,
144
- 0.0052, 0.0052, 0.0047, 0.0064, 0.0072, 0.0056, 0.0056, 0.0054, 0.0068, 0.0062,
145
- 0.0044, 0.0053, 0.0054, 0.0057, 0.0063, 0.0029, 0.0039, 0.0043, 0.0053, 0.007,
146
- 0.0069, 0.0048, 0.0055, 0.0054, 0.0042, 0.0058, 0.0075, 0.0078, 0.0075, 0.0064,
147
- 0.0061, 0.0066, 0.0076, 0.0065, 0.0058, 0.0079, 0.0053, 0.0074, 0.006, 0.0052,
148
- 0.0072, 0.0048, 0.0065, 0.0079, 0.0053, 0.0074, 0.0073, 0.0044, 0.0056, 0.0062,
149
- 0.0078, 0.0065, 0.007, 0.0066, 0.007, 0.0052, 0.0054, 0.0075, 0.0078, 0.0075,
150
- 0.0064, 0.0061, 0.0066, 0.0076, 0.007, 0.0057, 0.0058, 0.0061, 0.0087, 0.0065,
151
- 0.0061, 0.0054, 0.0061, 0.0084, 0.0072, 0.0071, 0.0058, 0.0074, 0.008, 0.0066,
152
- 0.0069, 0.007, 0.0063, 0.0067, 0.0047, 0.0074, 0.0066, 0.007, 0.0078, 0.0062,
153
- 0.0058, 0.0086, 0.0088, 0.007, 0.0077, 0.0067, 0.0063, 0.0078, 0.0082, 0.0077,
154
- 0.006, 0.008, 0.0082, 0.0068, 0.0073, 0.0071, 0.0102, 0.0062, 0.0058, 0.0067,
155
- 0.009, 0.0089, 0.0053, 0.0077, 0.0063, 0.0056, 0.009, 0.0079, 0.0072, 0.0078,
156
- 0.0081, 0.0055, 0.0081, 0.0083, 0.0079, 0.0065, 0.0072, 0.0085, 0.0085, 0.0063,
157
- 0.0059, 0.0065, 0.0073, 0.0095, 0.0073, 0.0086, 0.0055, 0.0075, 0.0076, 0.0052,
158
- 0.0058, 0.0076, 0.0077, 0.0064, 0.0087, 0.0064, 0.0069, 0.0077, 0.007, 0.0074,
159
- 0.0059, 0.0064, 0.0095, 0.0084, 0.0061, 0.0056, 0.009, 0.0079, 0.0072, 0.0078,
160
- 0.0081, 0.0081, 0.0097, 0.0058, 0.0071, 0.0069, 0.0076, 0.0087, 0.0079, 0.0082,
161
- 0.0074, 0.0067, 0.0096, 0.0068, 0.007, 0.0092, 0.0083, 0.0071, 0.0073, 0.009,
162
- 0.0074, 0.0077, 0.0075, 0.0073, 0.0078, 0.0064, 0.0062, 0.0085, 0.0065, 0.0058,
163
- 0.0087, 0.0071, 0.0073, 0.008, 0.0077, 0.0063, 0.0057, 0.0054, 0.008, 0.0067,
164
- 0.0063, 0.0056, 0.007, 0.0049, 0.0057, 0.0062, 0.0078, 0.0082, 0.0089, 0.0091,
165
- 0.0068, 0.0069, 0.0081, 0.0058, 0.0069, 0.0065, 0.0067, 0.007,
166
- ]
167
-
168
- # Pad losses to match rewards length if needed
169
- if len(losses) < len(rewards):
170
- avg_tail = float(np.mean(losses[-20:]))
171
- losses = losses + [avg_tail] * (len(rewards) - len(losses))
172
-
173
- steps = list(range(1, len(rewards) + 1))
174
-
175
- # ── Plot 1: Reward over training ────────────────────────────────
176
- fig, ax = plt.subplots(figsize=(12, 5))
177
- ax.plot(steps, rewards, color='#4dabf7', linewidth=0.8, alpha=0.5, label='Reward (per step)')
178
-
179
- window = 20
180
- smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
181
- smooth_steps = steps[window-1:]
182
- ax.plot(smooth_steps, smoothed, color='#00d4aa', linewidth=2.5, label=f'Smoothed (w={window})')
183
-
184
- ax.axhline(y=0, color='#ff6b6b', linestyle='--', linewidth=1, alpha=0.7, label='Zero baseline')
185
- ax.axhline(y=0.6, color='#ffd43b', linestyle=':', linewidth=1.5, alpha=0.8, label='0.6 target')
186
-
187
- ax.set_xlabel('Training Step', fontsize=12)
188
- ax.set_ylabel('GRPO Reward', fontsize=12)
189
- ax.set_title('OpenGrid GRPO Training — Reward Curve\n(Qwen2.5-1.5B-Instruct, LoRA r=16, task_karnataka)', fontweight='bold', fontsize=13)
190
- ax.legend(fontsize=10)
191
- ax.grid(True, alpha=0.3)
192
- ax.set_xlim(1, len(steps))
193
- ax.set_ylim(-0.45, 0.75)
194
-
195
- # Annotate key milestones
196
- ax.annotate('Learning begins\n(step ~24)', xy=(24, rewards[23]), xytext=(60, -0.32),
197
- arrowprops=dict(arrowstyle='->', color='gray'), fontsize=9, color='gray')
198
- ax.annotate('Rapid improvement\n(step ~35–50)', xy=(46, rewards[45]), xytext=(90, 0.42),
199
- arrowprops=dict(arrowstyle='->', color='gray'), fontsize=9, color='gray')
200
- ax.annotate('Converged ≈0.66\n(step ~300+)', xy=(350, rewards[349]), xytext=(260, 0.72),
201
- arrowprops=dict(arrowstyle='->', color='gray'), fontsize=9, color='gray')
202
-
203
- plt.tight_layout()
204
- plt.savefig('training/outputs/training_reward_curve.png', dpi=150, bbox_inches='tight')
205
- plt.close()
206
- print("Saved: training/outputs/training_reward_curve.png")
207
-
208
- # ── Plot 2: Loss over training ──────────────────────────────────
209
- fig, ax = plt.subplots(figsize=(12, 4))
210
- ax.plot(steps, losses, color='#ff6b6b', linewidth=0.8, alpha=0.5, label='Loss (per step)')
211
-
212
- smoothed_loss = np.convolve(losses, np.ones(window)/window, mode='valid')
213
- ax.plot(smooth_steps, smoothed_loss, color='#e03131', linewidth=2.5, label=f'Smoothed (w={window})')
214
-
215
- ax.set_xlabel('Training Step', fontsize=12)
216
- ax.set_ylabel('Loss', fontsize=12)
217
- ax.set_title('OpenGrid GRPO Training — Loss Curve', fontweight='bold', fontsize=13)
218
- ax.legend(fontsize=10)
219
- ax.grid(True, alpha=0.3)
220
- ax.set_xlim(1, len(steps))
221
- plt.tight_layout()
222
- plt.savefig('training/outputs/training_loss.png', dpi=150, bbox_inches='tight')
223
- plt.close()
224
- print("Saved: training/outputs/training_loss.png")
225
-
226
- # ── Plot 3: Before vs After bar chart ──────────────────────────
227
- fig, ax = plt.subplots(figsize=(10, 6))
228
-
229
- tasks = ['task_easy', 'task_medium', 'karnataka_easy', 'karnataka_medium', 'karnataka_hard', 'task_karnataka']
230
- labels = ['Easy', 'Medium', 'Karnataka\nEasy', 'Karnataka\nMedium', 'Karnataka\nHard', 'Karnataka\n(training)']
231
- baseline = [31.99, 46.69, 56.33, 49.57, -417.15, 49.43]
232
-
233
- # GRPO trained on task_karnataka; approximate post-training estimates
234
- # based on reward improvement of ~0.66 observed (normalized reward scale)
235
- # The environment reward scale differs from the GRPO normalized reward
236
- trained_est = [38.5, 52.1, 61.2, 57.8, -180.0, 58.9]
237
-
238
- x = np.arange(len(tasks))
239
- width = 0.35
240
-
241
- bars1 = ax.bar(x - width/2, baseline, width, label='Heuristic Baseline', color='#ff6b6b', alpha=0.85)
242
- bars2 = ax.bar(x + width/2, trained_est, width, label='GRPO Trained (est.)', color='#00d4aa', alpha=0.85)
243
-
244
- ax.set_xlabel('Task', fontsize=12)
245
- ax.set_ylabel('Average Episode Reward', fontsize=12)
246
- ax.set_title('OpenGrid — GRPO Training Results\nBaseline vs Trained Policy (task_karnataka)', fontweight='bold', fontsize=13)
247
- ax.set_xticks(x)
248
- ax.set_xticklabels(labels, fontsize=10)
249
- ax.legend(fontsize=11)
250
- ax.grid(True, alpha=0.3, axis='y')
251
- ax.axhline(y=0, color='black', linewidth=0.8, alpha=0.5)
252
-
253
- for bar in bars1:
254
- h = bar.get_height()
255
- ax.text(bar.get_x() + bar.get_width()/2., h + (5 if h >= 0 else -20),
256
- f'{h:.1f}', ha='center', va='bottom' if h >= 0 else 'top', fontsize=9)
257
- for bar in bars2:
258
- h = bar.get_height()
259
- ax.text(bar.get_x() + bar.get_width()/2., h + (5 if h >= 0 else -20),
260
- f'{h:.1f}*', ha='center', va='bottom' if h >= 0 else 'top', fontsize=9, color='#2f9e44')
261
-
262
- ax.text(0.98, 0.02, '* Trained values estimated from GRPO reward signal\n (post-eval crashed; raw reward improved −0.19→0.66)',
263
- transform=ax.transAxes, fontsize=8, ha='right', va='bottom', color='gray',
264
- bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
265
-
266
- plt.tight_layout()
267
- plt.savefig('training/outputs/before_after.png', dpi=150, bbox_inches='tight')
268
- plt.close()
269
- print("Saved: training/outputs/before_after.png")
270
-
271
- # ── Save summary.json ───────────────────────────────────────────
272
- summary = {
273
- "model": "Qwen/Qwen2.5-1.5B-Instruct",
274
- "train_task": "task_karnataka",
275
- "train_time_minutes": 159.6,
276
- "num_prompts": 600,
277
- "num_epochs": 3,
278
- "num_steps": 449,
279
- "gpu": "NVIDIA A10G (23.9 GB)",
280
- "lora_rank": 16,
281
- "framework": "TRL GRPOTrainer + bitsandbytes 4-bit",
282
- "reward_start": round(float(np.mean(rewards[:5])), 4),
283
- "reward_end": round(float(np.mean(rewards[-20:])), 4),
284
- "reward_peak": round(float(max(rewards)), 4),
285
- "note": "Post-training eval OOM'd during model save; reward values from training log",
286
- "baseline": {
287
- "task_easy": {"avg": 31.99, "std": 0.0},
288
- "task_medium": {"avg": 46.69, "std": 0.36},
289
- "karnataka_easy": {"avg": 56.33, "std": 0.25},
290
- "karnataka_medium": {"avg": 49.57, "std": 0.21},
291
- "karnataka_hard": {"avg": -417.15, "std": 63.02},
292
- "task_karnataka": {"avg": 49.43, "std": 0.21},
293
- },
294
- "training_reward": {
295
- "initial_avg_5steps": round(float(np.mean(rewards[:5])), 4),
296
- "mid_avg_steps100_150": round(float(np.mean(rewards[99:149])), 4),
297
- "final_avg_last50steps": round(float(np.mean(rewards[-50:])), 4),
298
- }
299
- }
300
- with open("training/outputs/summary.json", "w") as f:
301
- json.dump(summary, f, indent=2)
302
- print("Saved: training/outputs/summary.json")
303
-
304
- print("\nDone! All outputs saved to training/outputs/")
305
- print(f" Reward: {summary['reward_start']:.4f} → {summary['reward_end']:.4f}")
306
- print(f" Steps: {summary['num_steps']}")
307
- print(f" Time: {summary['train_time_minutes']} min")