YongganFu commited on
Commit
281fcf8
·
verified ·
1 Parent(s): 6c967fb

Upload model

Browse files
config.json CHANGED
@@ -24,7 +24,7 @@
24
  "dp_varying_mask_ratio": false,
25
  "enable_self_spec": false,
26
  "enforce_mask": false,
27
- "eos_token_id": 2,
28
  "global_loss_avg": false,
29
  "head_dim": 128,
30
  "hidden_act": "silu",
 
24
  "dp_varying_mask_ratio": false,
25
  "enable_self_spec": false,
26
  "enforce_mask": false,
27
+ "eos_token_id": 11,
28
  "global_loss_avg": false,
29
  "head_dim": 128,
30
  "hidden_act": "silu",
generation_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
- "eos_token_id": 2,
5
  "transformers_version": "4.55.4",
6
  "use_cache": false
7
  }
 
1
  {
2
  "_from_model_config": true,
3
  "bos_token_id": 1,
4
+ "eos_token_id": 11,
5
  "transformers_version": "4.55.4",
6
  "use_cache": false
7
  }
modeling_ministral_dlm.py CHANGED
@@ -872,6 +872,9 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
872
 
873
 
874
  def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None):
 
 
 
875
  out_ids, nfe = generate_with_prefix_cache_block_diff(
876
  model=self,
877
  prompt=prompt_ids,
@@ -986,6 +989,86 @@ class MinistralDiffEncoderModel(Ministral3PreTrainedModel, GenerationMixin):
986
 
987
  return logits, past_key_values
988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
989
  @torch.no_grad()
990
  def self_spec_generate(
991
  self,
 
872
 
873
 
874
  def generate(self, prompt_ids, max_new_tokens, steps, block_length, shift_logits, threshold, causal_context=True, temperature=0, eos_token_id=None):
875
+ if eos_token_id is None:
876
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
877
+
878
  out_ids, nfe = generate_with_prefix_cache_block_diff(
879
  model=self,
880
  prompt=prompt_ids,
 
989
 
990
  return logits, past_key_values
991
 
992
+
993
+ @torch.no_grad()
994
+ def ar_generate(
995
+ self,
996
+ prompt_ids: torch.Tensor,
997
+ max_new_tokens: int = 128,
998
+ temperature: float = 0.0,
999
+ eos_token_id: Optional[int] = None,
1000
+ ) -> tuple:
1001
+ """Autoregressive generation calling the encoder directly (injected by build_hf_tidar_repo).
1002
+
1003
+ Bypasses MinistralDiffEncoderModel.forward() to avoid diffusion-specific
1004
+ code paths. Calls self.encoder (Ministral3Model) with explicit cache_position,
1005
+ position_ids, and use_cache so the KV cache and causal masking behave
1006
+ identically to MistralForCausalLM / vLLM.
1007
+
1008
+ Returns:
1009
+ (output_ids, nfe) where output_ids includes the prompt.
1010
+ """
1011
+ for layer in self.encoder.layers:
1012
+ if hasattr(layer.self_attn, 'diffusion_lm'):
1013
+ layer.self_attn.diffusion_lm = False
1014
+
1015
+ if eos_token_id is None:
1016
+ eos_token_id = getattr(self.config, 'eos_token_id', None)
1017
+
1018
+ device = prompt_ids.device
1019
+ batch_size, prompt_len = prompt_ids.shape
1020
+
1021
+ past_key_values = DynamicCache()
1022
+ cache_position = torch.arange(prompt_len, device=device)
1023
+ position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
1024
+
1025
+ enc_out = self.encoder(
1026
+ input_ids=prompt_ids,
1027
+ position_ids=position_ids,
1028
+ past_key_values=past_key_values,
1029
+ use_cache=True,
1030
+ cache_position=cache_position,
1031
+ )
1032
+ past_key_values = enc_out.past_key_values
1033
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1034
+
1035
+ generated_tokens = []
1036
+ nfe = 0
1037
+
1038
+ for step in range(max_new_tokens):
1039
+ nfe += 1
1040
+
1041
+ if temperature > 0:
1042
+ probs = torch.softmax(next_logit / temperature, dim=-1)
1043
+ next_token = torch.multinomial(probs, num_samples=1)
1044
+ else:
1045
+ next_token = torch.argmax(next_logit, dim=-1, keepdim=True)
1046
+
1047
+ generated_tokens.append(next_token)
1048
+
1049
+ if eos_token_id is not None and (next_token == eos_token_id).all():
1050
+ break
1051
+
1052
+ if step < max_new_tokens - 1:
1053
+ cur_pos = prompt_len + step
1054
+ step_cache_pos = torch.tensor([cur_pos], device=device)
1055
+ step_pos_ids = step_cache_pos.unsqueeze(0).expand(batch_size, -1)
1056
+
1057
+ enc_out = self.encoder(
1058
+ input_ids=next_token,
1059
+ position_ids=step_pos_ids,
1060
+ past_key_values=past_key_values,
1061
+ use_cache=True,
1062
+ cache_position=step_cache_pos,
1063
+ )
1064
+ past_key_values = enc_out.past_key_values
1065
+ next_logit = self.diffusion_head(enc_out.last_hidden_state[:, -1:, :]).squeeze(1)
1066
+
1067
+ all_generated = torch.cat(generated_tokens, dim=1)
1068
+ output_ids = torch.cat([prompt_ids, all_generated], dim=1)
1069
+ return output_ids, nfe
1070
+
1071
+
1072
  @torch.no_grad()
1073
  def self_spec_generate(
1074
  self,