kmn5409 commited on
Commit
486e0c9
·
1 Parent(s): 11c85a5

Upload cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco_flower.py

Browse files
cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco_flower.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/models/cascade_mask_rcnn_swin_fpn.py',
3
+ '../_base_/datasets/coco_instance.py',
4
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ embed_dim=96,
10
+ depths=[2, 2, 18, 2],
11
+ num_heads=[3, 6, 12, 24],
12
+ window_size=7,
13
+ ape=False,
14
+ drop_path_rate=0.2,
15
+ patch_norm=True,
16
+ use_checkpoint=False
17
+ ),
18
+ neck=dict(in_channels=[96, 192, 384, 768]),
19
+ roi_head=dict(
20
+ bbox_head=[
21
+ dict(
22
+ type='ConvFCBBoxHead',
23
+ num_shared_convs=4,
24
+ num_shared_fcs=1,
25
+ in_channels=256,
26
+ conv_out_channels=256,
27
+ fc_out_channels=1024,
28
+ roi_feat_size=7,
29
+ #num_classes=80,
30
+ num_classes=3,
31
+ bbox_coder=dict(
32
+ type='DeltaXYWHBBoxCoder',
33
+ target_means=[0., 0., 0., 0.],
34
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
35
+ reg_class_agnostic=False,
36
+ reg_decoded_bbox=True,
37
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
38
+ loss_cls=dict(
39
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
40
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
41
+ dict(
42
+ type='ConvFCBBoxHead',
43
+ num_shared_convs=4,
44
+ num_shared_fcs=1,
45
+ in_channels=256,
46
+ conv_out_channels=256,
47
+ fc_out_channels=1024,
48
+ roi_feat_size=7,
49
+ #num_classes=80,
50
+ num_classes=3,
51
+ bbox_coder=dict(
52
+ type='DeltaXYWHBBoxCoder',
53
+ target_means=[0., 0., 0., 0.],
54
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
55
+ reg_class_agnostic=False,
56
+ reg_decoded_bbox=True,
57
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
58
+ loss_cls=dict(
59
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
60
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
61
+ dict(
62
+ type='ConvFCBBoxHead',
63
+ num_shared_convs=4,
64
+ num_shared_fcs=1,
65
+ in_channels=256,
66
+ conv_out_channels=256,
67
+ fc_out_channels=1024,
68
+ roi_feat_size=7,
69
+ #num_classes=80,
70
+ num_classes=3,
71
+ bbox_coder=dict(
72
+ type='DeltaXYWHBBoxCoder',
73
+ target_means=[0., 0., 0., 0.],
74
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
75
+ reg_class_agnostic=False,
76
+ reg_decoded_bbox=True,
77
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
78
+ loss_cls=dict(
79
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
80
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
81
+ ]))
82
+
83
+ dataset_type = 'COCODataset'
84
+ classes = ('bud','flower','fruit',)
85
+ img_norm_cfg = dict(
86
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
87
+
88
+ # augmentation strategy originates from DETR / Sparse RCNN
89
+ train_pipeline = [
90
+ dict(type='LoadImageFromFile'),
91
+ #dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
92
+ dict(type='LoadAnnotations', with_mask=True),
93
+ dict(type='RandomFlip', flip_ratio=0.5),
94
+ dict(type='AutoAugment',
95
+ policies=[
96
+ [
97
+ dict(type='Resize',
98
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
99
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
100
+ (736, 1333), (768, 1333), (800, 1333)],
101
+ multiscale_mode='value',
102
+ keep_ratio=True)
103
+ ],
104
+ [
105
+ dict(type='Resize',
106
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
107
+ multiscale_mode='value',
108
+ keep_ratio=True),
109
+ dict(type='RandomCrop',
110
+ crop_type='absolute_range',
111
+ crop_size=(384, 600),
112
+ allow_negative_crop=True),
113
+ dict(type='Resize',
114
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
115
+ (576, 1333), (608, 1333), (640, 1333),
116
+ (672, 1333), (704, 1333), (736, 1333),
117
+ (768, 1333), (800, 1333)],
118
+ multiscale_mode='value',
119
+ override=True,
120
+ keep_ratio=True)
121
+ ]
122
+ ]),
123
+ dict(type='Normalize', **img_norm_cfg),
124
+ dict(type='Pad', size_divisor=32),
125
+ dict(type='DefaultFormatBundle'),
126
+ #dict(type='Collect', keys=['img']),
127
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
128
+ #dict(type='Collect', keys=['img', 'gt_labels', 'gt_masks']),
129
+ #dict(type='Collect', keys=['img']),
130
+ ]
131
+
132
+
133
+ test_pipeline = [
134
+ dict(type='LoadImageFromFile'),
135
+ dict(
136
+ type='MultiScaleFlipAug',
137
+ #img_scale=(1333, 800),
138
+ img_scale=(800, 1333),
139
+ flip=False,
140
+ transforms=[
141
+ dict(type='Resize', keep_ratio=True),
142
+ dict(type='Normalize', **img_norm_cfg),
143
+ dict(type='Pad', size_divisor=32),
144
+ dict(type='ImageToTensor', keys=['img']),
145
+ dict(type='Collect', keys=['img']), # do not pass gt_label while testing
146
+ ])
147
+ ]
148
+ '''
149
+ data = dict(train=dict(pipeline=train_pipeline),
150
+ test=dict(
151
+ img_prefix='/projectnb/ds549/students/kmn5409/tertiary_task/Swin/Swin-Transformer-Object-Detection/mmdetection/data/coco/images/',
152
+ classes=classes,
153
+ ann_file='/projectnb/ds549/students/kmn5409/tertiary_task/Swin/Swin-Transformer-Object-Detection/mmdetection/data/coco/annotations/flowers/test.json'))
154
+ '''
155
+ data = dict(
156
+ #train=dict(
157
+ train=dict(
158
+ img_prefix='/projectnb/ds549/students/kmn5409/tertiary_task/Swin/Swin-Transformer-Object-Detection/mmdetection/data/coco/images/',
159
+ classes=classes,
160
+ ann_file='/projectnb/ds549/students/kmn5409/tertiary_task/Swin/Swin-Transformer-Object-Detection/mmdetection/data/coco/annotations/flowers/train.json', pipeline=train_pipeline),
161
+
162
+ test=dict( # test data config
163
+ #type=dataset_type,
164
+ img_prefix='/projectnb/ds549/students/kmn5409/tertiary_task/Swin/Swin-Transformer-Object-Detection/mmdetection/data/coco/images/',
165
+ classes=classes,
166
+ ann_file='/projectnb/ds549/students/kmn5409/tertiary_task/Swin/Swin-Transformer-Object-Detection/mmdetection/data/coco/annotations/flowers/test.json',
167
+ pipeline=test_pipeline)
168
+
169
+ )
170
+ #)
171
+
172
+
173
+
174
+ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
175
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
176
+ 'relative_position_bias_table': dict(decay_mult=0.),
177
+ 'norm': dict(decay_mult=0.)}))
178
+ lr_config = dict(step=[27, 33])
179
+ runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
180
+ #runner = dict(type='EpochBasedRunnerAmp', max_epochs=72)
181
+
182
+ # do not use mmdet version fp16
183
+ fp16 = None
184
+ optimizer_config = dict(
185
+ type="DistOptimizerHook",
186
+ update_interval=1,
187
+ grad_clip=None,
188
+ coalesce=True,
189
+ bucket_size_mb=-1,
190
+ use_fp16=True,
191
+ )
192
+
193
+ load_from = '/projectnb/ds549/students/kmn5409/tertiary_task/Swin/Swin-Transformer-Object-Detection/cascade_mask_rcnn_swin_small_patch4_window7.pth'