File size: 126,291 Bytes
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
"""
GRPO training for self-improvement math environment.

Group Relative Policy Optimization (GRPO) is dramatically simpler and more
stable than PPO for LLM fine-tuning on math tasks:

  - No value function / critic needed
  - No GAE, no gamma, no lambda
  - No KL instability from per-step advantage collapse
  - Advantages computed as within-group z-scores: A_i = (R_i - mean_R) / std_R
  - Proven on math RL: DeepSeek-Math, Qwen-Math, DAPO all use GRPO variants

The algorithm per question:
  1. Generate K solutions (default K=4)
  2. Score each with the existing reward pipeline (PRM + SymPy + format)
  3. A_i = (R_i - mean(R)) / (std(R) + eps)
  4. policy_loss = -mean_i [ A_i * sum_t log pi(a_t | s_{<t}) / T_i ]
  5. Skip the group if all rewards are identical (zero gradient signal)

Expected improvement curve:
  - Iterations 1-5:  reward mean rising, policy learning to avoid R=0 outputs
  - Iterations 5-15: GSM8K accuracy starts moving (+2-5%)
  - Iterations 15-30: continued improvement toward ~70-75%+ from 63.6% baseline

Usage:
    python scripts/run_grpo_training.py \\
        --base-model checkpoints/dual_task_v1 \\
        --gsm8k-data data/sft/gsm8k_sft.jsonl \\
        --num-iterations 30 \\
        --group-size 4 \\
        --questions-per-iter 16

    # Faster smoke test (no PRM, 3 iters):
    python scripts/run_grpo_training.py \\
        --base-model checkpoints/dual_task_v1 \\
        --num-iterations 3 --group-size 4 --questions-per-iter 8 \\
        --no-prm --skip-initial-eval --run-name smoke_grpo
"""

from __future__ import annotations

import argparse
import atexit
import copy
import csv
import json
import logging
import random
import re
import shutil
import sys
import time
import types
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm

sys.path.insert(0, str(Path(__file__).parent.parent))

from scripts.convert_gsm8k_to_sft import parse_gsm8k_answer
from scripts.eval_sft_inference import evaluate_gsm8k
from src.rl.prm_scorer import ProcessRewardScorer
from src.sft.solution_format import extract_final_answer_numeric_str
from src.utils.attn_backend import select_attn_implementation
from src.rl.math_environment_curriculum import CurriculumMathEnvironment
from src.config.prompts import create_generator_messages

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
)
logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Logging infrastructure
# ---------------------------------------------------------------------------

class TeeStream:
    """Mirrors every write to a terminal stream into a log file.

    Wrapping sys.stdout and sys.stderr with this object ensures that *all*
    output β€” bare print() calls, tqdm bars, third-party library writes β€” lands
    in the run log file in addition to the terminal.

    A separate FileHandler on the root logger (see _add_file_logging) captures
    the Python logging subsystem independently, because logging.StreamHandler
    stores a reference to the stream at creation time and therefore bypasses
    any later sys.stderr reassignment.  Both mechanisms together guarantee that
    nothing escapes the log file.
    """

    def __init__(self, primary, secondary):
        self.primary = primary
        self.secondary = secondary

    def write(self, data: str) -> int:
        self.primary.write(data)
        self.secondary.write(data)
        return len(data)

    def flush(self) -> None:
        self.primary.flush()
        self.secondary.flush()

    def isatty(self) -> bool:
        return getattr(self.primary, "isatty", lambda: False)()

    def fileno(self) -> int:
        return self.primary.fileno()


def _add_file_logging(log_path: Path) -> logging.FileHandler:
    """Attach a FileHandler to the root logger.

    Every logger.info / logger.warning / … call β€” from any module β€” will be
    written to ``log_path`` in addition to the terminal.  This complements
    TeeStream: TeeStream captures bare print() / sys.stderr writes; this
    handler captures the logging subsystem, which uses its own internal stream
    reference that TeeStream cannot intercept.
    """
    fh = logging.FileHandler(log_path, mode="a", encoding="utf-8")
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(logging.Formatter(
        "%(asctime)s %(levelname)-8s %(name)s - %(message)s"
    ))
    logging.getLogger().addHandler(fh)
    return fh


if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True   # auto-tune fastest conv algo per shape


# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------

def _infer_eval_dataset_name(data_path: str) -> str:
    """Derive a short human-readable label from the eval data file path."""
    stem = Path(data_path).stem.lower()
    if "aqua" in stem:
        return "AQuA-RAT"
    if "math" in stem:
        return "MATH"
    if "gsm" in stem:
        return "GSM8K"
    return Path(data_path).stem


