Rebrand: Ministral DLM → Nemotron Labs Diffusion

#2
config.json CHANGED
@@ -5,14 +5,14 @@
5
  "adaptive_mask_rate": false,
6
  "ar_loss_weight": 1.0,
7
  "architectures": [
8
- "MinistralDiffEncoderModel"
9
  ],
10
  "attention_bias": false,
11
  "attention_dropout": 0.0,
12
  "attn_implementation": null,
13
  "auto_map": {
14
- "AutoConfig": "configuration_ministral_dlm.MinistralDLMConfig",
15
- "AutoModel": "modeling_ministral_dlm.MinistralDiffEncoderModel"
16
  },
17
  "block_size": 32,
18
  "bos_token_id": 1,
@@ -34,7 +34,7 @@
34
  "mask_token_id": 100,
35
  "max_position_embeddings": 262144,
36
  "mlp_bias": false,
37
- "model_type": "ministral_dlm",
38
  "multi_sampling": null,
39
  "num_ar_layers": 0,
40
  "num_attention_heads": 32,
 
5
  "adaptive_mask_rate": false,
6
  "ar_loss_weight": 1.0,
7
  "architectures": [
8
+ "NemotronLabsDiffusionEncoderModel"
9
  ],
10
  "attention_bias": false,
11
  "attention_dropout": 0.0,
12
  "attn_implementation": null,
13
  "auto_map": {
14
+ "AutoConfig": "configuration_nemotron_labs_diffusion.NemotronLabsDiffusionConfig",
15
+ "AutoModel": "modeling_nemotron_labs_diffusion.NemotronLabsDiffusionEncoderModel"
16
  },
17
  "block_size": 32,
18
  "bos_token_id": 1,
 
34
  "mask_token_id": 100,
35
  "max_position_embeddings": 262144,
36
  "mlp_bias": false,
37
+ "model_type": "nemotron_labs_diffusion",
38
  "multi_sampling": null,
39
  "num_ar_layers": 0,
40
  "num_attention_heads": 32,
configuration_ministral_dlm.py → configuration_nemotron_labs_diffusion.py RENAMED
@@ -12,7 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """Ministral DLM model configuration"""
16
 
17
  from transformers.configuration_utils import PretrainedConfig
18
  from transformers.modeling_rope_utils import rope_config_validation
@@ -22,7 +22,7 @@ from transformers.utils import logging
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
- class MinistralDLMConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
  It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
@@ -114,7 +114,7 @@ class MinistralDLMConfig(PretrainedConfig):
114
  Adaptive permutation ratio for global.
115
  """
116
 
117
- model_type = "ministral_dlm"
118
  keys_to_ignore_at_inference = ["past_key_values"]
119
 
120
  # Default tensor parallel plan for base model `Ministral`
@@ -243,5 +243,5 @@ class MinistralDLMConfig(PretrainedConfig):
243
  )
244
 
245
 
246
- __all__ = ["MinistralDLMConfig"]
247
 
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """Nemotron Labs Diffusion model configuration"""
16
 
17
  from transformers.configuration_utils import PretrainedConfig
18
  from transformers.modeling_rope_utils import rope_config_validation
 
22
  logger = logging.get_logger(__name__)
23
 
24
 
25
+ class NemotronLabsDiffusionConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`Ministral3Model`] for diffusion language models.
28
  It is used to instantiate a Ministral model according to the specified arguments, defining the model architecture.
 
114
  Adaptive permutation ratio for global.
115
  """
116
 
117
+ model_type = "nemotron_labs_diffusion"
118
  keys_to_ignore_at_inference = ["past_key_values"]
119
 
120
  # Default tensor parallel plan for base model `Ministral`
 
243
  )
244
 
245
 
246
+ __all__ = ["NemotronLabsDiffusionConfig"]
247
 
modeling_ministral.py CHANGED
@@ -25,7 +25,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
25
  from transformers.processing_utils import Unpack
26
  from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
  # from transformers.utils.generic import maybe_autocast
28
- from .configuration_ministral_dlm import MinistralDLMConfig
29
 
30
  #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
@@ -110,7 +110,7 @@ def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_positi
110
  class Ministral3Attention(nn.Module):
111
  """Multi-headed attention from 'Attention Is All You Need' paper"""
112
 
113
- def __init__(self, config: MinistralDLMConfig, layer_idx: int):
114
  super().__init__()
