YongganFu commited on
Commit
6294f2a
·
verified ·
1 Parent(s): bad4dec

Upload model

Browse files
Files changed (2) hide show
  1. modeling_ministral.py +13 -9
  2. modeling_ministral_dlm.py +1 -1
modeling_ministral.py CHANGED
@@ -417,15 +417,19 @@ class Ministral3Model(Ministral3PreTrainedModel):
417
  if position_ids is None:
418
  position_ids = cache_position.unsqueeze(0)
419
 
420
- mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
421
- causal_mask = mask_function(
422
- config=self.config,
423
- input_embeds=inputs_embeds,
424
- attention_mask=attention_mask,
425
- cache_position=cache_position,
426
- past_key_values=past_key_values,
427
- position_ids=position_ids,
428
- )
 
 
 
 
429
 
430
  hidden_states = inputs_embeds
431
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
 
417
  if position_ids is None:
418
  position_ids = cache_position.unsqueeze(0)
419
 
420
+ if self.training:
421
+ causal_mask = None
422
+
423
+ else:
424
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
425
+ causal_mask = mask_function(
426
+ config=self.config,
427
+ input_embeds=inputs_embeds,
428
+ attention_mask=attention_mask,
429
+ cache_position=cache_position,
430
+ past_key_values=past_key_values,
431
+ position_ids=position_ids,
432
+ )
433
 
434
  hidden_states = inputs_embeds
435
  position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
modeling_ministral_dlm.py CHANGED
@@ -518,7 +518,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
518
 
519
  if labels is not None and self.config.dlm_paradigm != 'autoregressive':
520
  if masked_indices is not None:
521
- #assert p_mask is not None
522
 
523
  if loss_mask is not None:
524
  masked_indices[loss_mask == 0] = 0
 
518
 
519
  if labels is not None and self.config.dlm_paradigm != 'autoregressive':
520
  if masked_indices is not None:
521
+ assert p_mask is not None
522
 
523
  if loss_mask is not None:
524
  masked_indices[loss_mask == 0] = 0