def load_gsm8k(path: str) -> List[Dict[str, str]]:
    """Return list of {"question": ..., "gold_final": ...} from a JSONL file."""
    pairs: List[Dict[str, str]] = []
    p = Path(path)
    if not p.exists():
        logger.warning("Training data not found at %s", path)
        return pairs
    with p.open(encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError:
                continue

            question = ""
            gold = ""
            if "question" in rec and "answer" in rec:
                question = rec["question"].strip()
                _, gold = parse_gsm8k_answer(str(rec["answer"]))
            elif "messages" in rec:
                user_text = ""
                asst_text = ""
                for msg in rec["messages"]:
                    if msg.get("role") == "user" and not user_text:
                        user_text = msg.get("content", "").strip()
                    elif msg.get("role") == "assistant" and not asst_text:
                        asst_text = msg.get("content", "")
                if "Problem:" in user_text:
                    question = user_text.split("Problem:", 1)[1].strip()
                else:
                    question = user_text
                answer_str = extract_final_answer_numeric_str(asst_text) or ""
                gold = answer_str.strip()

            if question and gold:
                pairs.append({"question": question, "gold_final": gold})
    logger.info("Loaded %d QA pairs from %s", len(pairs), path)
    return pairs


# ---------------------------------------------------------------------------
# MATH harder dataset
# ---------------------------------------------------------------------------

def _extract_boxed(text: str) -> Optional[str]:
    r"""Extract the content of the first ``\boxed{...}`` in *text*."""
    m = re.search(r"\\boxed\{([^}]*)\}", text)
    return m.group(1).strip() if m else None


def _boxed_to_numeric(answer: str) -> Optional[str]:
    """
    Convert a ``\\boxed{...}`` answer to a plain numeric string.

    Returns a string of the form ``"42"`` or ``"3.5000"`` when the answer
    is a recognisable integer, decimal, or simple fraction (``3/4`` or
    ``\\frac{3}{4}``).  Returns ``None`` for symbolic / multi-part answers
    like ``3\\sqrt{2}`` or ``(1, 2)``.
    """
    ans = answer.strip()
    # Direct integer
    try:
        return str(int(ans))
    except ValueError:
        pass
    # Direct float (includes "3.5", "0.75", etc.)
    try:
        v = float(ans)
        return str(int(v)) if v == int(v) else f"{v:.4f}"
    except ValueError:
        pass
    # LaTeX fraction  \frac{num}{den}
    m = re.fullmatch(r"\\frac\{(\d+)\}\{(\d+)\}", ans)
    if m:
        num, den = int(m.group(1)), int(m.group(2))
        if den:
            v = num / den
            return str(int(v)) if v == int(v) else f"{v:.4f}"
    # Plain fraction  num/den
    m = re.fullmatch(r"(\d+)/(\d+)", ans)
    if m:
        num, den = int(m.group(1)), int(m.group(2))
        if den:
            v = num / den
            return str(int(v)) if v == int(v) else f"{v:.4f}"
    return None


def load_math_dataset(
    local_path: Optional[str] = None,
    cache_path: str = "data/math/math_numeric.jsonl",
    max_difficulty: int = 3,
) -> List[Dict[str, str]]:
    """
    Load a subset of the MATH competition dataset filtered to problems with
    numerically-verifiable answers (integers, decimals, simple fractions).

    Loading order
    -------------
    1. ``local_path`` if provided and the file exists.
    2. ``cache_path`` if that file exists (written on first HF download).
    3. HuggingFace ``competition_math`` dataset; filtered + written to
       ``cache_path`` for subsequent runs.

    Only problems with ``Level ≀ max_difficulty`` are included.  Difficulty
    1-2 β‰ˆ AMC-8 level (comparable to hard GSM8K); difficulty 3 β‰ˆ AMC-10.
    Levels 4-5 are graduate-level and usually too hard for a 1.5B model to
    get any reward signal from (win_rate β‰ˆ 0 β†’ skipped groups every iter).
    """
    for candidate in filter(None, [local_path, cache_path]):
        p = Path(candidate)
        if p.exists():
            pairs: List[Dict[str, str]] = []
            with p.open(encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        try:
                            pairs.append(json.loads(line))
                        except json.JSONDecodeError:
                            pass
            if pairs:
                logger.info("Loaded %d MATH pairs from %s", len(pairs), p)
                return pairs

    # Download from HuggingFace
    logger.info(
        "MATH dataset not found locally β€” downloading from HuggingFace "
        "(qwedsacf/competition_math, difficulty ≀ %d, numeric answers only)...",
        max_difficulty,
    )
    # Try HF sources in priority order.  Only keep sources confirmed reachable;
    # lighteval/MATH and hendrycks/competition_math have network/naming issues.
    _HF_SOURCES = [
        ("qwedsacf/competition_math", {}),           # reliable public mirror
        ("lighteval/MATH-Hard",       {"name": "default"}),  # hard subset
    ]
    ds = None
    for hf_name, hf_kwargs in _HF_SOURCES:
        try:
            from datasets import load_dataset  # type: ignore
            ds = load_dataset(hf_name, split="train", trust_remote_code=True, **hf_kwargs)
            logger.info("Loaded HuggingFace dataset: %s (%d items)", hf_name, len(ds))
            break
        except Exception as exc:
            logger.warning("Could not load %s: %s β€” trying next source.", hf_name, exc)
    if ds is None:
        logger.warning(
            "All MATH dataset sources failed. Proceeding with GSM8K only. "
            "To load offline: download from https://github.com/hendrycks/math "
            "and pass --math-data <path_to_jsonl>."
        )
        return []

    pairs = []
    for item in ds:
        level_str = item.get("level", "Level 5")
        try:
            level = int(level_str.split()[-1])
        except (ValueError, IndexError):
            level = 5
        if level > max_difficulty:
            continue

        question = item.get("problem", "").strip()
        solution = item.get("solution", "")
        boxed    = _extract_boxed(solution)
        if not boxed:
            continue
        numeric  = _boxed_to_numeric(boxed)
        if not numeric:
            continue
        pairs.append({"question": question, "gold_final": numeric})

    if pairs:
        out_p = Path(cache_path)
        out_p.parent.mkdir(parents=True, exist_ok=True)
        with out_p.open("w", encoding="utf-8") as f:
            for p_item in pairs:
                f.write(json.dumps(p_item) + "\n")
        logger.info("Cached %d MATH numeric pairs to %s", len(pairs), out_p)
    else:
        logger.warning("No MATH pairs passed the numeric filter β€” check the dataset.")

    return pairs


# ---------------------------------------------------------------------------
# Reward
# ---------------------------------------------------------------------------

# ---------------------------------------------------------------------------
# Self-play verification cascade
# ---------------------------------------------------------------------------
# Routes each self-play group to the right verification tool based on
# problem type and difficulty, then gates the GRPO update on the result.
# Returns False (β†’ skip group) when no tool can verify cleanly, preventing
# circular PRM-only reward from anchoring the training signal.

import re as _re

_FINAL_ANSWER_RE = _re.compile(r"final answer[:\s]*([^\n]+)", _re.I)

# Problem-type routing tables
_PAL_TOPICS     = frozenset({"arithmetic", "algebra", "prealgebra", "grounded"})
_SYMPY_TOPICS   = frozenset({
    "number_theory", "intermediate_algebra", "precalculus",
    "counting_and_probability",
})
_EXCLUDE_TOPICS = frozenset({"geometry"})  # spatial reasoning; cannot verify programmatically


def _extract_final_answer(solution: str) -> Optional[str]:
    """Extract the text after 'Final Answer:' from a solution string."""
    m = _FINAL_ANSWER_RE.search(solution)
    return m.group(1).strip() if m else None


def _pal_eval(answer_str: str) -> Optional[float]:
    """Tier 1: arithmetic / basic algebra via safe eval (no builtins, no names)."""
    try:
        val = eval(answer_str, {"__builtins__": {}}, {})  # noqa: S307
        f = float(val)
        return None if f != f else f  # NaN guard
    except Exception:
        return None


def _sympy_eval(answer_str: str) -> Optional[float]:
    """Tier 2: symbolic evaluation via SymPy for algebra, number theory, etc."""
    try:
        from sympy import sympify, N as _N  # type: ignore
        f = float(_N(sympify(answer_str), 15))
        return None if f != f else f  # NaN guard
    except Exception:
        return None


def _verify_self_play_answer(
    solutions: List[str],
    target_topic: str,
    target_difficulty: float,
) -> bool:
    """
    Tiered verification cascade for self-play groups.

    Returns True only when a majority of solutions agree on an answer that an
    independent tool (PAL eval or SymPy) can verify as a finite number.

    Returns False β€” drop this group, no gradient β€” when:
      * topic is geometry (spatial reasoning, can't verify programmatically)
      * difficulty >= 4.0 (should have been blocked at generation, guard here too)
      * no tool can parse a consistent numerical answer
      * fewer than half of solutions agree on the majority answer

    Coverage for GSM8K + MATH:
      GSM8K              β†’ PAL tier, ~95%+ verified
      MATH L1-L2 algebra β†’ PAL + SymPy fallback, ~80% verified
      MATH number theory / intermediate algebra β†’ SymPy primary, ~70% verified
      MATH geometry      β†’ excluded entirely (~3-5% of MATH)
      MATH L4-L5         β†’ excluded at generation time (see call site)
    """
    topic = target_topic.lower().replace(" ", "_")

    # Hard exclusions (guard even if called after generation-time check)
    if topic in _EXCLUDE_TOPICS or target_difficulty >= 4.0:
        return False

    answers: List[float] = []
    for sol in solutions:
        raw = _extract_final_answer(sol)
        if raw is None:
            continue

        val: Optional[float]
        if topic in _PAL_TOPICS or target_difficulty <= 2:
            val = _pal_eval(raw) or _sympy_eval(raw)
        elif topic in _SYMPY_TOPICS:
            val = _sympy_eval(raw) or _pal_eval(raw)
        else:
            # Unknown topic: try both
            val = _pal_eval(raw) or _sympy_eval(raw)

        if val is not None:
            answers.append(round(val, 6))

    if not answers:
        return False  # Tier 4: cannot verify β€” exclude

    majority = max(set(answers), key=answers.count)
    return answers.count(majority) >= max(1, len(solutions) // 2)


def compute_grounded_reward(
    question: str,
    solution: str,
    gold_final: str,
    math_env: CurriculumMathEnvironment,
) -> Dict[str, float]:
    """Score a solution against a known gold answer (grounded path).

    Returns a dict with:
      combined_score  – 0.50Γ—correct + 0.40Γ—process(prm_final,prm_mean) + 0.10Γ—fmt
      step_accuracy   – fraction of PRM steps rated > 0.5 (the core process metric)
      prm_mean_score  – PRM mean across all steps
      prm_final_score – PRM score on the final reasoning step
      gt_match        – bool, whether pred matches gold
      format_score    – format compliance score
    """
    result = math_env.compute_grounded_reward(
        question=question,
        solution=solution,
        gold_final=gold_final,
    )
    return {
        "combined_score":  float(result.get("combined_score",  0.0)),
        "step_accuracy":   float(result.get("step_accuracy",   0.0)),
        "lccp":            float(result.get("lccp",            0.0)),
        "prm_mean_score":  float(result.get("prm_mean_score",  0.0)),
        "prm_final_score": float(result.get("prm_final_score", 0.0)),
        "gt_match":        bool(result.get("gt_match",         False)),
        "format_score":    float(result.get("format_score",    0.0)),
    }


def compute_self_play_reward(
    question: str,
    solution: str,
    target_topic: str,
    target_difficulty: float,
    math_env: CurriculumMathEnvironment,
) -> Tuple[float, float, float, Dict]:
    """Score a self-generated question + solution (self-play path).

    Returns (combined_reward, question_reward, solution_reward, q_metrics).

    Reward breakdown: R = 0.40Γ—question_quality + 0.60Γ—solution_quality,
    where question_quality captures topic match, difficulty fit, clarity,
    novelty, and solvability β€” completing the Theme #4 self-improvement loop
    where the model is rewarded for generating *good challenges*, not only
    for solving them.

    q_metrics contains the full question quality breakdown:
      topic_match, difficulty_fit, clarity, novelty, solvability, overall_score
    """
    result = math_env.compute_reward(
        question=question,
        solution=solution,
        target_topic=target_topic,
        target_difficulty=target_difficulty,
    )
    combined  = float(result["combined_score"])
    sol_score = result.get("solution_metrics", {})
    s_reward  = float(sol_score.get("overall_score", 0.0)) if isinstance(sol_score, dict) else 0.0

    # question_reward is NOT a top-level key in compute_reward()'s return dict.
    # The question quality score lives inside question_metrics["overall_score"].
    # Key mapping from QuestionEvalResult.to_dict():
    #   overall_score    β†’ scalar  (overall question quality)
    #   topic_match      β†’ scalar
    #   difficulty_score β†’ scalar  (fit to target difficulty; named _score not _fit)
    #   clarity          β†’ scalar
    #   solvability_score→ scalar  (the dict version is under "solvability" — don't use that)
    #   novelty_combined β†’ scalar  (the dict version is under "novelty" β€” don't use that)
    q_metrics_raw = result.get("question_metrics", {}) or {}
    # Use the gated question reward (zeroed when solution is invalid) β€” this is
    # what actually contributed to combined_score, not the raw overall_score.
    q_reward = float(result.get("effective_question_reward", q_metrics_raw.get("overall_score", 0.0)))
    q_metrics: Dict = {
        "overall_score":  q_reward,
        "topic_match":    float(q_metrics_raw.get("topic_match",       0.0)),
        "difficulty_fit": float(q_metrics_raw.get("difficulty_score",  0.0)),
        "clarity":        float(q_metrics_raw.get("clarity",           0.0)),
        "novelty":        float(q_metrics_raw.get("novelty_combined",  0.0)),
        "solvability":    float(q_metrics_raw.get("solvability_score", 0.0)),
        # Chain integrity score from Phase 2+ unified calculator (None if inactive)
        "sp_chain_integrity_score": result.get("sp_chain_integrity_score"),
    }
    return combined, q_reward, s_reward, q_metrics


@torch.no_grad()
def generate_question(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    instruction: str,
    max_new_tokens: int,
    device: torch.device,
    temperature: float = 0.85,
) -> str:
    """Generate a math question from a curriculum instruction.

    Uses centralized prompts from src/config/prompts.py to ensure consistency
    across SFT training, GRPO, PPO, and inference.

    Returns the raw decoded question text (no special tokens).
    """
    # Use centralized prompt configuration
    messages = create_generator_messages(instruction)
    
    try:
        prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        # Fallback if chat template is missing
        system = messages[0]["content"]
        user = messages[1]["content"]
        prompt = f"{system}\n\n{user}\n"

    enc = tokenizer(
        prompt, return_tensors="pt", truncation=True, max_length=512
    ).to(device)
    prompt_len = enc["input_ids"].shape[1]

    stop_ids: List[int] = []
    if tokenizer.eos_token_id is not None:
        stop_ids.append(tokenizer.eos_token_id)
    im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
    if isinstance(im_end, int) and im_end not in stop_ids:
        stop_ids.append(im_end)

    out = model.generate(
        input_ids=enc["input_ids"],
        attention_mask=enc["attention_mask"],
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=0.95,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
        eos_token_id=stop_ids or None,
        use_cache=True,
    )
    return tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True).strip()


# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------

@torch.no_grad()
def generate_questions_batched(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    instruction: str,
    K_q: int,
    max_new_tokens: int,
    temperature: float,
    device: torch.device,
) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
    """
    Generate K_q question candidates from a single curriculum instruction in
    one batched model.generate() call.  Returns the same four-tuple as
    ``generate_solutions_batched`` so the question token IDs can be passed
    directly to ``grpo_loss_for_group`` for the question-level GRPO update.

    Uses the same centralized prompts (``create_generator_messages``) as
    ``generate_question()`` so the chat format is identical whether running
    single-question or batched two-phase generation.

    Returns:
        questions       : K_q decoded question strings
        input_ids_list  : K_q full (prompt+response) token ID tensors
        response_masks  : K_q bool masks (True = non-pad response token)
        old_log_probs   : K_q scalar tensors (sum log Ο€_old over response),
                          no_grad β€” used as denominator in IS ratio.
    """
    messages = create_generator_messages(instruction)
    try:
        prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        prompt = f"{system}\n\n{instruction}\n"

    stop_ids = _build_stop_token_ids(tokenizer)
    pad_id: int = (
        tokenizer.pad_token_id
        if tokenizer.pad_token_id is not None
        else tokenizer.eos_token_id
    )

    enc = tokenizer(
        prompt, return_tensors="pt", truncation=True, max_length=512
    ).to(device)
    prompt_len: int = enc["input_ids"].shape[1]

    input_ids_batch = enc["input_ids"].expand(K_q, -1).contiguous()
    attn_mask_batch = enc["attention_mask"].expand(K_q, -1).contiguous()

    model.eval()
    with torch.no_grad():
        out = model.generate(
            input_ids=input_ids_batch,
            attention_mask=attn_mask_batch,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=0.95,
            pad_token_id=pad_id,
            eos_token_id=stop_ids,
            use_cache=True,
        )

    questions: List[str] = []
    input_ids_list: List[torch.Tensor] = []
    response_masks: List[torch.Tensor] = []

    pad_id_t = torch.tensor(pad_id, device=device, dtype=out.dtype)
    for i in range(K_q):
        full_ids = out[i]
        response_section = full_ids[prompt_len:]
        mask = torch.zeros(full_ids.shape[0], dtype=torch.bool, device=device)
        mask[prompt_len:] = response_section != pad_id_t
        question = tokenizer.decode(response_section, skip_special_tokens=True).strip()
        questions.append(question)
        input_ids_list.append(full_ids)
        response_masks.append(mask)

    # Single batched forward pass for all K_q old log-probs (same trick as solutions).
    old_log_probs: List[torch.Tensor] = []
    with torch.no_grad():
        attn_mask_lp = (out != pad_id_t)
        attn_mask_lp[:, :prompt_len] = True
        batch_logits = model(
            input_ids=out,
            attention_mask=attn_mask_lp.long(),
            use_cache=False,
            return_dict=True,
        ).logits  # [K_q, total_len, vocab]

        for i in range(K_q):
            full_ids = out[i]
            mask = response_masks[i]
            shift_logits = batch_logits[i, :-1]
            shift_labels  = full_ids[1:]
            shift_mask    = mask[1:]
            lp_tokens = F.log_softmax(shift_logits, dim=-1)[
                torch.arange(shift_logits.size(0), device=device),
                shift_labels,
            ]
            resp_lps = lp_tokens[shift_mask]
            old_log_probs.append(
                resp_lps.sum().detach() if resp_lps.numel() > 0
                else torch.tensor(0.0, device=device)
            )

    return questions, input_ids_list, response_masks, old_log_probs

def _build_stop_token_ids(tokenizer: AutoTokenizer) -> List[int]:
    """
    Return a list of token IDs that should stop generation.

    Qwen2.5-chat models end turns with <|im_end|> (ID 151645).  If that
    token is not the same as eos_token_id we include both so that .generate()
    halts cleanly instead of running to max_new_tokens and emitting repetitive
    garbage.
    """
    stop_ids: List[int] = []
    if tokenizer.eos_token_id is not None:
        stop_ids.append(tokenizer.eos_token_id)
    im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    if isinstance(im_end_id, int) and im_end_id not in stop_ids:
        stop_ids.append(im_end_id)
    return stop_ids or None  # type: ignore[return-value]


def generate_solutions_batched(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    K: int,
    max_new_tokens: int,
    temperature: float,
    device: torch.device,
) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
    """
    Generate K solutions for a prompt in a **single batched** model.generate() call.

    Batching all K sequences together achieves near-100% GPU utilisation vs
    the old sequential loop (which was <20% utilised).  On an A100 with K=8,
    this is typically 4-8Γ— faster than K sequential calls.

    ``prompt`` must come from ``math_env.format_solution_prompt(question)``
    so the chat-template system/user wrapping exactly matches the SFT
    training format.

    Returns:
        solutions       : K decoded strings (prompt stripped, specials removed)
        input_ids_list  : K full (prompt+response) token ID tensors
        response_masks  : K bool masks (True = non-pad response token)
        old_log_probs   : K scalar tensors, sum(log Ο€_old(token)) over response,
                          computed no_grad β€” used for IS clip ratio in the loss.
    """
    stop_ids = _build_stop_token_ids(tokenizer)
    pad_id: int = (
        tokenizer.pad_token_id
        if tokenizer.pad_token_id is not None
        else tokenizer.eos_token_id
    )

    enc = tokenizer(
        prompt,
        return_tensors="pt",
        padding=False,
        truncation=True,
        max_length=1024,
    ).to(device)
    prompt_len: int = enc["input_ids"].shape[1]

    # Expand prompt K times along the batch dimension (no data copy).
    input_ids_batch = enc["input_ids"].expand(K, -1).contiguous()
    attn_mask_batch = enc["attention_mask"].expand(K, -1).contiguous()

    model.eval()
    with torch.no_grad():
        out = model.generate(
            input_ids=input_ids_batch,
            attention_mask=attn_mask_batch,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=0.9,
            pad_token_id=pad_id,
            eos_token_id=stop_ids,
            use_cache=True,
        )
        # out: [K, prompt_len + padded_response_len]

    # ── 1. Build masks and decode solutions ──────────────────────────────────
    solutions: List[str] = []
    input_ids_list: List[torch.Tensor] = []
    response_masks: List[torch.Tensor] = []

    pad_id_t = torch.tensor(pad_id, device=device, dtype=out.dtype)
    for i in range(K):
        full_ids = out[i]
        response_section = full_ids[prompt_len:]
        mask = torch.zeros(full_ids.shape[0], dtype=torch.bool, device=device)
        mask[prompt_len:] = response_section != pad_id_t
        solution = tokenizer.decode(response_section, skip_special_tokens=True)
        solutions.append(solution)
        input_ids_list.append(full_ids)
        response_masks.append(mask)

    # ── 2. Batched old_log_probs β€” ONE forward pass for all K sequences ───────
    # The old sequential approach called compute_sequence_log_prob K times
    # (K separate CPU→GPU round-trips + K forward passes).  A single batched
    # forward pass over out[K, total_len] gives the same result KΓ— faster.
    #
    # Attention mask: always attend to prompt tokens; attend to response tokens
    # only where they are non-pad.  This matches what the model saw during
    # model.generate() and prevents padding from distorting log probs.
    old_log_probs: List[torch.Tensor] = []
    with torch.no_grad():
        attn_mask_lp = (out != pad_id_t)          # [K, total_len]
        attn_mask_lp[:, :prompt_len] = True        # prompt always attended

        batch_logits = model(
            input_ids=out,
            attention_mask=attn_mask_lp.long(),
            use_cache=False,
            return_dict=True,
        ).logits  # [K, total_len, vocab]

        for i in range(K):
            full_ids = out[i]
            mask = response_masks[i]

            shift_logits = batch_logits[i, :-1]      # [total_len-1, vocab]
            shift_labels  = full_ids[1:]              # [total_len-1]
            shift_mask    = mask[1:]                  # [total_len-1]

            lp_tokens = F.log_softmax(shift_logits, dim=-1)[
                torch.arange(shift_logits.size(0), device=device),
                shift_labels,
            ]  # [total_len-1]
            resp_lps = lp_tokens[shift_mask]
            old_log_probs.append(
                resp_lps.sum().detach() if resp_lps.numel() > 0
                else torch.tensor(0.0, device=device)
            )

    return solutions, input_ids_list, response_masks, old_log_probs


def compute_sequence_log_prob(
    model: AutoModelForCausalLM,
    input_ids: torch.Tensor,
    response_mask: torch.Tensor,
) -> torch.Tensor:
    """
    Forward pass through model to get sum of log probs for response tokens.

    Returns scalar tensor (differentiable).
    """
    # input_ids: [seq_len]  β†’  unsqueeze to [1, seq_len]
    ids = input_ids.unsqueeze(0)
    # Causal LM: logits[i] predicts token[i+1]
    outputs = model(input_ids=ids, use_cache=False, return_dict=True)
    logits = outputs.logits[0]  # [seq_len, vocab]

    # Shift: predict token t+1 from logit at position t
    shift_logits = logits[:-1]           # [seq_len-1, vocab]
    shift_labels = input_ids[1:]         # [seq_len-1]
    shift_mask = response_mask[1:]       # [seq_len-1]  (response tokens)

    log_probs = F.log_softmax(shift_logits, dim=-1)  # [seq_len-1, vocab]
    token_log_probs = log_probs[
        torch.arange(shift_logits.size(0), device=shift_logits.device),
        shift_labels,
    ]  # [seq_len-1]

    # Sum log probs over response tokens only
    response_log_probs = token_log_probs[shift_mask]
    if response_log_probs.numel() == 0:
        return torch.tensor(0.0, requires_grad=True, device=input_ids.device)
    return response_log_probs.sum()


# ---------------------------------------------------------------------------
# GRPO update for one question group
# ---------------------------------------------------------------------------

def grpo_loss_for_group(
    model: AutoModelForCausalLM,
    input_ids_list: List[torch.Tensor],
    response_masks: List[torch.Tensor],
    rewards: List[float],
    old_log_probs: List[torch.Tensor],
    clip_eps: float = 0.2,
    kl_coef: float = 0.0,
    ref_model: Optional[AutoModelForCausalLM] = None,
    eps: float = 1e-8,
) -> Optional[torch.Tensor]:
    """
    Compute GRPO loss for a group of K solutions to the same question.

    IS clip (``clip_eps > 0``):
        ratio  = Ο€_ΞΈ(response) / Ο€_old(response)   [sequence level]
        L_GRPO = -min(ratio Γ— A, clip(ratio, 1-Ξ΅, 1+Ξ΅) Γ— A) / T

    Reference-policy KL penalty (``kl_coef > 0``, ``ref_model`` required):
        KL(Ο€_ΞΈ β€– Ο€_ref) β‰ˆ (log Ο€_ΞΈ βˆ’ log Ο€_ref) / T   per sequence
        L_total = L_GRPO + Ξ² Γ— KL

    The KL term acts as an anchor: it prevents the policy from drifting so
    far from its starting point that it forgets the SFT knowledge baked in
    during dual_task_v1 fine-tuning.  Ξ²=0.04 is a conservative starting
    value (matches DeepSeekMath GRPO default).

    Returns None if all rewards are identical (zero gradient signal).
    """
    rewards_arr = np.array(rewards, dtype=np.float32)
    std_r = rewards_arr.std()
    if std_r < eps:
        return None

    mean_r = rewards_arr.mean()
    advantages = (rewards_arr - mean_r) / (std_r + eps)
    advantages = np.clip(advantages, -5.0, 5.0)

    _device = next(model.parameters()).device
    group_loss = torch.tensor(0.0, device=_device)
    n_valid = 0

    model.train()
    for ids, mask, adv, old_lp in zip(
        input_ids_list, response_masks, advantages, old_log_probs
    ):
        new_lp = compute_sequence_log_prob(model, ids, mask)  # differentiable
        n_response = int(mask[1:].sum().item())
        if n_response == 0:
            continue

        adv_t = torch.tensor(adv, dtype=new_lp.dtype, device=_device)

        # ── GRPO surrogate (with optional IS clip) ────────────────────────
        if clip_eps > 0:
            ratio = torch.exp(new_lp - old_lp.to(_device).detach())
            surr_unclipped = ratio * adv_t / n_response
            surr_clipped   = (
                torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
                * adv_t / n_response
            )
            loss_i = -torch.min(surr_unclipped, surr_clipped)
        else:
            loss_i = -(adv_t * new_lp / n_response)

        # ── Reference-policy KL penalty ───────────────────────────────────
        # KL(Ο€_ΞΈ β€– Ο€_ref) = mean_token(log Ο€_ΞΈ βˆ’ log Ο€_ref)
        # Adding +Ξ²Γ—KL to the minimisation objective penalises drift from
        # the reference (frozen) checkpoint.  This is differentiable through
        # new_lp; ref_lp is always detached (no grad through frozen model).
        if kl_coef > 0.0 and ref_model is not None:
            with torch.no_grad():
                ref_lp = compute_sequence_log_prob(ref_model, ids, mask)
            kl_per_token = (new_lp - ref_lp.to(_device).detach()) / n_response
            loss_i = loss_i + kl_coef * kl_per_token

        group_loss = group_loss + loss_i
        n_valid += 1

    if n_valid == 0:
        return None
    return group_loss / n_valid


# ---------------------------------------------------------------------------
# Evaluation helpers
# ---------------------------------------------------------------------------

def _log_eval_result(label: str, res: Dict, best: Optional[float]) -> None:
    """Print a structured evaluation summary that mirrors the training objective."""
    cs      = float(res.get("combined_score",  0.0))
    cr      = float(res.get("correct_rate",    0.0))
    step_a  = float(res.get("step_accuracy",   0.0))
    lccp    = float(res.get("lccp",            0.0))
    prm     = float(res.get("prm_mean",        0.0))
    prm_fin = float(res.get("prm_final",       0.0))
    fmt     = float(res.get("format_mean",     0.0))
    n_sc    = int(res.get("n_scored", res.get("total", 0)))
    fa_acc  = float(res.get("final_answer_accuracy", cr))
    pak     = res.get("pass_at_k")
    pak_k   = int(res.get("pass_at_k_k", 4))

    best_str = f" (best={best:.4f})" if best is not None else ""
    logger.info(
        "Training Score  [%s]: %.4f%s  |  n=%d",
        label, cs, best_str, n_sc,
    )
    logger.info(
        "  Components    : 0.50Γ—correct(%.1f%%) + 0.40Γ—process + 0.10Γ—fmt(%.3f)",
        100 * cr, fmt,
    )
    logger.info(
        "  Process score : prm_mean=%.3f  prm_final=%.3f  β†’ weighted=%.3f",
        prm, prm_fin, 0.60 * prm_fin + 0.40 * prm,
    )
    logger.info(
        "  Step accuracy : %.1f%%  (bag-of-steps: fraction of steps PRM >0.5)",
        100 * step_a,
    )
    logger.info(
        "  Chain integrity (LCCP): %.1f%%  ← fraction of steps before first failure\n"
        "    [LCCP=100%% β†’ all steps correct; LCCP=0%% β†’ first step wrong]",
        100 * lccp,
    )
    if pak is not None:
        logger.info(
            "  pass@%d (T=0.8): %.1f%%  |  greedy correct: %.1f%%  "
            "← ceiling vs floor gap",
            pak_k, 100 * pak, 100 * cr,
        )
    logger.info(
        "  (debug) final-answer accuracy: %.1f%%",
        100 * fa_acc,
    )


def evaluate_policy(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    eval_data_path: str,
    max_samples: int,
    max_new_tokens: int,
    math_env: Optional[Any] = None,
    pass_at_k: int = 4,
) -> Dict[str, object]:
    """Run GSM8K evaluation using the SAME reward formula as GRPO training.

    When *math_env* is supplied a ``reward_fn`` is constructed that calls
    ``math_env.compute_grounded_reward(question, solution, gold)``.  This
    returns ``combined_score = 0.50Γ—correct + 0.40Γ—process(0.60Γ—prm_final
    + 0.40Γ—prm_mean) + 0.10Γ—format``, making the eval metric IDENTICAL to
    the GRPO training objective.  Any improvement in step quality, chain
    integrity, or format compliance shows up immediately in the accuracy
    number instead of being hidden behind the coarse binary final-answer
    signal.
    """
    if not Path(eval_data_path).exists():
        return {"accuracy": 0.0, "combined_score": 0.0, "total": 0}
    model.eval()

    reward_fn = None
    if math_env is not None:
        import logging as _log_mod
        _mec_logger  = _log_mod.getLogger("src.rl.math_environment_curriculum")
        _prm_logger  = _log_mod.getLogger("src.rl.prm_scorer")

        def reward_fn(question: str, solution: str, gold: str) -> Dict:
            """Thin wrapper that silences per-sample INFO logs during eval."""
            _old_mec = _mec_logger.level
            _old_prm = _prm_logger.level
            _mec_logger.setLevel(_log_mod.WARNING)
            _prm_logger.setLevel(_log_mod.WARNING)
            try:
                return math_env.compute_grounded_reward(question, solution, gold)
            finally:
                _mec_logger.setLevel(_old_mec)
                _prm_logger.setLevel(_old_prm)

    results = evaluate_gsm8k(
        model=model,
        tokenizer=tokenizer,
        data_path=eval_data_path,
        max_samples=max_samples,
        max_new_tokens=max_new_tokens,
        reward_fn=reward_fn,
        pass_at_k=pass_at_k,
        dataset_name=_infer_eval_dataset_name(eval_data_path),
    )
    model.train()
    return results


# ---------------------------------------------------------------------------
# Main training loop
# ---------------------------------------------------------------------------

def main() -> None:
    parser = argparse.ArgumentParser(description="GRPO training for self-improvement math")
    parser.add_argument("--base-model", default="checkpoints/dual_task_v1")
    parser.add_argument("--output-dir", default="checkpoints/grpo")
    parser.add_argument("--gsm8k-data", default="data/sft/gsm8k_sft.jsonl")
    parser.add_argument("--eval-data-path", default="data/sft/dual_task_val.jsonl")
    parser.add_argument("--num-iterations", type=int, default=30)
    parser.add_argument(
        "--group-size", type=int, default=4,
        help="K: number of solutions per question per GRPO group (default 4).",
    )
    parser.add_argument(
        "--q-group-size", type=int, default=1,
        help="K_q: question candidates per self-play group (default 1 = disabled). "
             "When β‰₯2, a second question-level GRPO update is added: K_q questions are "
             "sampled from the same instruction, each solved group-size times; the "
             "per-question reward (mean of its M solution rewards) drives a GRPO update "
             "on the question tokens.  Recommended: 2 with --group-size 4 to keep "
             "total self-play compute the same as K_q=1 with group-size 8.",
    )
    parser.add_argument(
        "--questions-per-iter", type=int, default=16,
        help="Number of questions per training iteration (default 16).",
    )
    parser.add_argument("--learning-rate", type=float, default=5e-6)
    parser.add_argument("--max-new-tokens", type=int, default=400)
    parser.add_argument("--temperature", type=float, default=0.8)
    parser.add_argument("--eval-every", type=int, default=5)
    parser.add_argument("--eval-max-samples", type=int, default=250)
    parser.add_argument("--eval-max-new-tokens", type=int, default=512)
    parser.add_argument(
        "--eval-pass-at-k", type=int, default=0,
        help="Number of sampled solutions per eval problem for pass@k (0 to disable). "
             "Makes eval directly comparable to training batch_acc (both K samples at T=0.8). "
             "Disabled by default β€” enable with e.g. --eval-pass-at-k 4 for demo runs only "
             "(adds KΓ—eval_samples extra forward passes).",
    )
    parser.add_argument("--use-prm", dest="use_prm", action="store_true", default=True)
    parser.add_argument("--no-prm", dest="use_prm", action="store_false")
    parser.add_argument("--prm-model", default="Qwen/Qwen2.5-Math-PRM-7B")
    parser.add_argument("--skip-initial-eval", action="store_true")
    parser.add_argument("--run-name", default=None)
    parser.add_argument("--max-grad-norm", type=float, default=1.0)
    parser.add_argument(
        "--kl-coef", type=float, default=0.04,
        help="Reference-policy KL penalty coefficient Ξ². 0 = disabled. Default 0.04.",
    )
    parser.add_argument(
        "--math-data", type=str, default=None,
        help="Path to MATH dataset JSONL. If absent, downloads from HuggingFace "
             "(competition_math) and caches to data/math/math_numeric.jsonl.",
    )
    parser.add_argument(
        "--math-mix-ratio", type=float, default=0.3,
        help="Fraction of each question batch drawn from MATH (vs GSM8K). "
             "0 = GSM8K only, 1 = MATH only. Default 0.3.",
    )
    parser.add_argument(
        "--math-mix-ratio-late", type=float, default=None,
        help="If set, ramp MATH fraction from --math-mix-ratio to this value "
             "starting at iter 15 (linear ramp over next 10 iters). "
             "Example: --math-mix-ratio 0.3 --math-mix-ratio-late 0.5 "
             "raises difficulty progressively once the policy is stable.",
    )
    parser.add_argument(
        "--math-ramp-start", type=int, default=15,
        help="Iteration at which to begin the MATH ratio ramp. Default 15.",
    )
    parser.add_argument(
        "--math-max-difficulty", type=int, default=3,
        help="Maximum MATH difficulty level to include (1-5). Default 3.",
    )
    parser.add_argument(
        "--clip-eps", type=float, default=0.2,
        help="Importance-sampling clip ratio Ξ΅ (PPO-style clip applied inside GRPO). "
             "0 = disabled (plain GRPO). Default 0.2.",
    )
    parser.add_argument(
        "--warmup-iters", type=int, default=3,
        help="Number of linear LR warmup iterations before cosine decay. Default 3.",
    )
    parser.add_argument(
        "--min-lr-ratio", type=float, default=0.1,
        help="Cosine decay floor as a fraction of peak LR (default 0.1 = 10%%).",
    )
    parser.add_argument(
        "--difficulty-alpha", type=float, default=2.0,
        help="Sharpness of difficulty-weighted question sampling. "
             "Higher = stronger preference for on-the-margin questions (win_rate β‰ˆ 0.5). "
             "0 = uniform random (default behaviour). Default 2.0.",
    )
    parser.add_argument(
        "--overlong-filter", dest="overlong_filter",
        action="store_true", default=True,
        help="Skip solutions that hit max-new-tokens (truncated = no Final Answer). Default on.",
    )
    parser.add_argument(
        "--no-overlong-filter", dest="overlong_filter", action="store_false",
        help="Disable overlong-response filtering.",
    )
    parser.add_argument(
        "--save-every", type=int, default=1,
        help="Save a full checkpoint every N iterations (default 1 = every iter). "
             "Best-policy is always saved when accuracy improves, independently of this flag.",
    )
    parser.add_argument(
        "--keep-last", type=int, default=0,
        help="Keep only the last K iter_* checkpoints on disk (0 = keep all). "
             "best_policy/ is never pruned.",
    )
    parser.add_argument(
        "--self-play-ratio", type=float, default=0.3,
        help="Fraction of each question batch that uses SELF-PLAY (model generates the "
             "question from a curriculum instruction, then solves it, rewarded on "
             "0.40 Γ— question_quality + 0.60 Γ— solution_quality). "
             "The remaining (1 - ratio) uses GROUNDED questions from GSM8K / MATH with "
             "gold-answer reward. "
             "0.0 = fully grounded (original behaviour), 1.0 = fully self-play. "
             "Default 0.3 β€” mirrors the PPO default of 30%% grounded / 70%% self-play "
             "(inverted here because grounded is our primary accuracy signal).",
    )
    # ── Phase-curriculum parameters ───────────────────────────────────────────
    parser.add_argument(
        "--min-warmup", type=int, default=10,
        help="Minimum iterations in Phase 1 (grounded-only) before considering graduation "
             "to Phase 2 (self-play ramp). Prevents graduating on a lucky early batch. "
             "Default 10.",
    )
    parser.add_argument(
        "--selfplay-gt-thresh", type=float, default=0.55,
        help="gt_match_rate threshold required to graduate from Phase 1 to Phase 2. "
             "Measures raw answer correctness (SymPy exact match), not reward-gamed "
             "combined_score. Default 0.55.",
    )
    parser.add_argument(
        "--selfplay-grounded-thresh", type=float, default=0.60,
        help="grounded_accuracy (combined_score > 0.5) threshold for Phase 1 graduation. "
             "Default 0.60.",
    )
    parser.add_argument(
        "--selfplay-step-thresh", type=float, default=0.65,
        help="step_accuracy (PRM steps rated > 0.5) threshold for Phase 1 graduation. "
             "Ensures the model has learned clean step format before entering self-play. "
             "Default 0.65.",
    )
    parser.add_argument(
        "--selfplay-ramp-iters", type=int, default=20,
        help="Number of iterations to ramp self-play ratio from ~0%% to --self-play-ratio "
             "(Phase 2). Grounded anchor stays at β‰₯30%% throughout. Default 20.",
    )
    parser.add_argument(
        "--grounded-floor", type=float, default=0.50,
        help="Minimum gt_match_rate to maintain during Phase 3. If it falls below this "
             "value, self-play is suspended until grounded performance recovers. "
             "Should be slightly below --selfplay-gt-thresh. Default 0.50.",
    )
    # ── Unified accuracy calculator parameters ────────────────────────────────
    parser.add_argument(
        "--extractor-model", default="Qwen/Qwen2.5-0.5B-Instruct",
        help="Small model used for step chain extraction in the unified accuracy "
             "calculator (Phase 2+). Loaded in 4-bit to minimise VRAM. "
             "Default Qwen/Qwen2.5-0.5B-Instruct.",
    )
    parser.add_argument(
        "--extraction-cache", default=None,
        help="Path to a pre-built JSON extraction cache from "
             "scripts/precompute_extraction_cache.py. When provided, grounded-data "
             "extractions are served from cache instead of calling the extractor LLM "
             "at training time. Only novel self-play solutions require live extraction. "
             "Default None (extraction always uses the LLM).",
    )
    args = parser.parse_args()

    # ── Run identity ─────────────────────────────────────────────────────────
    # Establish run_name first β€” everything that follows (including log paths)
    # derives from it.
    run_name = args.run_name or f"grpo_{datetime.now():%Y%m%d_%H%M%S}"
    out_dir = Path(args.output_dir) / run_name
    out_dir.mkdir(parents=True, exist_ok=True)

    # ── Log directory ─────────────────────────────────────────────────────────
    # One canonical directory for ALL run artefacts that are not model weights:
    #   console_output.log  β€” full terminal mirror (logger.* + print + tqdm)
    #   config.json         β€” serialised CLI args for reproducibility
    #   metrics.csv         β€” one row per iteration, written live
    #   summary.json        β€” written at the end of training
    log_dir = Path("logs") / "grpo" / run_name
    log_dir.mkdir(parents=True, exist_ok=True)

    # ── Console log file ─────────────────────────────────────────────────────
    console_log_path = log_dir / "console_output.log"
    _console_log_file = console_log_path.open("a", encoding="utf-8", buffering=1)

    # 1) FileHandler on the root logger β†’ every logger.*() call goes to file.
    #    This is necessary because logging.StreamHandler stores a reference to
    #    sys.stderr at *creation* time (inside logging.basicConfig above), so
    #    reassigning sys.stderr later has no effect on existing handlers.
    _file_handler = _add_file_logging(console_log_path)

    # 2) TeeStream on sys.stdout / sys.stderr β†’ every print() / tqdm bar /
    #    library write also goes to file.  Both together cover 100% of output.
    _original_stdout = sys.stdout
    _original_stderr = sys.stderr
    sys.stdout = TeeStream(_original_stdout, _console_log_file)
    sys.stderr = TeeStream(_original_stderr, _console_log_file)

    logger.info("=" * 70)
    logger.info("GRPO run: %s", run_name)
    logger.info("Checkpoints : %s", out_dir)
    logger.info("Logs        : %s", log_dir)
    logger.info("Console log : %s", console_log_path)
    logger.info("=" * 70)

    # ── Persist config for reproducibility ───────────────────────────────────
    (log_dir / "config.json").write_text(
        json.dumps(vars(args), indent=2, default=str), encoding="utf-8"
    )

    # ── Live CSV metrics writer ───────────────────────────────────────────────
    # Written one row per iteration so you can tail / open in Excel mid-run.
    _metrics_csv_path = log_dir / "metrics.csv"
    _csv_file: Optional[Any] = None
    _csv_writer: Optional[Any] = None

    def _append_metrics_csv(row: Dict[str, Any]) -> None:
        """Append one metrics row to metrics.csv; writes header on first call."""
        nonlocal _csv_file, _csv_writer
        # Normalise floats to fixed precision so the CSV is human-readable.
        flat = {
            k: (f"{v:.6f}" if isinstance(v, float) else v)
            for k, v in row.items()
        }
        if _csv_writer is None:
            _csv_file = _metrics_csv_path.open("w", newline="", encoding="utf-8")
            _csv_writer = csv.DictWriter(
                _csv_file,
                fieldnames=list(flat.keys()),
                extrasaction="ignore",
            )
            _csv_writer.writeheader()
        _csv_writer.writerow(flat)
        _csv_file.flush()  # type: ignore[union-attr]

    # ── Teardown: restore streams and close files on any exit path ───────────
    # atexit runs unconditionally β€” on normal completion, keyboard interrupt,
    # unhandled exception, or OOM crash.  This is equivalent to a finally block
    # without requiring the entire training body to be re-indented.
    def _teardown_logging() -> None:
        sys.stdout = _original_stdout
        sys.stderr = _original_stderr
        logging.getLogger().removeHandler(_file_handler)
        if not getattr(_file_handler.stream, "closed", False):
            _file_handler.close()
        if _csv_file is not None and not getattr(_csv_file, "closed", False):
            _csv_file.close()
        if not _console_log_file.closed:
            _console_log_file.close()

    atexit.register(_teardown_logging)

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    attn_impl = select_attn_implementation()
    logger.info("Device: %s | attn: %s", device, attn_impl)
    if torch.cuda.is_available():
        _gpu = torch.cuda.get_device_properties(0)
        logger.info(
            "GPU: %s | %.1f GB VRAM | capability sm_%d%d",
            _gpu.name, _gpu.total_memory / 1e9, _gpu.major, _gpu.minor,
        )
    logger.info(
        "Run config: K=%d K_q=%d N=%d lr=%.1e T=%.2f max_new=%d | "
        "clip_eps=%.2f kl_coef=%.4f warmup=%d | diff_alpha=%.1f | "
        "self_play=%.0f%% grounded=%.0f%% | "
        "math_mix=%.0f%% math_maxdiff=%d | overlong_filter=%s | "
        "eval_every=%d eval_N=%d | grad_clip=%.2f save_every=%d keep_last=%d | "
        "question_GRPO=%s",
        args.group_size, args.q_group_size, args.questions_per_iter, args.learning_rate,
        args.temperature, args.max_new_tokens,
        args.clip_eps, args.kl_coef, args.warmup_iters,
        args.difficulty_alpha,
        100 * args.self_play_ratio, 100 * (1 - args.self_play_ratio),
        100 * args.math_mix_ratio, args.math_max_difficulty,
        args.overlong_filter,
        args.eval_every, args.eval_max_samples,
        args.max_grad_norm, args.save_every, args.keep_last,
        f"ENABLED (K_q={args.q_group_size})" if args.q_group_size > 1 else "disabled",
    )

    # ── Load model ──────────────────────────────────────────────────────────
    logger.info("Loading model from %s ...", args.base_model)
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # SFT adapter checkpoints often don't save chat_template, which causes
    # tokenizer.apply_chat_template() to raise an error inside evaluate_gsm8k
    # β€” silently swallowed there, giving 0% accuracy even for a capable model.
    if tokenizer.chat_template is None:
        _base_model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
        _meta_file = Path(args.base_model) / "pipeline_meta.json"
        if _meta_file.exists():
            _meta = json.loads(_meta_file.read_text(encoding="utf-8"))
            _base_model_name = _meta.get("base_model", _base_model_name)
        logger.info(
            "Tokenizer has no chat_template; loading from base model %s", _base_model_name
        )
        try:
            _base_tok = AutoTokenizer.from_pretrained(_base_model_name, trust_remote_code=True)
            if _base_tok.chat_template is not None:
                tokenizer.chat_template = _base_tok.chat_template
                logger.info("Chat template loaded successfully.")
        except Exception as _e:
            logger.warning("Could not load chat template from base model: %s", _e)

    # PEFT <= 0.12 crashes inside merge_and_unload() when the
    # transformers.integrations.tensor_parallel module is missing.
    if "transformers.integrations.tensor_parallel" not in sys.modules:
        sys.modules["transformers.integrations.tensor_parallel"] = types.ModuleType(
            "tensor_parallel"
        )

    model_path = Path(args.base_model)
    is_adapter = (model_path / "adapter_config.json").exists()

    load_kwargs = dict(
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map={"": device},
        trust_remote_code=True,
        attn_implementation=attn_impl,
    )

    if is_adapter:
        # Determine actual base model from pipeline_meta.json (written by SFT pipeline).
        _meta_path = model_path / "pipeline_meta.json"
        _base_for_weights = "Qwen/Qwen2.5-Math-1.5B-Instruct"
        if _meta_path.exists():
            _base_for_weights = json.loads(
                _meta_path.read_text(encoding="utf-8")
            ).get("base_model", _base_for_weights)
        logger.info("Detected PEFT adapter β€” loading base %s then merging %s",
                    _base_for_weights, args.base_model)
        _base = AutoModelForCausalLM.from_pretrained(_base_for_weights, **load_kwargs)
        model = PeftModel.from_pretrained(_base, args.base_model).merge_and_unload()
        model = model.to(device)
    else:
        model = AutoModelForCausalLM.from_pretrained(args.base_model, **load_kwargs)

    # PEFT.merge_and_unload() leaves requires_grad=False on every param.
    # Re-enable unconditionally so GRPO's optimizer actually updates weights.
    params_before = sum(p.numel() for p in model.parameters() if p.requires_grad)
    for p in model.parameters():
        p.requires_grad_(True)
    params_after = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if params_before == 0 and params_after > 0:
        logger.warning(
            "All parameters were frozen on load (PEFT merge_and_unload bug). "
            "Re-enabled requires_grad β€” any prior frozen runs were training nothing."
        )

    # Flash-Attn 2 turns attention memory from O(TΒ²) to O(T), so gradient
    # checkpointing gives almost no extra saving while costing ~30% more
    # backward time.  Disable it when Flash is active (mirrors PPO runner).
    # gradient_checkpointing_enable requires use_reentrant=False on modern
    # PyTorch β€” the default True is deprecated and causes silent issues.
    # Also set use_cache=False: HF models can't use KV cache together with
    # gradient checkpointing (incompatible memory management).
    flash_active = attn_impl == "flash_attention_2"
    if not flash_active:
        model.gradient_checkpointing_enable(
            gradient_checkpointing_kwargs={"use_reentrant": False}
        )
        if hasattr(model, "config"):
            model.config.use_cache = False
        logger.info("Gradient checkpointing ENABLED (use_reentrant=False, use_cache=False).")
    else:
        logger.info(
            "Flash-Attn 2 active β€” gradient checkpointing OFF "
            "(Flash already gives O(T) attention memory)."
        )

    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_total = sum(p.numel() for p in model.parameters())
    logger.info(
        "Trainable parameters: %s / %s (%.1f%%)",
        f"{n_params:,}", f"{n_total:,}", 100.0 * n_params / max(n_total, 1),
    )

    # ── Reference policy (frozen copy) ───────────────────────────────────────
    # A deep copy of the policy at t=0, kept frozen forever.  Used in the KL
    # penalty to anchor the policy against catastrophic forgetting of SFT
    # knowledge: L += Ξ² Γ— (log Ο€_ΞΈ - log Ο€_ref) / T.
    # Memory cost: ~3 GB (1.5B Γ— 2 bytes BF16) β€” negligible on 80 GB.
    ref_model: Optional[AutoModelForCausalLM] = None
    if args.kl_coef > 0.0:
        logger.info(
            "Creating frozen reference policy (kl_coef=%.4f, ~%.1f GB VRAM)...",
            args.kl_coef, sum(p.numel() for p in model.parameters()) * 2 / 1e9,
        )
        ref_model = copy.deepcopy(model)
        ref_model.requires_grad_(False)
        ref_model.eval()
        logger.info("Reference policy ready.")
    else:
        logger.info("KL coef = 0 β€” no reference policy created.")

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.learning_rate,
        fused=torch.cuda.is_available(),
    )

    # ── LR schedule: linear warmup β†’ cosine decay ────────────────────────────
    # Linear warmup avoids the large initial gradient spike when the policy
    # starts updating from an SFT checkpoint.  Cosine decay then smoothly
    # reduces LR toward min_lr as training progresses (standard in RLHF runs).
    from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
    _n_warmup = max(1, args.warmup_iters)
    _n_total  = max(1, args.num_iterations)
    _n_decay  = max(1, _n_total - _n_warmup)
    _min_lr   = args.learning_rate * args.min_lr_ratio
    _warmup_sched = LinearLR(
        optimizer,
        start_factor=0.1,
        end_factor=1.0,
        total_iters=_n_warmup,
    )
    _cosine_sched = CosineAnnealingLR(
        optimizer,
        T_max=_n_decay,
        eta_min=_min_lr,
    )
    scheduler = SequentialLR(
        optimizer,
        schedulers=[_warmup_sched, _cosine_sched],
        milestones=[_n_warmup],
    )
    logger.info(
        "LR schedule: %.1e warmup(%d iters) β†’ cosine decay(%d iters, min=%.1e)",
        args.learning_rate, _n_warmup, _n_decay, _min_lr,
    )

    # ── Load data ────────────────────────────────────────────────────────────
    gsm8k_pairs = load_gsm8k(args.gsm8k_data)
    if not gsm8k_pairs:
        logger.error("No training data found at %s β€” cannot train. Exiting.", args.gsm8k_data)
        sys.exit(1)

    # Optional MATH dataset mixing
    math_pairs: List[Dict[str, str]] = []
    if args.math_mix_ratio > 0.0:
        math_pairs = load_math_dataset(
            local_path=args.math_data,
            max_difficulty=args.math_max_difficulty,
        )
        if math_pairs:
            logger.info(
                "MATH mixing: %.0f%% MATH (%d problems) + %.0f%% GSM8K (%d problems)",
                100 * args.math_mix_ratio, len(math_pairs),
                100 * (1 - args.math_mix_ratio), len(gsm8k_pairs),
            )
        else:
            logger.warning("No MATH pairs loaded β€” using GSM8K only.")

    # Combined pool used for difficulty sampling; kept separate for VRAM-aware
    # batch construction (sampler draws from each pool proportionally).
    qa_pairs = gsm8k_pairs  # for reward env (all GSM8K gold answers needed)

    # ── Load PRM (optional) ─────────────────────────────────────────────────
    prm: Optional[ProcessRewardScorer] = None
    if args.use_prm:
        try:
            prm = ProcessRewardScorer(
                model_name=args.prm_model,
                device=device,
                load_in_4bit=True,
            )
            logger.info("PRM loaded: %s (4-bit)", args.prm_model)
        except Exception as exc:
            logger.warning("PRM load failed (%s); running without PRM.", exc)

    # Build a minimal math_env just for its reward utilities (compute_grounded_reward).
    # value_model=None is safe: it's only stored as self.value and never invoked on
    # the grounded-reward path, so GRPO avoids the ~3 GB ValueHead backbone entirely.
    from src.rl.unified_accuracy import StepChainExtractor, UnifiedAccuracyCalculator
    _extractor = StepChainExtractor(
        model_name=args.extractor_model,
        device=str(device),
        cache_path=args.extraction_cache,
    )
    _unified_calc = UnifiedAccuracyCalculator(extractor=_extractor, question_evaluator=None)
    logger.info(
        "Unified accuracy calculator ready (extractor=%s, cache=%s)",
        args.extractor_model,
        args.extraction_cache or "none",
    )
    # Eagerly load the extractor model now to avoid a 30–60 s stall on the
    # first training iteration that triggers live (non-cached) extraction.
    logger.info("Warming up step-chain extractor (eager load)...")
    _extractor.warmup()
    logger.info("Extractor warmup complete")

    # ── LLM-backed question classifier (replaces keyword regex) ─────────────
    # Uses the already-loaded policy model for topic classification during
    # self-play reward computation. ~60-120 ms per call, cached, falls back
    # to regex on any error. Dramatically more accurate than keyword matching
    # for geometry, calculus, competition_math, and statistics.
    from src.rl.llm_question_classifier import LLMQuestionClassifier
    _llm_classifier = LLMQuestionClassifier(
        model=model,
        tokenizer=tokenizer,
        device=device,
        cache_size=10_000,
    )

    math_env = CurriculumMathEnvironment(
        policy_model=model,
        value_model=None,
        tokenizer=tokenizer,
        # Feed all training questions as the novelty reference set so
        # session_novelty is measured against the actual training distribution β€”
        # a self-play question that mimics a dataset question gets low novelty.
        reference_questions=[p["question"] for p in gsm8k_pairs],
        grounded_qa_pairs=qa_pairs,
        prm_scorer=prm,
        max_solution_tokens=args.max_new_tokens,
        device=device,
        unified_accuracy_calc=_unified_calc,
    )
    # Inject LLM classifier into the question quality evaluator
    math_env.question_evaluator.classifier = _llm_classifier
    # Wire the question_evaluator into the unified calc after math_env is available
    _unified_calc.question_evaluator = math_env.question_evaluator

    # Bootstrap curriculum from dataset skill_ids when the training data
    # contains structured records (NuminaMath / OpenMathInstruct format).
    # Falls back to the keyword-classifier path for plain GSM8K.
    _raw_records: list = []
    _train_path = Path(args.gsm8k_data)
    if _train_path.exists():
        with _train_path.open(encoding="utf-8") as _f:
            for _line in _f:
                _line = _line.strip()
                if _line:
                    try:
                        _raw_records.append(json.loads(_line))
                    except json.JSONDecodeError:
                        pass
    if any("skill_id" in r for r in _raw_records[:20]):
        logger.info(
            "Detected structured dataset (%d records) β€” bootstrapping "
            "curriculum from skill_ids instead of keyword classifier.",
            len(_raw_records),
        )
        math_env.curriculum_manager.initialize_from_dataset(_raw_records)
    else:
        logger.info("Plain dataset detected β€” using keyword-classifier bootstrap.")

    # ── Difficulty-adaptive sampling state ───────────────────────────────────
    # Track per-question win-rate.  Questions where the model scores correctly
    # 20-80% of the time are "on the margin" and provide the richest gradient
    # signal.  Questions it always gets right (win_rateβ‰ˆ1) or always gets wrong
    # (win_rateβ‰ˆ0) contribute little after the first few iterations.
    from collections import defaultdict
    _q_wins:     Dict[str, int] = defaultdict(int)
    _q_attempts: Dict[str, int] = defaultdict(int)

    def _question_key(q: str) -> str:
        """Stable hash fingerprint β€” collision-resistant for any pool size."""
        import hashlib
        return hashlib.md5(q.encode(), usedforsecurity=False).hexdigest()

    def _sample_by_difficulty(
        pool: List[Dict[str, str]], n: int, alpha: float
    ) -> List[Dict[str, str]]:
        """
        Sample ``n`` questions from ``pool``, weighting by how informative each is.

        Informativeness = 1 - |win_rate - 0.5| Γ— 2   ∈ [0, 1]
          win_rate = 0.0 or 1.0  β†’ informativeness = 0  (model already knows / lost cause)
          win_rate = 0.5         β†’ informativeness = 1  (most uncertain = best signal)

        ``alpha`` sharpens the weighting (higher = stronger preference for win_rateβ‰ˆ0.5).
        Unseen questions get weight 0.75 to encourage exploration.
        A 5% floor prevents any question from being permanently excluded.
        """
        if alpha <= 0.0:
            return random.sample(pool, min(n, len(pool)))

        weights = []
        for qa in pool:
            key = _question_key(qa["question"])
            att = _q_attempts[key]
            if att == 0:
                w = 0.75
            else:
                win_rate = _q_wins[key] / att
                info = 1.0 - abs(win_rate - 0.5) * 2.0  # ∈ [0, 1]
                w = max(info ** alpha, 0.05)
            weights.append(w)

        total_w = sum(weights)
        probs = [w / total_w for w in weights]
        chosen = np.random.choice(
            len(pool), size=min(n, len(pool)), replace=False, p=probs
        )
        return [pool[i] for i in chosen]

    # ── Metrics log ─────────────────────────────────────────────────────────
    metrics_log: List[Dict] = []

    # ── Initial eval ─────────────────────────────────────────────────────────
    if not args.skip_initial_eval:
        logger.info("=" * 70)
        logger.info("INITIAL EVALUATION (Iteration 0)")
        logger.info("=" * 70)
        initial_eval = evaluate_policy(
            model, tokenizer,
            args.eval_data_path, args.eval_max_samples, args.eval_max_new_tokens,
            math_env=math_env,
            pass_at_k=args.eval_pass_at_k,
        )
        # accuracy == combined_score = 0.50Γ—correct + 0.40Γ—process(prm_final,prm_mean) + 0.10Γ—fmt
        # This is identical to the GRPO training objective.
        _log_eval_result("INITIAL (iter 0)", initial_eval, best=None)
        metrics_log.append({"iteration": 0, **initial_eval})
        best_accuracy  = float(initial_eval.get("accuracy",     0.0))
        best_combined  = float(initial_eval.get("combined_score", 0.0))
        best_prm_mean  = float(initial_eval.get("prm_mean",     0.0))
    else:
        best_accuracy = 0.0
        best_combined = 0.0
        best_prm_mean = 0.0

    # ── Training curriculum phase FSM ────────────────────────────────────────
    # Phase 1 β€” GROUNDED_ONLY: self-play ratio is forced to 0 until the model
    #   has established reliable answer correctness (gt_match_rate) and step
    #   quality (step_accuracy) on grounded data.
    # Phase 2 β€” SELFPLAY_RAMP: self-play ratio ramps from ~0 β†’ self_play_ratio
    #   ceiling over selfplay_ramp_iters, keeping β‰₯30% grounded as an anchor.
    # Phase 3 β€” CONTINUOUS: ratio holds at ceiling; grounded floor is monitored
    #   and self-play is suspended whenever gt_match_rate drops below the floor.
    from enum import Enum, auto as _auto

    class _Phase(Enum):
        GROUNDED_ONLY = _auto()
        SELFPLAY_RAMP = _auto()
        CONTINUOUS    = _auto()

    _phase: _Phase = _Phase.GROUNDED_ONLY
    _selfplay_iterations: int = 0    # iterations spent in Phase 2+
    _selfplay_suspended: bool = False
    _effective_sp_ratio: float = 0.0  # computed each iteration from phase

    # ── Chain scoring calibration state ──────────────────────────────────────
    # During Phase 2 SELFPLAY_RAMP the extractor runs in shadow mode (computing
    # scores but NOT affecting rewards) to build a rolling calibration window.
    # use_chain_scoring only flips True when both the chain↔PRM correlation AND
    # the extraction success rate cross their thresholds β€” a data-driven gate,
    # not a schedule-driven one.
    _use_chain_as_primary: bool = False     # True once calibration passes
    _chain_prm_correlation: float = 0.0    # rolling Pearson r (chain vs PRM)
    _extraction_success_rate: float = 0.0  # rolling extraction success fraction
    # Cross-iteration rolling window (up to 200 paired samples)
    _rolling_chain_scores:  List[float] = []
    _rolling_prm_scores:    List[float] = []
    _rolling_successes:     List[int]   = []   # 1 = successful extraction, 0 = failed
    _CALIB_WINDOW = 50    # minimum samples before computing correlation
    _CALIB_MAX    = 200   # cap rolling lists at this length
    # Throttle shadow extraction: only run the extractor on every Nth grounded
    # solution during calibration. Reduces overhead ~4Γ— while still reaching
    # the 50-sample window within a few iterations.
    _SHADOW_EVERY   = 4
    _shadow_extract_counter: int = 0

    # ── Training ─────────────────────────────────────────────────────────────
    for iteration in range(1, args.num_iterations + 1):
        iter_start = time.perf_counter()
        logger.info("=" * 70)
        logger.info("GRPO ITERATION %d/%d", iteration, args.num_iterations)
        logger.info("=" * 70)

        # Sample questions β€” difficulty-weighted from the mixed pool.
        # When math_pairs is non-empty, draw proportionally: N*ratio from MATH
        # and N*(1-ratio) from GSM8K.  The difficulty sampler handles each pool
        # independently so MATH problems get their own win-rate tracking.
        #
        # MATH ratio ramp: once past --math-ramp-start, linearly increase the
        # MATH fraction toward --math-mix-ratio-late over the next 10 iterations.
        # This progressively raises difficulty after the policy has stabilised.
        _effective_math_ratio = args.math_mix_ratio
        if args.math_mix_ratio_late is not None and iteration > args.math_ramp_start:
            _ramp_progress = min(1.0, (iteration - args.math_ramp_start) / 10.0)
            _effective_math_ratio = (
                args.math_mix_ratio
                + _ramp_progress * (args.math_mix_ratio_late - args.math_mix_ratio)
            )

        if math_pairs and _effective_math_ratio > 0.0:
            n_math  = max(1, round(args.questions_per_iter * _effective_math_ratio))
            n_gsm8k = max(1, args.questions_per_iter - n_math)
            math_batch  = _sample_by_difficulty(math_pairs,  n_math,  alpha=args.difficulty_alpha)
            gsm8k_batch = _sample_by_difficulty(gsm8k_pairs, n_gsm8k, alpha=args.difficulty_alpha)
            questions_batch = math_batch + gsm8k_batch
            random.shuffle(questions_batch)
        else:
            questions_batch = _sample_by_difficulty(
                gsm8k_pairs, args.questions_per_iter, alpha=args.difficulty_alpha
            )
        cur_lr = optimizer.param_groups[0]["lr"]
        # Temperature annealing: linearly decay T from peak β†’ min_temp over the run.
        # Early iterations need high T for exploration; later ones need lower T
        # to consolidate learned strategies (and close the training/eval gap).
        _anneal_frac = min(1.0, (iteration - 1) / max(1, args.num_iterations - 1))
        _annealed_temp = args.temperature * (1.0 - 0.5 * _anneal_frac)  # 0.8 β†’ 0.4
        logger.info(
            "LR this iteration: %.2e | T=%.3f | MATH ratio=%.0f%%",
            cur_lr, _annealed_temp, 100 * _effective_math_ratio,
        )

        all_rewards:   List[float] = []
        all_q_rewards: List[float] = []
        _grounded_rewards:   List[float] = []
        _sp_rewards:         List[float] = []
        _grounded_step_accs: List[float] = []
        _grounded_lccps:     List[float] = []
        _grounded_gt_matches: List[bool] = []
        # Chain scoring accumulators (populated only in Phase 2+ when
        # math_env.use_chain_scoring is True)
        _chain_arith_scores:     List[float] = []
        _chain_dep_scores:       List[float] = []
        _chain_integrity_scores: List[float] = []
        _sp_chain_scores:        List[float] = []   # self-play chain integrity
        _skipped_zero_var:   int = 0   # groups skipped due to zero reward variance
        # Per-component question quality accumulators
        _qc_topic:      List[float] = []
        _qc_diff:       List[float] = []
        _qc_clarity:    List[float] = []
        _qc_novelty:    List[float] = []
        _qc_solvability: List[float] = []

        skipped = 0
        n_groups = 0
        n_self_play = 0
        q_gen_attempts  = 0    # total generate_question() calls
        q_gen_valid     = 0    # non-empty questions produced (len > 10 chars)
        q_quality_good  = 0    # self-play groups where question_reward > 0.5
        total_loss_val = 0.0

        # Determine how many of this iteration's groups use self-play question
        # generation vs grounded (dataset) questions.
        # Phase-driven ratio: Phase 1 forces 0; Phase 2 ramps from 0 to ceiling;
        # Phase 3 holds at ceiling (args.self_play_ratio). Grounded floor recovery
        # (computed at end of previous iteration) overrides to 0 regardless of phase.
        if _phase == _Phase.GROUNDED_ONLY:
            _effective_sp_ratio = 0.0
        elif _phase == _Phase.SELFPLAY_RAMP:
            _grounded_anchor = max(0.30, 1.0 - (_selfplay_iterations / max(1, args.selfplay_ramp_iters)))
            _effective_sp_ratio = 1.0 - _grounded_anchor
        else:  # CONTINUOUS
            _effective_sp_ratio = args.self_play_ratio

        if _selfplay_suspended:
            _effective_sp_ratio = 0.0   # grounded floor recovery pass

        n_self_play_target = int(round(len(questions_batch) * _effective_sp_ratio))

        # Build a random set of group indices that will use self-play.
        # Random interleaving distributes self-play uniformly across the batch
        # instead of front-loading all self-play groups, which would cause the
        # gradient to shift mid-batch as the objective changes character.
        _all_indices = list(range(len(questions_batch)))
        random.shuffle(_all_indices)
        _self_play_indices = set(_all_indices[:n_self_play_target])

        # Zero gradients once before the loop β€” we accumulate them via
        # per-group .backward() calls instead of building one giant graph.
        # Keeping all K*N forward passes alive until a single backward()
        # at the end would hold O(K*N) computation graphs in GPU memory
        # simultaneously (64 graphs at K=4, N=16), risking OOM.  Calling
        # .backward() immediately after each group frees that graph right
        # away; gradients accumulate in .grad tensors without extra memory.
        optimizer.zero_grad()

        pbar = tqdm(questions_batch, desc=f"Iter {iteration} GRPO groups", unit="q")
        for _group_idx, qa in enumerate(pbar):

            # ── Decide: self-play (model generates question) or grounded ─────
            # Random interleaving: self-play slots chosen before the loop.
            use_self_play = _group_idx in _self_play_indices

            if use_self_play:
                # ── SELF-PLAY BRANCH ─────────────────────────────────────────
                # 1. Sample a curriculum instruction (topic + difficulty target)
                instruction, target_topic, target_difficulty = math_env.sample_instruction()

                # MATH L4-L5: exclude from self-play generation β€” problems at this
                # difficulty produce unanchored reward because the verification
                # cascade cannot reliably confirm answers.  Fall back to grounded.
                if target_difficulty >= 4.0:
                    use_self_play = False

                # 2. Model generates the question from the instruction.
                #    This is the "proposer" role in Theme #4 self-improvement:
                #    the model creates its own challenge.
                q_gen_attempts += 1

                # ── TWO-PHASE QUESTION GRPO (when --q-group-size β‰₯ 2) ────────
                # Phase 1: sample K_q question candidates, store their token
                #   IDs for a question-level GRPO update.
                # Phase 2: for each candidate, generate M=group_size solutions,
                #   score them, and run a solution-level GRPO update.
                # The per-question reward (mean solution reward) is then used
                # to run GRPO on the question tokens β€” gradients flow back
                # through the question tokens for the first time.
                if args.q_group_size > 1:
                    _q_temp = min(0.90, _annealed_temp + 0.05)
                    q_cands, q_ids_all, q_masks_all, q_olps_all = generate_questions_batched(
                        model=model,
                        tokenizer=tokenizer,
                        instruction=instruction,
                        K_q=args.q_group_size,
                        max_new_tokens=128,
                        temperature=_q_temp,
                        device=device,
                    )
                    # Keep only candidates with enough substance
                    _valid_q = [
                        (q, ids, mask, olp)
                        for q, ids, mask, olp
                        in zip(q_cands, q_ids_all, q_masks_all, q_olps_all)
                        if len(q.strip()) >= 10
                    ]
                    if not _valid_q:
                        logger.debug("Two-phase SP: all %d question candidates too short, skipping.", args.q_group_size)
                        skipped += 1
                        continue
                    q_gen_valid += 1
                    n_self_play += 1

                    # Phase 2: score solutions for each valid question candidate
                    _question_agg_rewards: List[float] = []   # one per valid candidate
                    _q_total_loss_val: float = 0.0

                    for _q_text, _q_ids, _q_mask, _q_olp in _valid_q:
                        solution_prompt = math_env.format_solution_prompt(_q_text)
                        sols_q, ids_q, masks_q, olps_q = generate_solutions_batched(
                            model=model,
                            tokenizer=tokenizer,
                            prompt=solution_prompt,
                            K=args.group_size,
                            max_new_tokens=args.max_new_tokens,
                            temperature=_annealed_temp,
                            device=device,
                        )
                        # Overlong filter
                        if args.overlong_filter:
                            _vf = [
                                t for t in zip(sols_q, ids_q, masks_q, olps_q)
                                if int(t[2].sum().item()) < args.max_new_tokens
                            ]
                            if _vf:
                                sols_q, ids_q, masks_q, olps_q = map(list, zip(*_vf))  # type: ignore
                            else:
                                skipped += 1
                                _question_agg_rewards.append(0.0)
                                continue

                        # Score solutions
                        _sol_rewards: List[float] = []
                        for _sol in sols_q:
                            _r, _q_rew, _, _q_met = compute_self_play_reward(
                                question=_q_text,
                                solution=_sol,
                                target_topic=target_topic,
                                target_difficulty=target_difficulty,
                                math_env=math_env,
                            )
                            _sol_rewards.append(_r)
                            all_q_rewards.append(_q_rew)
                            _qc_topic.append(_q_met["topic_match"])
                            _qc_diff.append(_q_met["difficulty_fit"])
                            _qc_clarity.append(_q_met["clarity"])
                            _qc_novelty.append(_q_met["novelty"])
                            _qc_solvability.append(_q_met["solvability"])

                        all_rewards.extend(_sol_rewards)
                        _sp_rewards.extend(_sol_rewards)

                        # Aggregate question reward = mean of its solution rewards
                        _q_agg = float(np.mean(_sol_rewards))
                        _question_agg_rewards.append(_q_agg)

                        # ── Solution-level GRPO update ───────────────────────
                        _sol_loss = grpo_loss_for_group(
                            model=model,
                            input_ids_list=ids_q,
                            response_masks=masks_q,
                            rewards=_sol_rewards,
                            old_log_probs=olps_q,
                            clip_eps=args.clip_eps,
                            kl_coef=args.kl_coef,
                            ref_model=ref_model,
                        )
                        if _sol_loss is not None:
                            _sol_loss.backward()
                            total_loss_val += _sol_loss.item()
                            _q_total_loss_val += _sol_loss.item()
                            n_groups += 1
                        else:
                            skipped += 1
                            _skipped_zero_var += 1

                    # ── Question-level GRPO update ───────────────────────────
                    # Advantages are computed over the K_q question-reward
                    # scalars.  The IS ratio is exp(new_lp_question - old_lp_question).
                    # kl_coef=0 here: there is no reference distribution for questions.
                    _q_ids_v   = [t[1] for t in _valid_q]
                    _q_masks_v = [t[2] for t in _valid_q]
                    _q_olps_v  = [t[3] for t in _valid_q]

                    _q_loss = grpo_loss_for_group(
                        model=model,
                        input_ids_list=_q_ids_v,
                        response_masks=_q_masks_v,
                        rewards=_question_agg_rewards,
                        old_log_probs=_q_olps_v,
                        clip_eps=args.clip_eps,
                        kl_coef=0.0,   # no ref model for question tokens
                        ref_model=None,
                    )
                    if _q_loss is not None:
                        _q_loss.backward()
                        logger.debug(
                            "Q-GRPO: loss=%.4f q_rewards=%s (variance=%.4f)",
                            _q_loss.item(),
                            [f"{r:.3f}" for r in _question_agg_rewards],
                            float(np.var(_question_agg_rewards)),
                        )

                    # Group-level quality: at least one candidate scored > 0.5
                    if any(r > 0.5 for r in _question_agg_rewards):
                        q_quality_good += 1

                    # pbar update then skip to next group (all done above)
                    _mean_r_sp = float(np.mean(all_rewards[-len(_valid_q)*args.group_size:])) if all_rewards else 0.0
                    _q_acc_pct = 100.0 * q_quality_good / max(1, n_self_play)
                    pbar.set_postfix(
                        loss=f"{_q_total_loss_val / max(1, len(_valid_q)):.4f}",
                        mean_r=f"{_mean_r_sp:.3f}",
                        q_acc=f"{_q_acc_pct:.0f}%",
                        q_rew=f"{float(np.mean(all_q_rewards)):.3f}" if all_q_rewards else "n/a",
                        skip=skipped,
                    )
                    continue  # ← everything handled above; jump to next group

                # ── K_q=1: original single-question path (no question GRPO) ──
                question = generate_question(
                    model=model,
                    tokenizer=tokenizer,
                    instruction=instruction,
                    max_new_tokens=128,   # questions are short
                    device=device,
                    # Slightly warmer than solution temperature for diversity,
                    # but anneals with the same schedule to stay consistent.
                    temperature=min(0.90, _annealed_temp + 0.05),
                )
                # A valid question must have at least some substance.
                # Reject single-word, empty, or nonsensical outputs.
                if len(question.strip()) < 10:
                    logger.debug(
                        "Self-play: generated question too short (%d chars), skipping group.",
                        len(question.strip()),
                    )
                    skipped += 1
                    continue
                q_gen_valid += 1
                n_self_play += 1
                gold = None   # no gold answer β€” rewarded on question quality
            else:
                # ── GROUNDED BRANCH ──────────────────────────────────────────
                # Use pre-existing dataset question with known gold answer.
                question = qa["question"]
                gold = qa["gold_final"]
                target_topic = "grounded"
                target_difficulty = 0.5

            # --- Generate K solutions (batched β€” single model.generate call) ---
            solution_prompt = math_env.format_solution_prompt(question)
            solutions, input_ids_list, response_masks, old_log_probs_list = (
                generate_solutions_batched(
                    model=model,
                    tokenizer=tokenizer,
                    prompt=solution_prompt,
                    K=args.group_size,
                    max_new_tokens=args.max_new_tokens,
                    temperature=_annealed_temp,
                    device=device,
                )
            )

            # --- Overlong filter: drop truncated solutions (no Final Answer) ---
            # A response that hit max_new_tokens was cut off mid-generation;
            # it almost certainly didn't produce a valid "Final Answer: X" line,
            # so its reward is unreliable noise.  Dropping it keeps the group
            # advantage estimates clean.
            if args.overlong_filter:
                _valid = [
                    (sol, ids, mask, olp)
                    for sol, ids, mask, olp
                    in zip(solutions, input_ids_list, response_masks, old_log_probs_list)
                    if int(mask.sum().item()) < args.max_new_tokens
                ]
                if _valid:
                    solutions, input_ids_list, response_masks, old_log_probs_list = (
                        zip(*_valid)  # type: ignore[assignment]
                    )
                    solutions        = list(solutions)
                    input_ids_list   = list(input_ids_list)
                    response_masks   = list(response_masks)
                    old_log_probs_list = list(old_log_probs_list)
                else:
                    # All K solutions were truncated β€” skip group.
                    skipped += 1
                    continue

            # --- Score each solution (self-play: Q+S reward; grounded: S only) ---
            rewards = []
            _sp_q_rew_this_group: List[float] = []
            for sol in solutions:
                if use_self_play:
                    # compute_reward = 0.40Γ—question_quality + 0.60Γ—solution_quality
                    # This is the core Theme #4 signal: the model is rewarded
                    # for generating a well-formed, appropriately difficult,
                    # solvable question AND for solving it correctly.
                    r, q_rew, _, q_met = compute_self_play_reward(
                        question=question,
                        solution=sol,
                        target_topic=target_topic,
                        target_difficulty=target_difficulty,
                        math_env=math_env,
                    )
                    _sp_q_rew_this_group.append(q_rew)
                    all_q_rewards.append(q_rew)
                    # Collect per-component breakdown (same question, all K solutions
                    # get the same q_metrics β€” average to reduce noise).
                    _qc_topic.append(q_met["topic_match"])
                    _qc_diff.append(q_met["difficulty_fit"])
                    _qc_clarity.append(q_met["clarity"])
                    _qc_novelty.append(q_met["novelty"])
                    _qc_solvability.append(q_met["solvability"])
                    # Self-play chain integrity (Phase 2+ only; None in Phase 1)
                    _sp_ci = q_met.get("sp_chain_integrity_score")
                    if _sp_ci is not None:
                        _sp_chain_scores.append(float(_sp_ci))
                else:
                    r_dict = compute_grounded_reward(
                        question=question,
                        solution=sol,
                        gold_final=gold,
                        math_env=math_env,
                    )
                    r = r_dict["combined_score"]
                    _grounded_step_accs.append(r_dict["step_accuracy"])
                    _grounded_lccps.append(r_dict["lccp"])
                    _grounded_gt_matches.append(bool(r_dict["gt_match"]))
                    if r_dict.get("chain_arith_score") is not None:
                        _chain_arith_scores.append(float(r_dict["chain_arith_score"]))
                    if r_dict.get("chain_dep_score") is not None:
                        _chain_dep_scores.append(float(r_dict["chain_dep_score"]))
                    if r_dict.get("chain_integrity_score") is not None:
                        _chain_integrity_scores.append(float(r_dict["chain_integrity_score"]))

                    # Shadow extraction for calibration: during SELFPLAY_RAMP,
                    # run the chain extractor even before use_chain_scoring is
                    # activated so we can measure chain↔PRM correlation.  These
                    # scores do NOT affect the reward β€” they only feed the
                    # calibration window that decides when to flip use_chain_scoring.
                    # Throttled to every _SHADOW_EVERY solutions to avoid making
                    # each iteration ~10Γ— slower (extractor adds ~8s per call).
                    _shadow_extract_counter += 1
                    if (
                        _phase == _Phase.SELFPLAY_RAMP
                        and not _use_chain_as_primary
                        and _unified_calc is not None
                        and _shadow_extract_counter % _SHADOW_EVERY == 0
                    ):
                        _prm_ps = (
                            0.60 * r_dict.get("prm_final_score", 0.0)
                            + 0.40 * r_dict.get("prm_mean_score", 0.0)
                        )
                        try:
                            _shadow = _unified_calc.compute(
                                solution=sol,
                                gold_answer=gold,
                                question=question,
                                topic=target_topic,
                                phase="grounded",
                            )
                            _rolling_chain_scores.append(_shadow.chain_integrity_score)
                            _rolling_prm_scores.append(_prm_ps)
                            _rolling_successes.append(1 if _shadow.extraction_succeeded else 0)
                        except Exception:
                            _rolling_successes.append(0)
                rewards.append(r)
            all_rewards.extend(rewards)
            # Route to path-specific accumulators for separate batch_acc reporting
            if use_self_play:
                _sp_rewards.extend(rewards)
            else:
                _grounded_rewards.extend(rewards)

            # A self-play group is "accurate" if the question it generated scored
            # above 0.5 on question quality β€” meaning it was clear, on-topic,
            # appropriately difficult, and solvable.
            if use_self_play and _sp_q_rew_this_group:
                if float(np.mean(_sp_q_rew_this_group)) > 0.5:
                    q_quality_good += 1

            # --- PAL/SymPy verification gate (self-play only) ---
            # Drop the group if the tiered cascade cannot confirm a consistent,
            # independently-verifiable answer.  This prevents circular PRM reward
            # from being the sole correctness anchor on self-play examples.
            if use_self_play:
                if not _verify_self_play_answer(solutions, target_topic, target_difficulty):
                    skipped += 1
                    continue  # no gradient for this group

            # --- Update difficulty stats (grounded questions only β€” self-play
            #     questions are ephemeral and have no stable key) ---
            if not use_self_play:
                _key = _question_key(question)
                _q_attempts[_key] += len(solutions)
                # Win = reward in the top half of THIS group, not an absolute 0.5 threshold.
                # Using a relative threshold avoids the case where all solutions score 0.55
                # (all "wins" β†’ easy) or all score 0.45 (all "losses" β†’ impossible) when the
                # rewards are actually similar and carry no difficulty information.
                _group_median = float(np.median(rewards))
                _q_wins[_key] += sum(1 for r in rewards if r > _group_median)

            # --- GRPO loss (IS clip + optional KL penalty) + immediate backward ---
            # Skip near-uniform groups early: when reward std < 0.02 (on a [0,1]
            # scale) all advantages collapse to ~0 and the gradient contribution
            # is negligible β€” equivalent to wasted compute. This is a stricter
            # guard than the eps=1e-8 inside grpo_loss_for_group, which only
            # catches exactly-equal rewards (e.g. all 0.998 passes through it).
            _reward_std = float(np.std(rewards))
            if _reward_std < 0.02:
                skipped += 1
                _skipped_zero_var += 1
                _pf_zv: Dict = dict(mean_r=f"{np.mean(rewards):.3f}", skip=skipped, loss="0var")
                pbar.set_postfix(**_pf_zv)
                continue

            group_loss = grpo_loss_for_group(
                model=model,
                input_ids_list=input_ids_list,
                response_masks=response_masks,
                rewards=rewards,
                old_log_probs=old_log_probs_list,
                clip_eps=args.clip_eps,
                kl_coef=args.kl_coef,
                ref_model=ref_model,
            )

            if group_loss is None:
                skipped += 1
                _skipped_zero_var += 1
                _pf: Dict = dict(mean_r=f"{np.mean(rewards):.3f}", skip=skipped, loss="skip")
                if n_self_play > 0 and all_q_rewards:
                    _q_acc_pct = 100.0 * q_quality_good / max(1, n_self_play)
                    _pf["q_acc"] = f"{_q_acc_pct:.0f}%"
                pbar.set_postfix(**_pf)
                continue

            # Backprop immediately β€” frees this group's computation graph.
            # Gradients from all valid groups accumulate in param.grad.
            group_loss.backward()
            total_loss_val += group_loss.item()
            n_groups += 1
            _pf = dict(
                mean_r=f"{np.mean(rewards):.3f}",
                loss=f"{group_loss.item():.4f}",
                skip=skipped,
            )
            if n_self_play > 0 and all_q_rewards:
                # Show live question-gen accuracy in the tqdm bar.
                # q_acc = fraction of self-play groups whose generated question
                # scored > 0.5 on quality (clear, on-topic, solvable).
                _q_acc_pct = 100.0 * q_quality_good / max(1, n_self_play)
                _pf["q_acc"] = f"{_q_acc_pct:.0f}%"
                _pf["q_rew"]  = f"{float(np.mean(all_q_rewards)):.3f}"
            pbar.set_postfix(**_pf)

        # --- Gradient step: normalise accumulated grads then step ---
        if n_groups > 0:
            # Divide accumulated grads by n_groups to get the true average
            # (equivalent to averaging the group losses before backward).
            if n_groups > 1:
                for p in model.parameters():
                    if p.grad is not None:
                        p.grad.div_(n_groups)
            torch.nn.utils.clip_grad_norm_(
                [p for p in model.parameters() if p.requires_grad],
                args.max_grad_norm,
            )
            optimizer.step()
            loss_val = total_loss_val / n_groups
        else:
            loss_val = 0.0
        scheduler.step()

        iter_time = time.perf_counter() - iter_start
        mean_r   = float(np.mean(all_rewards))             if all_rewards   else 0.0
        std_r    = float(np.std(all_rewards))              if all_rewards   else 0.0
        acc_r    = float(np.mean([r > 0.5 for r in all_rewards])) if all_rewards else 0.0
        grounded_acc_r = (
            float(np.mean([r > 0.5 for r in _grounded_rewards]))
            if _grounded_rewards else 0.0
        )
        mean_step_acc = (
            float(np.mean(_grounded_step_accs))
            if _grounded_step_accs else 0.0
        )
        mean_lccp = (
            float(np.mean(_grounded_lccps))
            if _grounded_lccps else 0.0
        )
        mean_q_r = float(np.mean(all_q_rewards)) if all_q_rewards else 0.0

        # Chain scoring batch means (non-None only in Phase 2+)
        mean_chain_arith     = float(np.mean(_chain_arith_scores))     if _chain_arith_scores     else None
        mean_chain_dep       = float(np.mean(_chain_dep_scores))       if _chain_dep_scores       else None
        mean_chain_integrity = float(np.mean(_chain_integrity_scores)) if _chain_integrity_scores else None
        mean_sp_chain        = float(np.mean(_sp_chain_scores))        if _sp_chain_scores        else None

        # ── gt_match_rate: raw answer-correctness on grounded examples ────────
        # This is the primary Phase-1 graduation signal β€” unlike grounded_acc_r
        # which is (combined_score > 0.5), gt_match_rate is the direct SymPy
        # exact-match fraction and cannot be gamed by a high PRM/format score.
        gt_match_rate = (
            float(sum(_grounded_gt_matches) / len(_grounded_gt_matches))
            if _grounded_gt_matches else 0.0
        )

        # ── Phase FSM transitions ─────────────────────────────────────────────
        if _phase == _Phase.GROUNDED_ONLY:
            _graduation_ready = (
                gt_match_rate    >= args.selfplay_gt_thresh
                and grounded_acc_r >= args.selfplay_grounded_thresh
                and mean_step_acc  >= args.selfplay_step_thresh
                and iteration      >= args.min_warmup
            )
            if _graduation_ready:
                _phase = _Phase.SELFPLAY_RAMP
                logger.info(
                    "PHASE β†’ SELFPLAY_RAMP at iter %d "
                    "(gt_match=%.2f grounded_acc=%.2f step_acc=%.2f) β€” "
                    "shadow extraction active; chain scoring deferred until "
                    "calibration passes (corrβ‰₯0.70, success_rateβ‰₯0.80)",
                    iteration, gt_match_rate, grounded_acc_r, mean_step_acc,
                )
                # NOTE: do NOT set math_env.use_chain_scoring = True here.
                # The extractor runs in shadow mode first; use_chain_scoring
                # flips to True below once calibration thresholds are met.
        elif _phase in (_Phase.SELFPLAY_RAMP, _Phase.CONTINUOUS):
            _selfplay_iterations += 1
            if _phase == _Phase.SELFPLAY_RAMP and _selfplay_iterations >= args.selfplay_ramp_iters:
                _phase = _Phase.CONTINUOUS
                logger.info(
                    "PHASE β†’ CONTINUOUS at iter %d (ramp complete after %d iters)",
                    iteration, _selfplay_iterations,
                )

            # ── Data-driven chain scoring activation ─────────────────────────
            # Trim rolling window to _CALIB_MAX before computing correlation.
            if len(_rolling_chain_scores) > _CALIB_MAX:
                _rolling_chain_scores = _rolling_chain_scores[-_CALIB_MAX:]
                _rolling_prm_scores   = _rolling_prm_scores[-_CALIB_MAX:]
                _rolling_successes    = _rolling_successes[-_CALIB_MAX:]

            if not _use_chain_as_primary and len(_rolling_chain_scores) >= _CALIB_WINDOW:
                from scipy.stats import pearsonr  # noqa: PLC0415
                try:
                    _r, _ = pearsonr(
                        _rolling_chain_scores[-_CALIB_WINDOW:],
                        _rolling_prm_scores[-_CALIB_WINDOW:],
                    )
                    _chain_prm_correlation = float(_r)
                except Exception:
                    _chain_prm_correlation = 0.0
                _rolling_n = len(_rolling_successes[-_CALIB_WINDOW:])
                _extraction_success_rate = (
                    sum(_rolling_successes[-_CALIB_WINDOW:]) / _rolling_n
                    if _rolling_n > 0 else 0.0
                )
                if (
                    _chain_prm_correlation >= 0.70
                    and _extraction_success_rate >= 0.80
                ):
                    _use_chain_as_primary = True
                    math_env.use_chain_scoring = True
                    logger.info(
                        "CHAIN PRIMARY activated at iter %d: "
                        "corr=%.2f extraction_rate=%.2f (window=%d) β€” "
                        "unified calculator now drives reward scoring",
                        iteration, _chain_prm_correlation,
                        _extraction_success_rate, _CALIB_WINDOW,
                    )
                else:
                    logger.debug(
                        "Chain calibration: corr=%.2f success_rate=%.2f "
                        "(need corrβ‰₯0.70, successβ‰₯0.80; window=%d/%d)",
                        _chain_prm_correlation, _extraction_success_rate,
                        len(_rolling_chain_scores), _CALIB_WINDOW,
                    )

            # Grounded floor monitoring: suspend self-play if answer correctness
            # drops below the floor set at graduation minus 5pp.  Self-play
            # resumes automatically next iteration if performance recovers.
            _prev_suspended = _selfplay_suspended
            _selfplay_suspended = (
                bool(_grounded_gt_matches) and gt_match_rate < args.grounded_floor
            )
            if _selfplay_suspended and not _prev_suspended:
                logger.warning(
                    "GROUNDED FLOOR: gt_match_rate=%.2f fell below floor=%.2f β€” "
                    "suspending self-play for recovery",
                    gt_match_rate, args.grounded_floor,
                )
            elif not _selfplay_suspended and _prev_suspended:
                logger.info(
                    "GROUNDED FLOOR: gt_match_rate=%.2f recovered above floor=%.2f β€” "
                    "resuming self-play",
                    gt_match_rate, args.grounded_floor,
                )

        # Question generation accuracy metrics (self-play only)
        q_gen_valid_rate = (q_gen_valid   / q_gen_attempts)  if q_gen_attempts  > 0 else 0.0
        q_quality_rate   = (q_quality_good / n_self_play)    if n_self_play     > 0 else 0.0
        # Per-component averages (all non-empty across K solutions Γ— groups)
        mean_q_topic     = float(np.mean(_qc_topic))       if _qc_topic      else 0.0
        mean_q_diff      = float(np.mean(_qc_diff))        if _qc_diff       else 0.0
        mean_q_clarity   = float(np.mean(_qc_clarity))     if _qc_clarity    else 0.0
        mean_q_novelty   = float(np.mean(_qc_novelty))     if _qc_novelty    else 0.0
        mean_q_solvab    = float(np.mean(_qc_solvability)) if _qc_solvability else 0.0

        _cur_lr = optimizer.param_groups[0]["lr"]

        # ── LLM classifier stats (every 5 iters to avoid log spam) ─────────
        if iteration % 5 == 0:
            _llm_classifier.log_stats()

        # ── Primary summary line ─────────────────────────────────────────────
        logger.info(
            "Iter %d | loss=%.4f | reward mean=%.3f std=%.3f | "
            "gt_match=%.1f%% | grounded_acc=%.1f%% | step_acc=%.1f%% | lccp=%.1f%% | "
            "batch_acc=%.1f%% | phase=%s sp_ratio=%.0f%% | "
            "groups=%d skipped=%d(0var=%d) | lr=%.2e | %.1fs",
            iteration, loss_val, mean_r, std_r,
            100 * gt_match_rate,
            100 * grounded_acc_r,
            100 * mean_step_acc,
            100 * mean_lccp,
            100 * acc_r,
            _phase.name, 100 * _effective_sp_ratio,
            n_groups, skipped, _skipped_zero_var, _cur_lr, iter_time,
        )
        # Starvation warning: if >30% of groups were skipped due to zero reward
        # variance (all K solutions same score), the curriculum difficulty is
        # mis-calibrated β€” either too easy (all correct) or too hard (all wrong).
        _total_attempted = n_groups + skipped
        if _total_attempted > 0 and _skipped_zero_var / _total_attempted > 0.30:
            logger.warning(
                "STARVATION: %.0f%% of groups skipped (zero variance). "
                "grounded_acc=%.1f%% suggests curriculum is %s. "
                "Consider adjusting --difficulty-alpha.",
                100 * _skipped_zero_var / _total_attempted,
                100 * grounded_acc_r,
                "too easy (raise alpha)" if grounded_acc_r > 0.75 else "too hard (lower alpha)",
            )

        # ── Question-generation accuracy line (only when self-play is active) ─
        if n_self_play > 0:
            logger.info(
                "  Question generation: %d/%d valid (%.0f%%) | "
                "q_reward=%.3f | q_acc=%.1f%% (>0.5 quality) | "
                "topic=%.2f diff=%.2f clarity=%.2f novelty=%.2f solvability=%.2f",
                q_gen_valid, q_gen_attempts, 100 * q_gen_valid_rate,
                mean_q_r, 100 * q_quality_rate,
                mean_q_topic, mean_q_diff, mean_q_clarity,
                mean_q_novelty, mean_q_solvab,
            )

        iter_metrics: Dict = {
            "iteration":             iteration,
            "loss":                  loss_val,
            "mean_reward":           mean_r,
            "std_reward":            std_r,
            "batch_accuracy":        acc_r,
            "grounded_accuracy":     grounded_acc_r,
            "gt_match_rate":         round(gt_match_rate, 4),
            "step_accuracy":         mean_step_acc,
            "lccp":                  mean_lccp,
            "n_groups":              n_groups,
            "skipped_groups":        skipped,
            "learning_rate":         _cur_lr,
            "iter_time_s":           iter_time,
            # ── Phase curriculum metrics ────────────────────────────────────
            "training_phase":        _phase.name,
            "effective_sp_ratio":    round(_effective_sp_ratio, 3),
            "selfplay_suspended":    int(_selfplay_suspended),
            # ── Chain scoring metrics (Phase 2+, None in Phase 1) ────────────
            "chain_arith_score":       round(mean_chain_arith, 4)     if mean_chain_arith     is not None else None,
            "chain_dep_score":         round(mean_chain_dep, 4)       if mean_chain_dep       is not None else None,
            "chain_integrity_score":   round(mean_chain_integrity, 4) if mean_chain_integrity is not None else None,
            "sp_chain_integrity_score": round(mean_sp_chain, 4)       if mean_sp_chain        is not None else None,
            # ── Chain calibration metrics (populated during SELFPLAY_RAMP shadow mode)
            "chain_prm_correlation":   round(_chain_prm_correlation, 3),
            "extraction_success_rate": round(_extraction_success_rate, 3),
            "chain_scoring_active":    int(_use_chain_as_primary),
            # ── Question-generation metrics ─────────────────────────────────
            "n_self_play_groups":    n_self_play,
            "q_gen_attempts":        q_gen_attempts,
            "q_gen_valid":           q_gen_valid,
            "q_gen_valid_rate":      round(q_gen_valid_rate, 4),
            "mean_question_reward":  round(mean_q_r, 4),
            "q_quality_rate":        round(q_quality_rate, 4),
            "q_topic_match":         round(mean_q_topic,   4),
            "q_difficulty_fit":      round(mean_q_diff,    4),
            "q_clarity":             round(mean_q_clarity, 4),
            "q_novelty":             round(mean_q_novelty, 4),
            "q_solvability":         round(mean_q_solvab,  4),
        }

        # --- Eval ---
        if iteration % args.eval_every == 0:
            _eval_ds_label = _infer_eval_dataset_name(args.eval_data_path)
            logger.info("Evaluating %s (%d samples)...", _eval_ds_label, args.eval_max_samples)
            eval_res = evaluate_policy(
                model, tokenizer,
                args.eval_data_path, args.eval_max_samples, args.eval_max_new_tokens,
                math_env=math_env,
                pass_at_k=args.eval_pass_at_k,
            )
            # accuracy == combined_score: 0.50Γ—correct + 0.40Γ—process(prm_final,prm_mean) + 0.10Γ—fmt
            cur_combined = float(eval_res.get("combined_score", best_combined))
            cur_prm_mean = float(eval_res.get("prm_mean",       best_prm_mean))

            _log_eval_result(f"iter {iteration}", eval_res, best=best_combined)

            # ── Checkpoint: save when combined_score strictly improves ────────
            # combined_score is a continuous variable; any improvement in
            # correctness, PRM quality, SymPy, or format moves it.
            if cur_combined > best_combined + 1e-4:
                reason = f"combined {cur_combined:.4f} > {best_combined:.4f}"
                best_combined  = cur_combined
                best_prm_mean  = max(best_prm_mean, cur_prm_mean)
                best_accuracy  = best_combined
                best_path = out_dir / "best_policy"
                model.save_pretrained(str(best_path))
                tokenizer.save_pretrained(str(best_path))
                logger.info("New best saved β†’ %s  (%s)", best_path, reason)

            iter_metrics.update(eval_res)

        # --- Save checkpoint (respect --save-every / --keep-last) ---
        is_last_iter = iteration == args.num_iterations
        should_save = is_last_iter or (
            args.save_every > 0 and iteration % args.save_every == 0
        )
        if should_save:
            ckpt_path = out_dir / f"iter_{iteration:04d}"
            ckpt_path.mkdir(exist_ok=True)
            model.save_pretrained(str(ckpt_path))
            tokenizer.save_pretrained(str(ckpt_path))

            # Prune older iter_* checkpoints beyond the rolling window.
            if args.keep_last and args.keep_last > 0:
                existing = sorted(
                    p for p in out_dir.iterdir()
                    if p.is_dir() and p.name.startswith("iter_")
                )
                to_remove = existing[: -args.keep_last]
                for old in to_remove:
                    try:
                        shutil.rmtree(old)
                        logger.info("Pruned old checkpoint: %s", old.name)
                    except OSError as exc:
                        logger.warning("Could not prune %s: %s", old.name, exc)

        # ── Write metrics to both JSONL (full history) and CSV (live row) ────
        metrics_log.append(iter_metrics)
        (out_dir / "metrics.jsonl").write_text(
            "\n".join(json.dumps(m) for m in metrics_log), encoding="utf-8"
        )
        # CSV: one row per iteration, flushed immediately so you can
        # `tail -f logs/grpo/<run>/metrics.csv` or open it in Excel mid-run.
        # `iter_metrics.update(eval_res)` overwrites step_accuracy/lccp on eval iters.
        # We capture the is_eval flag here for clarity.
        _is_eval_iter = "combined_score" in iter_metrics
        _append_metrics_csv({
            "iteration":        iter_metrics["iteration"],
            "timestamp":        datetime.now().isoformat(timespec="seconds"),
            # ── Per-iteration training signal ───────────────────────────────
            "loss":             iter_metrics.get("loss", 0.0),
            "mean_reward":      iter_metrics.get("mean_reward", 0.0),
            "std_reward":       iter_metrics.get("std_reward", 0.0),
            "batch_accuracy":   iter_metrics.get("batch_accuracy", 0.0),
            "grounded_acc":     iter_metrics.get("grounded_accuracy", 0.0),
            "gt_match_rate":    iter_metrics.get("gt_match_rate", 0.0),
            # step_accuracy / lccp: training value on non-eval iters,
            # eval value on eval iters (update() overwrites them).
            "step_accuracy":    iter_metrics.get("step_accuracy", 0.0),
            "lccp":             iter_metrics.get("lccp", 0.0),
            "n_groups":         iter_metrics.get("n_groups", 0),
            "skipped_groups":   iter_metrics.get("skipped_groups", 0),
            "n_sp_groups":      iter_metrics.get("n_self_play_groups", 0),
            "sp_ratio":         iter_metrics.get("effective_sp_ratio", 0.0),
            "sp_suspended":     iter_metrics.get("selfplay_suspended", 0),
            "training_phase":   iter_metrics.get("training_phase", ""),
            "learning_rate":    iter_metrics.get("learning_rate", 0.0),
            "iter_time_s":      iter_metrics.get("iter_time_s", 0.0),
            # ── Question-generation quality ─────────────────────────────────
            "q_reward":         iter_metrics.get("mean_question_reward", ""),
            "q_valid_rate":     iter_metrics.get("q_gen_valid_rate",     ""),
            "q_novelty":        iter_metrics.get("q_novelty",            ""),
            "q_solvability":    iter_metrics.get("q_solvability",        ""),
            # ── Chain scoring calibration ───────────────────────────────────
            "chain_prm_corr":   iter_metrics.get("chain_prm_correlation", ""),
            "chain_scoring_on": iter_metrics.get("chain_scoring_active",  ""),
            # ── Eval checkpoint metrics (every eval_every iters) ────────────
            "eval_combined":    iter_metrics.get("combined_score",          "") if _is_eval_iter else "",
            "eval_correct_rt":  iter_metrics.get("correct_rate",            "") if _is_eval_iter else "",
            "eval_prm":         iter_metrics.get("prm_mean",                "") if _is_eval_iter else "",
            "eval_step_acc":    iter_metrics.get("step_accuracy",           "") if _is_eval_iter else "",
            "eval_lccp":        iter_metrics.get("lccp",                    "") if _is_eval_iter else "",
            "eval_format":      iter_metrics.get("format_mean",             "") if _is_eval_iter else "",
            "eval_n_scored":    iter_metrics.get("n_scored",                "") if _is_eval_iter else "",
            "eval_final_ans":   iter_metrics.get("final_answer_accuracy",   "") if _is_eval_iter else "",
        })

    logger.info("=" * 70)
    logger.info("GRPO training complete.")
    logger.info(
        "Best training-objective score : %.4f  "
        "(0.50Γ—correct + 0.40Γ—process[0.60Γ—prm_final+0.40Γ—prm_mean] + 0.10Γ—fmt)",
        best_combined,
    )
    logger.info("Best PRM component mean       : %.3f", best_prm_mean)
    logger.info("Checkpoints                   : %s", out_dir)
    logger.info("Logs                          : %s", log_dir)
    logger.info("Console log                   : %s", console_log_path)
    logger.info("=" * 70)

    # ── Final summary ─────────────────────────────────────────────────────────
    summary: Dict[str, Any] = {
        "run_name":          run_name,
        "best_accuracy":     best_combined,   # accuracy == combined_score
        "best_combined":     best_combined,
        "best_prm_mean":     best_prm_mean,
        "total_iterations":  args.num_iterations,
        "checkpoints_dir":   str(out_dir),
        "log_dir":           str(log_dir),
        "console_log":       str(console_log_path),
        "metrics_csv":       str(_metrics_csv_path),
        "metrics_jsonl":     str(out_dir / "metrics.jsonl"),
    }
    (log_dir / "summary.json").write_text(
        json.dumps(summary, indent=2, default=str), encoding="utf-8"
    )
    logger.info("Summary written to %s", log_dir / "summary.json")

    # ── Auto-generate demo plots ───────────────────────────────────────────────
    _metrics_jsonl = out_dir / "metrics.jsonl"
    if _metrics_jsonl.exists():
        try:
            import importlib
            if importlib.util.find_spec("matplotlib") is None:
                logger.warning(
                    "matplotlib not installed β€” skipping auto-plot. "
                    "Install with: pip install matplotlib  then run: "
                    "python scripts/plot_grpo_run.py %s",
                    _metrics_jsonl,
                )
            else:
                from scripts.plot_grpo_run import generate_plots as _gen_plots
                _plot_dir = _gen_plots(_metrics_jsonl)
                logger.info("Plots saved β†’ %s", _plot_dir)
        except Exception as _plot_exc:
            logger.warning(
                "Plot generation failed (%s: %s). "
                "Run manually: python scripts/plot_grpo_run.py %s",
                type(_plot_exc).__name__, _plot_exc, _metrics_jsonl,
            )

    # Explicit teardown (atexit is the safety net for crashes; calling here
    # ensures everything is flushed and closed before the process returns
    # normally β€” atexit won't double-close because _teardown_logging is
    # idempotent via the .closed checks).
    _teardown_logging()


if __name__ == "__main__":
    main()