115
  self.config = config
116
  self.layer_idx = layer_idx
@@ -234,7 +234,7 @@ class Ministral3RMSNorm(nn.Module):
234
 
235
 
236
  class Ministral3DecoderLayer(GradientCheckpointingLayer):
237
- def __init__(self, config: MinistralDLMConfig, layer_idx: int):
238
  super().__init__()
239
  self.hidden_size = config.hidden_size
240
 
@@ -284,7 +284,7 @@ class Ministral3DecoderLayer(GradientCheckpointingLayer):
284
 
285
  @auto_docstring
286
  class Ministral3PreTrainedModel(PreTrainedModel):
287
- config: MinistralDLMConfig
288
  base_model_prefix = "model"
289
  supports_gradient_checkpointing = True
290
  _no_split_modules = ["Ministral3DecoderLayer"]
@@ -304,7 +304,7 @@ class Ministral3PreTrainedModel(PreTrainedModel):
304
  class Ministral3RotaryEmbedding(nn.Module):
305
  inv_freq: torch.Tensor # fix linting for `register_buffer`
306
 
307
- def __init__(self, config: MinistralDLMConfig, device=None):
308
  super().__init__()
309
  self.max_seq_len_cached = config.max_position_embeddings
310
  self.original_max_seq_len = config.max_position_embeddings
@@ -323,7 +323,7 @@ class Ministral3RotaryEmbedding(nn.Module):
323
 
324
  @staticmethod
325
  def compute_default_rope_parameters(
326
- config: Optional[MinistralDLMConfig] = None,
327
  device: Optional["torch.device"] = None,
328
  seq_len: Optional[int] = None,
329
  ) -> tuple["torch.Tensor", float]:
@@ -370,7 +370,7 @@ class Ministral3RotaryEmbedding(nn.Module):
370
 
371
  @auto_docstring
372
  class Ministral3Model(Ministral3PreTrainedModel):
373
- def __init__(self, config: MinistralDLMConfig):
374
  super().__init__(config)
375
  self.padding_idx = config.pad_token_id
376
  self.vocab_size = config.vocab_size
 
25
  from transformers.processing_utils import Unpack
26
  from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
27
  # from transformers.utils.generic import maybe_autocast
28
+ from .configuration_nemotron_labs_diffusion import NemotronLabsDiffusionConfig
29
 
30
  #ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask_older_torch
31
 
 
110
  class Ministral3Attention(nn.Module):
111
  """Multi-headed attention from 'Attention Is All You Need' paper"""
112
 
113
+ def __init__(self, config: NemotronLabsDiffusionConfig, layer_idx: int):
114
  super().__init__()
115
  self.config = config
116
  self.layer_idx = layer_idx
 
234
 
235
 
236
  class Ministral3DecoderLayer(GradientCheckpointingLayer):
237
+ def __init__(self, config: NemotronLabsDiffusionConfig, layer_idx: int):
238
  super().__init__()
239
  self.hidden_size = config.hidden_size
240
 
 
284
 
285
  @auto_docstring
286
  class Ministral3PreTrainedModel(PreTrainedModel):
287
+ config: NemotronLabsDiffusionConfig
288
  base_model_prefix = "model"
289
  supports_gradient_checkpointing = True
290
  _no_split_modules = ["Ministral3DecoderLayer"]
 
304
  class Ministral3RotaryEmbedding(nn.Module):
305
  inv_freq: torch.Tensor # fix linting for `register_buffer`
306
 
307
+ def __init__(self, config: NemotronLabsDiffusionConfig, device=None):
308
  super().__init__()
309
  self.max_seq_len_cached = config.max_position_embeddings
310
  self.original_max_seq_len = config.max_position_embeddings
 
323
 
324
  @staticmethod
325
  def compute_default_rope_parameters(
326
+ config: Optional[NemotronLabsDiffusionConfig] = None,
327
  device: Optional["torch.device"] = None,
328
  seq_len: Optional[int] = None,
329
  ) -> tuple["torch.Tensor", float]:
 
370
 
371
  @auto_docstring
372
  class Ministral3Model(Ministral3PreTrainedModel):
373
+ def __init__(self, config: NemotronLabsDiffusionConfig):
374
  super().__init__(config)
375
  self.padding_idx = config.pad_token_id
376
  self.vocab_size = config.vocab_size
