HichTala commited on
Commit
c1e9524
·
verified ·
1 Parent(s): 0916c25

Update modeling_diffusiondet.py

Browse files
Files changed (1) hide show
  1. modeling_diffusiondet.py +11 -18
modeling_diffusiondet.py CHANGED
@@ -238,7 +238,7 @@ class DiffusionDet(PreTrainedModel):
238
 
239
  return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
240
 
241
- def forward(self, pixel_values, labels):
242
  """
243
  Args:
244
  """
@@ -256,6 +256,16 @@ class DiffusionDet(PreTrainedModel):
256
  features = self.fpn(features) # [144, 72, 36, 18]
257
  features = [features[f] for f in features.keys()]
258
 
 
 
 
 
 
 
 
 
 
 
259
  # if self.training:
260
  labels = list(map(lambda tensor: tensor.to(self.device), labels))
261
  targets, x_boxes, noises, ts = self.prepare_targets(labels)
@@ -277,23 +287,6 @@ class DiffusionDet(PreTrainedModel):
277
  loss_dict[k] *= weight_dict[k]
278
  loss_dict['loss'] = sum([loss_dict[k] for k in weight_dict.keys()])
279
 
280
- wandb_logs_values = ["loss_ce", "loss_bbox", "loss_giou"]
281
-
282
- if self.training:
283
- wandb.log({f'train/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
284
- else:
285
- wandb.log({f'eval/{k}': v.detach().cpu().numpy() for k, v in loss_dict.items() if k in wandb_logs_values})
286
-
287
- if not self.training:
288
- pred_logits, pred_labels, pred_boxes = self.ddim_sample(pixel_values, features, images_whwh)
289
- return DiffusionDetOutput(
290
- loss=loss_dict['loss'],
291
- loss_dict=loss_dict,
292
- logits=pred_logits,
293
- labels=pred_labels,
294
- pred_boxes=pred_boxes,
295
- )
296
-
297
  return DiffusionDetOutput(
298
  loss=loss_dict['loss'],
299
  loss_dict=loss_dict,
 
238
 
239
  return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
240
 
241
+ def forward(self, pixel_values, labels=None):
242
  """
243
  Args:
244
  """
 
256
  features = self.fpn(features) # [144, 72, 36, 18]
257
  features = [features[f] for f in features.keys()]
258
 
259
+ if not self.training:
260
+ pred_logits, pred_labels, pred_boxes = self.ddim_sample(pixel_values, features, images_whwh)
261
+ return DiffusionDetOutput(
262
+ # loss=loss_dict['loss'],
263
+ # loss_dict=loss_dict,
264
+ logits=pred_logits,
265
+ labels=pred_labels,
266
+ pred_boxes=pred_boxes,
267
+ )
268
+
269
  # if self.training:
270
  labels = list(map(lambda tensor: tensor.to(self.device), labels))
271
  targets, x_boxes, noises, ts = self.prepare_targets(labels)
 
287
  loss_dict[k] *= weight_dict[k]
288
  loss_dict['loss'] = sum([loss_dict[k] for k in weight_dict.keys()])
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  return DiffusionDetOutput(
291
  loss=loss_dict['loss'],
292
  loss_dict=loss_dict,