modeling_ministral_dlm.py → modeling_nemotron_labs_diffusion.py RENAMED
@@ -29,11 +29,11 @@ import math
29
 
30
  from .chat_utils import generate_with_prefix_cache_block_diff
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
- from .configuration_ministral_dlm import MinistralDLMConfig
33
 
34
 
35
  @dataclass
36
- class MinistralDiffOutputWithPast(ModelOutput):
37
  loss: torch.FloatTensor | None = None
38
  logits: torch.FloatTensor | None = None
39
  causal_logits: torch.FloatTensor | None = None
@@ -87,7 +87,7 @@ def _extract_draft_kv_cache(past_key_values: DynamicCache, clean_len: int, block
87
 
88
 
89
  # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
90
- class MinistralFlexAttention(Ministral3Attention):
91
  def __init__(self, *args, **kwargs):
92
  super().__init__(*args, **kwargs)
93
 
@@ -434,14 +434,14 @@ def gumbel_topk(log_w: torch.Tensor, k: int) -> torch.Tensor:
434
  return mask
435
 
436
 
437
- class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
438
  """
439
  A single model with:
440
  - a bidirectional encoder + diffusion‐LM head over A
441
  - a causal decoder + LM head over B, conditioned on F_A
442
  """
443
 
444
- def __init__(self, config: MinistralDLMConfig):
445
  super().__init__(config)
446
 
447
  self.mask_token_id = config.mask_token_id
@@ -450,7 +450,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
450
  diffusion_config.diffusion_lm = True
451
 
452
  if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
453
- diffusion_config.attn_class = MinistralFlexAttention
454
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
455
  diffusion_config.attn_class = Ministral3Attention
456
 
@@ -867,7 +867,7 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
867
  else:
868
  loss = (loss, num_mask_tokens)
869
 
870
- return MinistralDiffOutputWithPast(
871
  loss=loss if not is_teacher else logits,
872
  logits=logits,
873
  causal_logits=causal_logits,
@@ -1109,4 +1109,4 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
1109
  return x[:, : -(block_length * 2)], nfe
1110
 
1111
 
1112
- __all__ = ["MinistralDiffEncoderModel", "MinistralFlexAttention"]
 
29
 
30
  from .chat_utils import generate_with_prefix_cache_block_diff
31
  from .modeling_ministral import Ministral3Model, Ministral3PreTrainedModel, Ministral3Attention, apply_rotary_pos_emb, repeat_kv, _get_llama_4_attn_scale
32
+ from .configuration_nemotron_labs_diffusion import NemotronLabsDiffusionConfig
33
 
34
 
35
  @dataclass
36
+ class NemotronLabsDiffusionOutputWithPast(ModelOutput):
37
  loss: torch.FloatTensor | None = None
38
  logits: torch.FloatTensor | None = None
39
  causal_logits: torch.FloatTensor | None = None
 
87
 
88
 
89
  # with reference to https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
90
+ class NemotronLabsDiffusionFlexAttention(Ministral3Attention):
91
  def __init__(self, *args, **kwargs):
92
  super().__init__(*args, **kwargs)
93
 
 
434
  return mask
435
 
436
 
437
+ class NemotronLabsDiffusionEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
438
  """
439
  A single model with:
440
  - a bidirectional encoder + diffusion‐LM head over A
441
  - a causal decoder + LM head over B, conditioned on F_A
442
  """
443
 
444
+ def __init__(self, config: NemotronLabsDiffusionConfig):
445
  super().__init__(config)
446
 
447
  self.mask_token_id = config.mask_token_id
 
450
  diffusion_config.diffusion_lm = True
451
 
452
  if config.dlm_paradigm in ['block_diff', 'sbd_block_diff']:
453
+ diffusion_config.attn_class = NemotronLabsDiffusionFlexAttention
454
  elif config.dlm_paradigm in ['bidirectional', 'autoregressive']:
455
  diffusion_config.attn_class = Ministral3Attention
456
 
 
867
  else:
868
  loss = (loss, num_mask_tokens)
869
 
870
+ return NemotronLabsDiffusionOutputWithPast(
871
  loss=loss if not is_teacher else logits,
872
  logits=logits,
873
  causal_logits=causal_logits,
 
1109
  return x[:, : -(block_length * 2)], nfe
1110
 
1111
 
1112
+ __all__ = ["NemotronLabsDiffusionEncoderModel", "NemotronLabsDiffusionFlexAttention"]