Upload 9 files
Browse files- chat_template.jinja +85 -0
- config.json +59 -0
- generation_config.json +6 -0
- model_minimind.py +287 -0
- model_omni.py +461 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +52 -0
- tokenizer.json +0 -0
- tokenizer_config.json +335 -0
chat_template.jinja
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if true %}
|
| 44 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 45 |
+
{%- endif %}
|
| 46 |
+
{%- if message.tool_calls %}
|
| 47 |
+
{%- for tool_call in message.tool_calls %}
|
| 48 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 49 |
+
{{- '\n' }}
|
| 50 |
+
{%- endif %}
|
| 51 |
+
{%- if tool_call.function %}
|
| 52 |
+
{%- set tool_call = tool_call.function %}
|
| 53 |
+
{%- endif %}
|
| 54 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 55 |
+
{{- tool_call.name }}
|
| 56 |
+
{{- '", "arguments": ' }}
|
| 57 |
+
{%- if tool_call.arguments is string %}
|
| 58 |
+
{{- tool_call.arguments }}
|
| 59 |
+
{%- else %}
|
| 60 |
+
{{- tool_call.arguments | tojson }}
|
| 61 |
+
{%- endif %}
|
| 62 |
+
{{- '}\n</tool_call>' }}
|
| 63 |
+
{%- endfor %}
|
| 64 |
+
{%- endif %}
|
| 65 |
+
{{- '<|im_end|>\n' }}
|
| 66 |
+
{%- elif message.role == "tool" %}
|
| 67 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 68 |
+
{{- '<|im_start|>user' }}
|
| 69 |
+
{%- endif %}
|
| 70 |
+
{{- '\n<tool_response>\n' }}
|
| 71 |
+
{{- content }}
|
| 72 |
+
{{- '\n</tool_response>' }}
|
| 73 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_end|>\n' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{%- endif %}
|
| 77 |
+
{%- endfor %}
|
| 78 |
+
{%- if add_generation_prompt %}
|
| 79 |
+
{{- '<|im_start|>assistant\n' }}
|
| 80 |
+
{%- if open_thinking is defined and open_thinking is true %}
|
| 81 |
+
{{- '<think>\n' }}
|
| 82 |
+
{%- else %}
|
| 83 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MiniMindOmni"
|
| 4 |
+
],
|
| 5 |
+
"audio_hidden_size": 512,
|
| 6 |
+
"audio_ids": [
|
| 7 |
+
16
|
| 8 |
+
],
|
| 9 |
+
"audio_pad_token": 2049,
|
| 10 |
+
"audio_special_token": "<|audio_pad|>",
|
| 11 |
+
"audio_spk_token": 2051,
|
| 12 |
+
"audio_stop_token": 2050,
|
| 13 |
+
"audio_vocab_size": 2112,
|
| 14 |
+
"auto_map": {
|
| 15 |
+
"AutoConfig": "model_omni.OmniConfig",
|
| 16 |
+
"AutoModelForCausalLM": "model_omni.MiniMindOmni"
|
| 17 |
+
},
|
| 18 |
+
"bos_token_id": 1,
|
| 19 |
+
"bridge_layer": 3,
|
| 20 |
+
"dropout": 0.0,
|
| 21 |
+
"dtype": "bfloat16",
|
| 22 |
+
"eos_token_id": 2,
|
| 23 |
+
"flash_attn": true,
|
| 24 |
+
"head_dim": 96,
|
| 25 |
+
"hidden_act": "silu",
|
| 26 |
+
"hidden_size": 768,
|
| 27 |
+
"image_hidden_size": 768,
|
| 28 |
+
"image_ids": [
|
| 29 |
+
12
|
| 30 |
+
],
|
| 31 |
+
"image_special_token": "<|image_pad|>",
|
| 32 |
+
"image_token_len": 64,
|
| 33 |
+
"inference_rope_scaling": false,
|
| 34 |
+
"intermediate_size": 2432,
|
| 35 |
+
"max_position_embeddings": 32768,
|
| 36 |
+
"model_type": "minimind-o",
|
| 37 |
+
"moe_intermediate_size": 2432,
|
| 38 |
+
"norm_topk_prob": true,
|
| 39 |
+
"num_attention_heads": 8,
|
| 40 |
+
"num_experts": 4,
|
| 41 |
+
"num_experts_per_tok": 1,
|
| 42 |
+
"num_hidden_layers": 8,
|
| 43 |
+
"num_key_value_heads": 4,
|
| 44 |
+
"num_talker_hidden_layers": 4,
|
| 45 |
+
"rms_norm_eps": 1e-06,
|
| 46 |
+
"rope_scaling": null,
|
| 47 |
+
"rope_theta": 1000000.0,
|
| 48 |
+
"router_aux_loss_coef": 0.0005,
|
| 49 |
+
"spk_emb_size": 192,
|
| 50 |
+
"talker_hidden_size": 768,
|
| 51 |
+
"think_end_ids": [
|
| 52 |
+
26,
|
| 53 |
+
234,
|
| 54 |
+
234
|
| 55 |
+
],
|
| 56 |
+
"transformers_version": "4.57.6",
|
| 57 |
+
"use_moe": false,
|
| 58 |
+
"vocab_size": 6400
|
| 59 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"transformers_version": "4.57.6"
|
| 6 |
+
}
|
model_minimind.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math, torch, torch.nn.functional as F
|
| 2 |
+
from torch import nn
|
| 3 |
+
from transformers.activations import ACT2FN
|
| 4 |
+
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
| 5 |
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
| 6 |
+
|
| 7 |
+
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
| 8 |
+
# MiniMind Config
|
| 9 |
+
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
| 10 |
+
class MiniMindConfig(PretrainedConfig):
|
| 11 |
+
model_type = "minimind"
|
| 12 |
+
def __init__(self, hidden_size=768, num_hidden_layers=8, use_moe=False, **kwargs):
|
| 13 |
+
super().__init__(**kwargs)
|
| 14 |
+
self.hidden_size = hidden_size
|
| 15 |
+
self.num_hidden_layers = num_hidden_layers
|
| 16 |
+
self.use_moe = use_moe
|
| 17 |
+
self.dropout = kwargs.get("dropout", 0.0)
|
| 18 |
+
self.vocab_size = kwargs.get("vocab_size", 6400)
|
| 19 |
+
self.bos_token_id = kwargs.get("bos_token_id", 1)
|
| 20 |
+
self.eos_token_id = kwargs.get("eos_token_id", 2)
|
| 21 |
+
self.flash_attn = kwargs.get("flash_attn", True)
|
| 22 |
+
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
|
| 23 |
+
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
|
| 24 |
+
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
|
| 25 |
+
self.hidden_act = kwargs.get("hidden_act", 'silu')
|
| 26 |
+
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
|
| 27 |
+
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
|
| 28 |
+
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
| 29 |
+
self.rope_theta = kwargs.get("rope_theta", 1e6)
|
| 30 |
+
self.tie_word_embeddings = kwargs.get("tie_word_embeddings", True)
|
| 31 |
+
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
|
| 32 |
+
self.rope_scaling = {
|
| 33 |
+
"beta_fast": 32,
|
| 34 |
+
"beta_slow": 1,
|
| 35 |
+
"factor": 16,
|
| 36 |
+
"original_max_position_embeddings": 2048,
|
| 37 |
+
"attention_factor": 1.0,
|
| 38 |
+
"type": "yarn"
|
| 39 |
+
} if self.inference_rope_scaling else None
|
| 40 |
+
### MoE specific configs (ignored if use_moe = False)
|
| 41 |
+
self.num_experts = kwargs.get("num_experts", 4)
|
| 42 |
+
self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 1)
|
| 43 |
+
self.moe_intermediate_size = kwargs.get("moe_intermediate_size", self.intermediate_size)
|
| 44 |
+
self.norm_topk_prob = kwargs.get("norm_topk_prob", True)
|
| 45 |
+
self.router_aux_loss_coef = kwargs.get("router_aux_loss_coef", 5e-4)
|
| 46 |
+
|
| 47 |
+
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
| 48 |
+
# MiniMind Model
|
| 49 |
+
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
| 50 |
+
class RMSNorm(torch.nn.Module):
|
| 51 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.eps = eps
|
| 54 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 55 |
+
|
| 56 |
+
def norm(self, x):
|
| 57 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
return (self.weight * self.norm(x.float())).type_as(x)
|
| 61 |
+
|
| 62 |
+
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
|
| 63 |
+
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
|
| 64 |
+
if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
|
| 65 |
+
orig_max, factor, beta_fast, beta_slow, attn_factor = (
|
| 66 |
+
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
|
| 67 |
+
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
|
| 68 |
+
)
|
| 69 |
+
if end / orig_max > 1.0:
|
| 70 |
+
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
|
| 71 |
+
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
|
| 72 |
+
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
|
| 73 |
+
freqs = freqs * (1 - ramp + ramp / factor)
|
| 74 |
+
t = torch.arange(end, device=freqs.device)
|
| 75 |
+
freqs = torch.outer(t, freqs).float()
|
| 76 |
+
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
|
| 77 |
+
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
|
| 78 |
+
return freqs_cos, freqs_sin
|
| 79 |
+
|
| 80 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 81 |
+
def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
|
| 82 |
+
q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
|
| 83 |
+
k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
|
| 84 |
+
return q_embed, k_embed
|
| 85 |
+
|
| 86 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 87 |
+
bs, slen, num_key_value_heads, head_dim = x.shape
|
| 88 |
+
if n_rep == 1: return x
|
| 89 |
+
return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))
|
| 90 |
+
|
| 91 |
+
class Attention(nn.Module):
|
| 92 |
+
def __init__(self, config: MiniMindConfig):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.num_key_value_heads = config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
| 95 |
+
self.n_local_heads = config.num_attention_heads
|
| 96 |
+
self.n_local_kv_heads = self.num_key_value_heads
|
| 97 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 98 |
+
self.head_dim = config.head_dim
|
| 99 |
+
self.is_causal = True
|
| 100 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
| 101 |
+
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 102 |
+
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 103 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
| 104 |
+
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 105 |
+
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 106 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 107 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 108 |
+
self.dropout = config.dropout
|
| 109 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and config.flash_attn
|
| 110 |
+
|
| 111 |
+
def forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
|
| 112 |
+
bsz, seq_len, _ = x.shape
|
| 113 |
+
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
| 114 |
+
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
| 115 |
+
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
| 116 |
+
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
| 117 |
+
xq, xk = self.q_norm(xq), self.k_norm(xk)
|
| 118 |
+
cos, sin = position_embeddings
|
| 119 |
+
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
| 120 |
+
if past_key_value is not None:
|
| 121 |
+
xk = torch.cat([past_key_value[0], xk], dim=1)
|
| 122 |
+
xv = torch.cat([past_key_value[1], xv], dim=1)
|
| 123 |
+
past_kv = (xk, xv) if use_cache else None
|
| 124 |
+
xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2))
|
| 125 |
+
if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
| 126 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal)
|
| 127 |
+
else:
|
| 128 |
+
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 129 |
+
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
|
| 130 |
+
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
|
| 131 |
+
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv
|
| 132 |
+
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
| 133 |
+
output = self.resid_dropout(self.o_proj(output))
|
| 134 |
+
return output, past_kv
|
| 135 |
+
|
| 136 |
+
class FeedForward(nn.Module):
|
| 137 |
+
def __init__(self, config: MiniMindConfig, intermediate_size: int = None):
|
| 138 |
+
super().__init__()
|
| 139 |
+
intermediate_size = intermediate_size or config.intermediate_size
|
| 140 |
+
self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
|
| 141 |
+
self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
|
| 142 |
+
self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
|
| 143 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 147 |
+
|
| 148 |
+
class MOEFeedForward(nn.Module):
|
| 149 |
+
def __init__(self, config: MiniMindConfig):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.config = config
|
| 152 |
+
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 153 |
+
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
|
| 154 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
batch_size, seq_len, hidden_dim = x.shape
|
| 158 |
+
x_flat = x.view(-1, hidden_dim)
|
| 159 |
+
scores = F.softmax(self.gate(x_flat), dim=-1)
|
| 160 |
+
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False)
|
| 161 |
+
if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
|
| 162 |
+
y = torch.zeros_like(x_flat)
|
| 163 |
+
for i, expert in enumerate(self.experts):
|
| 164 |
+
mask = (topk_idx == i)
|
| 165 |
+
if mask.any():
|
| 166 |
+
token_idx = mask.any(dim=-1).nonzero().flatten()
|
| 167 |
+
weight = topk_weight[mask].view(-1, 1)
|
| 168 |
+
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
|
| 169 |
+
elif self.training:
|
| 170 |
+
y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())
|
| 171 |
+
if self.training and self.config.router_aux_loss_coef > 0:
|
| 172 |
+
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
|
| 173 |
+
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
|
| 174 |
+
else:
|
| 175 |
+
self.aux_loss = scores.new_zeros(1).squeeze()
|
| 176 |
+
return y.view(batch_size, seq_len, hidden_dim)
|
| 177 |
+
|
| 178 |
+
class MiniMindBlock(nn.Module):
|
| 179 |
+
def __init__(self, layer_id: int, config: MiniMindConfig):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.self_attn = Attention(config)
|
| 182 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 183 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 184 |
+
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
| 185 |
+
|
| 186 |
+
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
|
| 187 |
+
residual = hidden_states
|
| 188 |
+
hidden_states, present_key_value = self.self_attn(
|
| 189 |
+
self.input_layernorm(hidden_states), position_embeddings,
|
| 190 |
+
past_key_value, use_cache, attention_mask
|
| 191 |
+
)
|
| 192 |
+
hidden_states += residual
|
| 193 |
+
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
|
| 194 |
+
return hidden_states, present_key_value
|
| 195 |
+
|
| 196 |
+
class MiniMindModel(nn.Module):
|
| 197 |
+
def __init__(self, config: MiniMindConfig):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.config = config
|
| 200 |
+
self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
|
| 201 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 202 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 203 |
+
self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
|
| 204 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 205 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim, end=config.max_position_embeddings, rope_base=config.rope_theta, rope_scaling=config.rope_scaling)
|
| 206 |
+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
| 207 |
+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
| 208 |
+
|
| 209 |
+
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, **kwargs):
|
| 210 |
+
batch_size, seq_length = input_ids.shape
|
| 211 |
+
if hasattr(past_key_values, 'layers'): past_key_values = None
|
| 212 |
+
past_key_values = past_key_values or [None] * len(self.layers)
|
| 213 |
+
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
| 214 |
+
hidden_states = self.dropout(self.embed_tokens(input_ids))
|
| 215 |
+
# Recompute RoPE buffers lost during meta-device init (transformers>=5.x)
|
| 216 |
+
if self.freqs_cos[0, 0] == 0:
|
| 217 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.config.head_dim, end=self.config.max_position_embeddings, rope_base=self.config.rope_theta, rope_scaling=self.config.rope_scaling)
|
| 218 |
+
self.freqs_cos, self.freqs_sin = freqs_cos.to(hidden_states.device), freqs_sin.to(hidden_states.device)
|
| 219 |
+
position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length])
|
| 220 |
+
presents = []
|
| 221 |
+
for layer, past_key_value in zip(self.layers, past_key_values):
|
| 222 |
+
hidden_states, present = layer(
|
| 223 |
+
hidden_states,
|
| 224 |
+
position_embeddings,
|
| 225 |
+
past_key_value=past_key_value,
|
| 226 |
+
use_cache=use_cache,
|
| 227 |
+
attention_mask=attention_mask
|
| 228 |
+
)
|
| 229 |
+
presents.append(present)
|
| 230 |
+
hidden_states = self.norm(hidden_states)
|
| 231 |
+
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
|
| 232 |
+
return hidden_states, presents, aux_loss
|
| 233 |
+
|
| 234 |
+
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
| 235 |
+
config_class = MiniMindConfig
|
| 236 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 237 |
+
def __init__(self, config: MiniMindConfig = None):
|
| 238 |
+
self.config = config or MiniMindConfig()
|
| 239 |
+
super().__init__(self.config)
|
| 240 |
+
self.model = MiniMindModel(self.config)
|
| 241 |
+
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 242 |
+
if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
|
| 243 |
+
self.post_init()
|
| 244 |
+
|
| 245 |
+
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
|
| 246 |
+
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
|
| 247 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 248 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 249 |
+
loss = None
|
| 250 |
+
if labels is not None:
|
| 251 |
+
x, y = logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous()
|
| 252 |
+
loss = F.cross_entropy(x.view(-1, x.size(-1)), y.view(-1), ignore_index=-100)
|
| 253 |
+
return MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
| 254 |
+
|
| 255 |
+
# https://github.com/jingyaogong/minimind/discussions/611
|
| 256 |
+
@torch.inference_mode()
|
| 257 |
+
def generate(self, inputs=None, attention_mask=None, max_new_tokens=8192, temperature=0.85, top_p=0.85, top_k=50, eos_token_id=2, streamer=None, use_cache=True, num_return_sequences=1, do_sample=True, repetition_penalty=1.0, **kwargs):
|
| 258 |
+
input_ids = kwargs.pop("input_ids", inputs).repeat(num_return_sequences, 1)
|
| 259 |
+
attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None
|
| 260 |
+
past_key_values = kwargs.pop("past_key_values", None)
|
| 261 |
+
finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
|
| 262 |
+
if streamer: streamer.put(input_ids.cpu())
|
| 263 |
+
for _ in range(max_new_tokens):
|
| 264 |
+
past_len = past_key_values[0][0].shape[1] if past_key_values else 0
|
| 265 |
+
outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
|
| 266 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
|
| 267 |
+
logits = outputs.logits[:, -1, :] / temperature
|
| 268 |
+
if repetition_penalty != 1.0:
|
| 269 |
+
for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty
|
| 270 |
+
if top_k > 0:
|
| 271 |
+
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
|
| 272 |
+
if top_p < 1.0:
|
| 273 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 274 |
+
mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
|
| 275 |
+
mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
|
| 276 |
+
logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')
|
| 277 |
+
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
|
| 278 |
+
if eos_token_id is not None: next_token = torch.where(finished.unsqueeze(-1), next_token.new_full((next_token.shape[0], 1), eos_token_id), next_token)
|
| 279 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 280 |
+
past_key_values = outputs.past_key_values if use_cache else None
|
| 281 |
+
if streamer: streamer.put(next_token.cpu())
|
| 282 |
+
if eos_token_id is not None:
|
| 283 |
+
finished |= next_token.squeeze(-1).eq(eos_token_id)
|
| 284 |
+
if finished.all(): break
|
| 285 |
+
if streamer: streamer.end()
|
| 286 |
+
if kwargs.get("return_kv"): return {'generated_ids': input_ids, 'past_kv': past_key_values}
|
| 287 |
+
return input_ids
|
model_omni.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, math, torch, soundfile as sf, librosa, warnings, numpy as np, onnxruntime as ort, logging, contextlib, io
|
| 2 |
+
from types import SimpleNamespace
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
| 6 |
+
from transformers import SiglipImageProcessor, SiglipVisionModel, logging as hf_logging
|
| 7 |
+
from .model_minimind import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class OmniConfig(MiniMindConfig):
|
| 11 |
+
model_type = "minimind-o"
|
| 12 |
+
def __init__(self, **kwargs):
|
| 13 |
+
super().__init__(**kwargs)
|
| 14 |
+
self.num_talker_hidden_layers = kwargs.get("num_talker_hidden_layers", 4)
|
| 15 |
+
self.talker_hidden_size = kwargs.get("talker_hidden_size", 768)
|
| 16 |
+
self.audio_ids = kwargs.get("audio_ids", [16]) # "<|audio_pad|>" token id
|
| 17 |
+
self.audio_special_token = kwargs.get("audio_special_token", "<|audio_pad|>")
|
| 18 |
+
self.audio_hidden_size = kwargs.get("audio_hidden_size", 512)
|
| 19 |
+
self.audio_vocab_size = kwargs.get("audio_vocab_size", 2112)
|
| 20 |
+
self.audio_pad_token = kwargs.get("audio_pad_token", 2049)
|
| 21 |
+
self.audio_stop_token = kwargs.get("audio_stop_token", 2050)
|
| 22 |
+
self.audio_spk_token = kwargs.get("audio_spk_token", 2051)
|
| 23 |
+
self.spk_emb_size = kwargs.get("spk_emb_size", 192)
|
| 24 |
+
self.think_end_ids = kwargs.get("think_end_ids", [26, 234, 234]) # </think>\n\n
|
| 25 |
+
self.image_ids = kwargs.get("image_ids", [12]) # "<|image_pad|>" token id
|
| 26 |
+
self.image_special_token = kwargs.get("image_special_token", "<|image_pad|>")
|
| 27 |
+
self.image_hidden_size = kwargs.get("image_hidden_size", 768)
|
| 28 |
+
self.image_token_len = kwargs.get("image_token_len", 64)
|
| 29 |
+
self.bridge_layer = kwargs.get("bridge_layer", self.num_hidden_layers // 2 - 1)
|
| 30 |
+
|
| 31 |
+
class MMAudioProjector(nn.Module):
|
| 32 |
+
def __init__(self, in_dim, out_dim):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.mlp = nn.Sequential(
|
| 35 |
+
nn.LayerNorm(in_dim),
|
| 36 |
+
nn.Linear(in_dim, out_dim),
|
| 37 |
+
nn.GELU(),
|
| 38 |
+
nn.Linear(out_dim, out_dim),
|
| 39 |
+
)
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return self.mlp(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MMVisionProjector(nn.Module):
|
| 45 |
+
def __init__(self, in_dim, out_dim, source_tokens=64, target_tokens=64):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.mlp = nn.Sequential(
|
| 48 |
+
nn.LayerNorm(in_dim),
|
| 49 |
+
nn.Linear(in_dim, out_dim),
|
| 50 |
+
nn.GELU(),
|
| 51 |
+
nn.Linear(out_dim, out_dim),
|
| 52 |
+
)
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
return self.mlp(x)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TalkerHead(nn.Module):
|
| 58 |
+
def __init__(self, in_features, out_features, num_layers=8, rank=256):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.num_layers = num_layers
|
| 61 |
+
self.base = nn.Linear(in_features, out_features, bias=False)
|
| 62 |
+
self.adapters = nn.ModuleList([nn.Sequential(nn.Linear(in_features, rank, bias=False), nn.GELU(), nn.Linear(rank, out_features, bias=False)) for _ in range(num_layers)])
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
base_out = self.base(x)
|
| 65 |
+
return [base_out + adapter(x) for adapter in self.adapters]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TalkerEmbedding(nn.Module):
|
| 69 |
+
def __init__(self, num_embeddings, embedding_dim, num_layers=8, rank=256):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.num_layers = num_layers
|
| 72 |
+
self.base = nn.Embedding(num_embeddings, embedding_dim)
|
| 73 |
+
self.adapters = nn.ModuleList([nn.Sequential(nn.Embedding(num_embeddings, rank), nn.GELU(), nn.Linear(rank, embedding_dim, bias=False)) for _ in range(num_layers)])
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
base_out = self.base(x)
|
| 76 |
+
return sum(base_out[:, i, :] + self.adapters[i](x[:, i, :]) for i in range(len(self.adapters))) / self.num_layers
|
| 77 |
+
|
| 78 |
+
class SenseVoiceAudioProcessor:
|
| 79 |
+
def __init__(self, frontend): self.frontend = frontend
|
| 80 |
+
def __call__(self, wav, sampling_rate=16000, return_tensors="pt", return_attention_mask=True, **kwargs):
|
| 81 |
+
if isinstance(wav, np.ndarray): wav = torch.from_numpy(wav).float()
|
| 82 |
+
if wav.dim() == 1: wav = wav.unsqueeze(0)
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
fbank, flen = self.frontend(wav, torch.tensor([wav.size(1)]))
|
| 85 |
+
return SimpleNamespace(input_features=fbank, attention_mask=(torch.arange(fbank.size(1)) < flen[0]).long().unsqueeze(0))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TalkerModule(nn.Module):
|
| 89 |
+
def __init__(self, config):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.talker_config = MiniMindConfig(hidden_size=config.talker_hidden_size, use_moe=config.use_moe)
|
| 92 |
+
self.layers = nn.ModuleList([MiniMindBlock(l, self.talker_config) for l in range(config.num_talker_hidden_layers)])
|
| 93 |
+
self.norm = RMSNorm(config.talker_hidden_size, eps=config.rms_norm_eps)
|
| 94 |
+
self.lm_head = TalkerHead(config.talker_hidden_size, config.audio_vocab_size)
|
| 95 |
+
self.embed_tokens = TalkerEmbedding(config.audio_vocab_size, config.talker_hidden_size)
|
| 96 |
+
self.codec_proj = nn.Sequential(nn.Linear(config.talker_hidden_size, config.talker_hidden_size), nn.GELU(), nn.Linear(config.talker_hidden_size, config.talker_hidden_size), RMSNorm(config.talker_hidden_size, eps=config.rms_norm_eps))
|
| 97 |
+
self.embed_proj = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.talker_hidden_size), RMSNorm(config.talker_hidden_size, eps=config.rms_norm_eps))
|
| 98 |
+
self.text_scale, self.audio_scale = nn.Parameter(torch.tensor(3.0)), nn.Parameter(torch.tensor(1.0))
|
| 99 |
+
self.spk_proj = nn.Linear(config.spk_emb_size, config.talker_hidden_size, bias=False)
|
| 100 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.talker_config.head_dim, end=config.max_position_embeddings, rope_base=config.rope_theta, rope_scaling=config.rope_scaling)
|
| 101 |
+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
| 102 |
+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class MiniMindOmni(MiniMindForCausalLM):
|
| 106 |
+
config_class = OmniConfig
|
| 107 |
+
def __init__(self, config: OmniConfig = None, audio_encoder_path=None, vision_model_path=None):
|
| 108 |
+
config = config or OmniConfig()
|
| 109 |
+
super().__init__(config)
|
| 110 |
+
object.__setattr__(self, 'thinker', self.model) # alias: self.thinker == self.model
|
| 111 |
+
object.__setattr__(self.model, 'lm_head', self.lm_head) # alias: self.thinker.lm_head == self.lm_head
|
| 112 |
+
self.talker = TalkerModule(config)
|
| 113 |
+
self.audio_proj = MMAudioProjector(config.audio_hidden_size, config.hidden_size)
|
| 114 |
+
self.vision_proj = MMVisionProjector(config.image_hidden_size, config.hidden_size, target_tokens=config.image_token_len)
|
| 115 |
+
self.audio_pad_token, self.audio_stop_token, self.audio_spk_token = config.audio_pad_token, config.audio_stop_token, config.audio_spk_token
|
| 116 |
+
audio_encoder, audio_processor = self.load_sensevoice(audio_encoder_path) if audio_encoder_path else (None, None)
|
| 117 |
+
object.__setattr__(self, 'audio_encoder', audio_encoder)
|
| 118 |
+
object.__setattr__(self, 'audio_processor', audio_processor)
|
| 119 |
+
vision_encoder, vision_processor = self.load_vision(vision_model_path) if vision_model_path else (None, None)
|
| 120 |
+
object.__setattr__(self, 'vision_encoder', vision_encoder)
|
| 121 |
+
object.__setattr__(self, 'vision_processor', vision_processor)
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def load_sensevoice(path):
|
| 125 |
+
if not os.path.exists(path):
|
| 126 |
+
warnings.warn(f"[MiniMindOmni] SenseVoice path not found: {path}")
|
| 127 |
+
return None, None
|
| 128 |
+
logging.getLogger().setLevel(logging.ERROR)
|
| 129 |
+
hf_logging.set_verbosity_error()
|
| 130 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 131 |
+
from funasr import AutoModel
|
| 132 |
+
m = AutoModel(model=path, trust_remote_code=True, disable_update=True, device="cpu")
|
| 133 |
+
encoder, frontend = m.model.encoder, m.kwargs["frontend"]
|
| 134 |
+
for p in encoder.parameters(): p.requires_grad = False
|
| 135 |
+
return encoder.eval().float(), SenseVoiceAudioProcessor(frontend.eval())
|
| 136 |
+
|
| 137 |
+
@torch.compiler.disable
|
| 138 |
+
def encode_audio_inputs(self, audio_inputs, audio_lens=None):
|
| 139 |
+
if (audio_inputs is None) or (self.audio_encoder is None) or (not audio_inputs.any()): return None
|
| 140 |
+
batch_mask = audio_inputs.flatten(1).any(1)
|
| 141 |
+
enc_dtype = next(self.audio_encoder.parameters()).dtype
|
| 142 |
+
valid_fbank = audio_inputs[batch_mask].to(dtype=enc_dtype)
|
| 143 |
+
if audio_lens is not None:
|
| 144 |
+
valid_lens = audio_lens[batch_mask].to(valid_fbank.device)
|
| 145 |
+
else:
|
| 146 |
+
valid_lens = torch.tensor([valid_fbank.size(1)] * valid_fbank.size(0), device=valid_fbank.device)
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
emb, _ = self.audio_encoder(valid_fbank, valid_lens)
|
| 149 |
+
proj_dtype = next(self.audio_proj.parameters()).dtype
|
| 150 |
+
emb_list = [self.audio_proj(emb[i, :max(1, min(valid_lens[i].item(), emb.size(1)))].unsqueeze(0).to(proj_dtype)).squeeze(0) for i in range(emb.size(0))]
|
| 151 |
+
if batch_mask.all(): return emb_list
|
| 152 |
+
out = [None] * audio_inputs.size(0)
|
| 153 |
+
j = 0
|
| 154 |
+
for i in range(audio_inputs.size(0)):
|
| 155 |
+
if batch_mask[i]:
|
| 156 |
+
out[i] = emb_list[j]
|
| 157 |
+
j += 1
|
| 158 |
+
return out
|
| 159 |
+
|
| 160 |
+
@torch.compiler.disable
|
| 161 |
+
def inject_audio_features(self, tokens, h, audio_feats, seqlen):
|
| 162 |
+
if audio_feats is None or not self.config.audio_ids:
|
| 163 |
+
return h
|
| 164 |
+
marker = self.config.audio_ids[0]
|
| 165 |
+
out = []
|
| 166 |
+
for b in range(h.size(0)):
|
| 167 |
+
hb, seq, i = h[b], tokens[b].tolist(), 0
|
| 168 |
+
af = audio_feats[b] if audio_feats[b] is not None else None
|
| 169 |
+
while i < len(seq):
|
| 170 |
+
if seq[i] == marker:
|
| 171 |
+
start = i
|
| 172 |
+
while i < len(seq) and seq[i] == marker:
|
| 173 |
+
i += 1
|
| 174 |
+
if af is not None:
|
| 175 |
+
inject_len = min(af.size(0), i - start)
|
| 176 |
+
hb = torch.cat((hb[:start], af[:inject_len], hb[start + inject_len:]), dim=0)
|
| 177 |
+
af = None
|
| 178 |
+
else:
|
| 179 |
+
i += 1
|
| 180 |
+
out.append(hb)
|
| 181 |
+
return torch.stack(out)
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def load_vision(path):
|
| 185 |
+
if path is None or not os.path.exists(path):
|
| 186 |
+
warnings.warn(f"[MiniMindOmni] Vision model path not found: {path}. vision_encoder will be None!")
|
| 187 |
+
return None, None
|
| 188 |
+
hf_logging.set_verbosity_error()
|
| 189 |
+
try:
|
| 190 |
+
model = SiglipVisionModel.from_pretrained(path)
|
| 191 |
+
except (RuntimeError, ValueError):
|
| 192 |
+
return None, None
|
| 193 |
+
processor = SiglipImageProcessor.from_pretrained(path)
|
| 194 |
+
for p in model.parameters():
|
| 195 |
+
p.requires_grad = False
|
| 196 |
+
return model.eval(), processor
|
| 197 |
+
|
| 198 |
+
@torch.compiler.disable
|
| 199 |
+
def get_image_embeddings(self, image_inputs):
|
| 200 |
+
if hasattr(image_inputs, 'keys'):
|
| 201 |
+
image_inputs = {k: v.squeeze(1) if v.ndim > 2 and v.shape[1] == 1 else v for k, v in image_inputs.items()}
|
| 202 |
+
pixel_attention_mask = image_inputs.get('pixel_attention_mask')
|
| 203 |
+
if pixel_attention_mask is not None and not pixel_attention_mask.any():
|
| 204 |
+
pv = image_inputs['pixel_values']
|
| 205 |
+
return pv.new_zeros(pv.size(0), pv.size(1), self.config.image_hidden_size)
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
outputs = self.vision_encoder(**image_inputs)
|
| 208 |
+
return outputs.last_hidden_state
|
| 209 |
+
|
| 210 |
+
@torch.compiler.disable
|
| 211 |
+
def encode_image_inputs(self, pixel_values):
|
| 212 |
+
if pixel_values is None or self.vision_encoder is None: return None
|
| 213 |
+
mask = pixel_values.flatten(1).any(1)
|
| 214 |
+
if not mask.any(): return pixel_values.new_zeros(pixel_values.size(0), self.config.image_token_len, self.config.hidden_size)
|
| 215 |
+
with torch.no_grad(): emb = self.vision_encoder(pixel_values=pixel_values[mask]).last_hidden_state
|
| 216 |
+
if emb.dim() == 2: emb = emb.unsqueeze(0)
|
| 217 |
+
emb = self.vision_proj(emb)
|
| 218 |
+
if mask.all(): return emb
|
| 219 |
+
idx = mask.nonzero().view(-1, 1, 1).expand_as(emb)
|
| 220 |
+
return emb.new_zeros(pixel_values.size(0), *emb.shape[1:]).scatter(0, idx, emb)
|
| 221 |
+
|
| 222 |
+
@torch.compiler.disable
|
| 223 |
+
def count_vision_proj(self, tokens, h, vision_tensors=None, seqlen=512):
|
| 224 |
+
if vision_tensors is None or not self.config.image_ids:
|
| 225 |
+
return h
|
| 226 |
+
marker, vf = self.config.image_ids[0], vision_tensors
|
| 227 |
+
if vf.dim() == 3:
|
| 228 |
+
vf = vf.unsqueeze(1)
|
| 229 |
+
out = []
|
| 230 |
+
for b in range(h.size(0)):
|
| 231 |
+
hb, seq, k, i = h[b], tokens[b].tolist(), 0, 0
|
| 232 |
+
while i < len(seq):
|
| 233 |
+
if seq[i] == marker:
|
| 234 |
+
start = i
|
| 235 |
+
while i < len(seq) and seq[i] == marker:
|
| 236 |
+
i += 1
|
| 237 |
+
if k < vf.size(1):
|
| 238 |
+
hb = torch.cat((hb[:start], vf[b][k][:i - start], hb[i:]), dim=0)[:seqlen]
|
| 239 |
+
k += 1
|
| 240 |
+
else:
|
| 241 |
+
i += 1
|
| 242 |
+
out.append(hb)
|
| 243 |
+
return torch.stack(out)
|
| 244 |
+
|
| 245 |
+
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, audio_inputs=None, audio_lens=None, pixel_values=None, **args):
|
| 246 |
+
if len(input_ids.shape) == 2:
|
| 247 |
+
batch_size, seq_length = input_ids.shape
|
| 248 |
+
text_ids = input_ids
|
| 249 |
+
audio_ids = torch.full((batch_size, 8, seq_length), self.audio_pad_token, dtype=torch.long, device=input_ids.device)
|
| 250 |
+
else:
|
| 251 |
+
batch_size, _, seq_length = input_ids.shape
|
| 252 |
+
text_ids, audio_ids = input_ids[:, 8, :], input_ids[:, :8, :]
|
| 253 |
+
if hasattr(past_key_values, 'layers'): past_key_values = None
|
| 254 |
+
n_thinker, n_talker = len(self.thinker.layers), len(self.talker.layers)
|
| 255 |
+
past_key_values = past_key_values or ([None] * (n_thinker + n_talker))
|
| 256 |
+
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
| 257 |
+
# Recompute RoPE buffers lost during meta-device init (transformers>=5.x)
|
| 258 |
+
if self.thinker.freqs_cos[0, 0] == 0:
|
| 259 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.config.head_dim, end=self.config.max_position_embeddings, rope_base=self.config.rope_theta, rope_scaling=self.config.rope_scaling)
|
| 260 |
+
self.thinker.freqs_cos, self.thinker.freqs_sin = freqs_cos.to(input_ids.device), freqs_sin.to(input_ids.device)
|
| 261 |
+
if self.talker.freqs_cos[0, 0] == 0:
|
| 262 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=self.talker.talker_config.head_dim, end=self.config.max_position_embeddings, rope_base=self.config.rope_theta, rope_scaling=self.config.rope_scaling)
|
| 263 |
+
self.talker.freqs_cos, self.talker.freqs_sin = freqs_cos.to(input_ids.device), freqs_sin.to(input_ids.device)
|
| 264 |
+
presents = []
|
| 265 |
+
|
| 266 |
+
# ======= Thinker: text-only input, output text logits =======
|
| 267 |
+
hidden_states = self.thinker.dropout(self.thinker.embed_tokens(text_ids))
|
| 268 |
+
position_embeddings = (self.thinker.freqs_cos[start_pos:start_pos + seq_length], self.thinker.freqs_sin[start_pos:start_pos + seq_length])
|
| 269 |
+
if audio_inputs is not None and start_pos == 0:
|
| 270 |
+
audio_features = self.encode_audio_inputs(audio_inputs, audio_lens)
|
| 271 |
+
hidden_states = self.inject_audio_features(text_ids, hidden_states, audio_features, seq_length)
|
| 272 |
+
if pixel_values is not None and start_pos == 0:
|
| 273 |
+
if hasattr(pixel_values, 'keys'):
|
| 274 |
+
img_emb = self.get_image_embeddings(pixel_values).to(hidden_states.dtype)
|
| 275 |
+
vision_tensors = self.vision_proj(img_emb)
|
| 276 |
+
else:
|
| 277 |
+
if len(pixel_values.shape) == 6:
|
| 278 |
+
pixel_values = pixel_values.squeeze(2)
|
| 279 |
+
if len(pixel_values.shape) == 4:
|
| 280 |
+
pixel_values = pixel_values.unsqueeze(1)
|
| 281 |
+
bs, num, c, im_h, im_w = pixel_values.shape
|
| 282 |
+
stack_dim = 1 if bs > 1 else 0
|
| 283 |
+
vision_tensors = torch.stack([
|
| 284 |
+
self.encode_image_inputs(pixel_values[:, i, :, :, :])
|
| 285 |
+
for i in range(num)
|
| 286 |
+
], dim=stack_dim)
|
| 287 |
+
hidden_states = self.count_vision_proj(tokens=text_ids, h=hidden_states, vision_tensors=vision_tensors, seqlen=seq_length)
|
| 288 |
+
bridge_states = hidden_states
|
| 289 |
+
for i, (layer, past_key_value) in enumerate(zip(self.thinker.layers, past_key_values[:n_thinker])):
|
| 290 |
+
hidden_states, present = layer(hidden_states, position_embeddings, past_key_value=past_key_value, use_cache=use_cache, attention_mask=attention_mask)
|
| 291 |
+
presents.append(present)
|
| 292 |
+
if i == self.config.bridge_layer: bridge_states = hidden_states
|
| 293 |
+
h_thinker = self.thinker.norm(hidden_states)
|
| 294 |
+
|
| 295 |
+
# ======= Talker: thinker hidden + audio codes, output audio logits =======
|
| 296 |
+
talker_emb = self.talker.embed_tokens(audio_ids)
|
| 297 |
+
spk_emb = args.get('spk_emb', None)
|
| 298 |
+
if spk_emb is not None:
|
| 299 |
+
spk_mask = (audio_ids[:, 0, :] == self.audio_spk_token).unsqueeze(-1)
|
| 300 |
+
talker_emb = torch.where(spk_mask, self.talker.spk_proj(spk_emb).unsqueeze(1), talker_emb)
|
| 301 |
+
hidden_states = self.talker.embed_proj(bridge_states) * self.talker.text_scale + self.talker.codec_proj(talker_emb) * self.talker.audio_scale
|
| 302 |
+
talker_pos_emb = (self.talker.freqs_cos[start_pos:start_pos + seq_length], self.talker.freqs_sin[start_pos:start_pos + seq_length])
|
| 303 |
+
for layer, past_key_value in zip(self.talker.layers, past_key_values[n_thinker:]):
|
| 304 |
+
hidden_states, present = layer(hidden_states, talker_pos_emb, past_key_value=past_key_value, use_cache=use_cache, attention_mask=attention_mask)
|
| 305 |
+
presents.append(present)
|
| 306 |
+
h_talker = self.talker.norm(hidden_states)
|
| 307 |
+
|
| 308 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 309 |
+
aux_loss = sum(l.mlp.aux_loss for l in list(self.thinker.layers) + list(self.talker.layers) if isinstance(l.mlp, MOEFeedForward))
|
| 310 |
+
aux_loss += sum(p.sum() for p in self.audio_proj.parameters()) * 0 + sum(p.sum() for p in self.vision_proj.parameters()) * 0 + sum(p.sum() for p in self.talker.lm_head.adapters.parameters()) * 0 + sum(p.sum() for p in self.talker.spk_proj.parameters()) * 0 # dummy gradient
|
| 311 |
+
text_logits = self.thinker.lm_head(h_thinker[:, slice_indices, :])
|
| 312 |
+
audio_logits = self.talker.lm_head(h_talker[:, slice_indices, :])
|
| 313 |
+
|
| 314 |
+
out = MoeCausalLMOutputWithPast(aux_loss=aux_loss, logits=text_logits, past_key_values=presents)
|
| 315 |
+
out.audio_logits = audio_logits
|
| 316 |
+
return out
|
| 317 |
+
|
| 318 |
+
@torch.inference_mode()
|
| 319 |
+
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
| 320 |
+
stream=False, rp=1., use_cache=True, return_audio_codes=False, **args):
|
| 321 |
+
if stream:
|
| 322 |
+
return self.stream_generate(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, return_audio_codes, **args)
|
| 323 |
+
tokens = list(self.stream_generate(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, return_audio_codes, **args))
|
| 324 |
+
return tokens[-1] if tokens else input_ids
|
| 325 |
+
|
| 326 |
+
def stream_generate(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, return_audio_codes=False, **args):
|
| 327 |
+
start_pos, past_kvs, text_finished, first_finished = input_ids.shape[1], None, False, True
|
| 328 |
+
audio_codes = [[] for _ in range(8)]
|
| 329 |
+
audio_stop_pos = [None] * 8
|
| 330 |
+
audio_buffer = torch.full((1, 8, start_pos), self.audio_pad_token, dtype=torch.long, device=input_ids.device)
|
| 331 |
+
spk_emb = args.get('spk_emb', None)
|
| 332 |
+
ref_codes = args.get('ref_codes', None)
|
| 333 |
+
ref_len = ref_codes.shape[2] if ref_codes is not None else 0
|
| 334 |
+
spk_reserve = 1 if spk_emb is not None else 0
|
| 335 |
+
fill_end = start_pos
|
| 336 |
+
fill_start = max(spk_reserve, start_pos - ref_len)
|
| 337 |
+
if ref_codes is not None and fill_start < fill_end:
|
| 338 |
+
audio_buffer[:, :, fill_start:fill_end] = ref_codes[:, :, -(fill_end - fill_start):]
|
| 339 |
+
if spk_emb is not None and fill_start > 0:
|
| 340 |
+
audio_buffer[:, :, fill_start - 1] = self.audio_spk_token
|
| 341 |
+
think_end_step, generated_tokens = None, ([] if args.get('open_thinking', False) else None)
|
| 342 |
+
while input_ids.shape[1] < start_pos + max_new_tokens:
|
| 343 |
+
if past_kvs is None or not use_cache:
|
| 344 |
+
out = self.forward(torch.cat((audio_buffer, input_ids.unsqueeze(1)), dim=1), past_key_values=past_kvs, use_cache=use_cache, **args)
|
| 345 |
+
else:
|
| 346 |
+
out = self.forward(torch.cat((audio_buffer[:, :, -1:], input_ids[:, -1:].unsqueeze(1)), dim=1), past_key_values=past_kvs, use_cache=use_cache, **args)
|
| 347 |
+
past_kvs = out.past_key_values
|
| 348 |
+
|
| 349 |
+
logits = out.logits[0, -1, :].clone() / (temperature + 1e-9)
|
| 350 |
+
logits[list(set(input_ids[0].tolist()))] /= rp
|
| 351 |
+
if top_p and top_p < 1.0:
|
| 352 |
+
sorted_l, sorted_i = torch.sort(logits, descending=True)
|
| 353 |
+
mask = torch.cumsum(F.softmax(sorted_l, dim=-1), dim=-1) > top_p
|
| 354 |
+
mask[1:], mask[0] = mask[:-1].clone(), False
|
| 355 |
+
logits[sorted_i[mask]] = -float('Inf')
|
| 356 |
+
text_token = torch.multinomial(F.softmax(logits, dim=-1), 1).item()
|
| 357 |
+
|
| 358 |
+
if text_finished:
|
| 359 |
+
text_token = args.get('enter_token_id', 201) if first_finished else args.get('pad_token_id', 0)
|
| 360 |
+
first_finished = False
|
| 361 |
+
|
| 362 |
+
step = input_ids.shape[1] - start_pos # 已生成token数(0=首次,此时模型处理prompt末尾token)
|
| 363 |
+
audio_step = step - 1 # 延迟1步:输出第1个text时无audio,输出第2个text时layer0开始
|
| 364 |
+
if generated_tokens is not None:
|
| 365 |
+
generated_tokens.append(text_token)
|
| 366 |
+
if not think_end_step and generated_tokens[-len(self.config.think_end_ids):] == list(self.config.think_end_ids): think_end_step = step + 2
|
| 367 |
+
audio_step = (step - think_end_step) if think_end_step else -1
|
| 368 |
+
for i, al in enumerate(out.audio_logits):
|
| 369 |
+
if audio_step < i:
|
| 370 |
+
audio_codes[i].append(self.audio_pad_token)
|
| 371 |
+
else:
|
| 372 |
+
logits_i = al[0, -1, :].clone() / 0.2
|
| 373 |
+
for prev_code in audio_codes[i][-3:]: logits_i[prev_code] /= 1.05
|
| 374 |
+
top_val, top_idx = logits_i.topk(50)
|
| 375 |
+
code = top_idx[torch.multinomial(F.softmax(top_val, dim=-1), 1)].item()
|
| 376 |
+
audio_codes[i].append(code)
|
| 377 |
+
if audio_stop_pos[i] is None and code >= 2048: audio_stop_pos[i] = len(audio_codes[i]) - 1
|
| 378 |
+
|
| 379 |
+
if text_finished and audio_codes[7][-1] == self.audio_stop_token: break
|
| 380 |
+
|
| 381 |
+
input_ids = torch.cat((input_ids, torch.tensor([[text_token]], device=input_ids.device)), dim=1)
|
| 382 |
+
audio_buffer = torch.cat((audio_buffer, torch.full((1, 8, 1), self.audio_pad_token, dtype=torch.long, device=input_ids.device)), dim=2)
|
| 383 |
+
for i in range(min(audio_step + 1, 8)): audio_buffer[0, i, -1] = audio_codes[i][-1]
|
| 384 |
+
|
| 385 |
+
audio_frame = None
|
| 386 |
+
if return_audio_codes and audio_step >= 7:
|
| 387 |
+
frame = [audio_codes[i][step - 7 + i] for i in range(8)]
|
| 388 |
+
active_layers = sum(1 for i in range(8) if audio_stop_pos[i] is None or step - 7 + i < audio_stop_pos[i])
|
| 389 |
+
if active_layers >= 8: audio_frame = frame
|
| 390 |
+
if not text_finished:
|
| 391 |
+
yield input_ids[:, start_pos:], audio_frame
|
| 392 |
+
if text_token == eos_token_id: text_finished = True
|
| 393 |
+
else:
|
| 394 |
+
yield None, audio_frame
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# ==== Realtime VAD (与模型本体零耦合,纯工程层) ====
|
| 398 |
+
class SileroVAD:
|
| 399 |
+
def __init__(self, path):
|
| 400 |
+
opts = ort.SessionOptions()
|
| 401 |
+
opts.inter_op_num_threads = opts.intra_op_num_threads = 1
|
| 402 |
+
opts.log_severity_level = 4
|
| 403 |
+
self.session = ort.InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts)
|
| 404 |
+
self.h, self.c = np.zeros((2, 1, 64), dtype=np.float32), np.zeros((2, 1, 64), dtype=np.float32)
|
| 405 |
+
|
| 406 |
+
def reset(self):
|
| 407 |
+
self.h[:], self.c[:] = 0, 0
|
| 408 |
+
|
| 409 |
+
def __call__(self, chunk, sr=16000):
|
| 410 |
+
out, self.h, self.c = self.session.run(None, {"input": chunk.reshape(1, -1).astype(np.float32), "h": self.h, "c": self.c, "sr": np.array(sr, dtype="int64")})
|
| 411 |
+
return float(out[0][0])
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class RealtimeSession:
|
| 415 |
+
def __init__(self, vad_path, sr=16000, threshold=0.8, min_speech_ms=128, min_silence_ms=800):
|
| 416 |
+
self.vad, self.sr, self.threshold = SileroVAD(vad_path), sr, threshold
|
| 417 |
+
self.min_speech, self.min_silence = int(sr * min_speech_ms / 1000), int(sr * min_silence_ms / 1000)
|
| 418 |
+
self.reset()
|
| 419 |
+
|
| 420 |
+
def reset(self):
|
| 421 |
+
self.vad.reset()
|
| 422 |
+
self.buffer, self.ring, self.speaking, self.generating, self.interrupt = [], [], False, False, False
|
| 423 |
+
self.speech_samples = self.silence_samples = self.tail_silence = 0
|
| 424 |
+
|
| 425 |
+
def push_chunk(self, chunk, W=1024):
|
| 426 |
+
for i in range(0, max(len(chunk), 1), W):
|
| 427 |
+
w = chunk[i:i + W]
|
| 428 |
+
if len(w) < W:
|
| 429 |
+
w = np.pad(w, (0, W - len(w)))
|
| 430 |
+
prob = self.vad(w, self.sr)
|
| 431 |
+
if prob > self.threshold:
|
| 432 |
+
self.silence_samples = self.tail_silence = 0
|
| 433 |
+
self.speech_samples += len(w)
|
| 434 |
+
self.buffer.append(w)
|
| 435 |
+
if self.speech_samples >= self.min_speech and not self.speaking:
|
| 436 |
+
self.speaking = True
|
| 437 |
+
self.buffer = self.ring + self.buffer
|
| 438 |
+
self.ring = []
|
| 439 |
+
if self.generating and self.speaking:
|
| 440 |
+
self.interrupt = True
|
| 441 |
+
return 'interrupt'
|
| 442 |
+
elif self.speaking:
|
| 443 |
+
self.silence_samples += len(w)
|
| 444 |
+
self.tail_silence += 1
|
| 445 |
+
self.buffer.append(w)
|
| 446 |
+
if self.silence_samples >= self.min_silence:
|
| 447 |
+
if self.tail_silence > 1:
|
| 448 |
+
del self.buffer[-(self.tail_silence - 1):]
|
| 449 |
+
self.speaking, self.speech_samples, self.silence_samples, self.tail_silence = False, 0, 0, 0
|
| 450 |
+
return 'speech_end'
|
| 451 |
+
else:
|
| 452 |
+
if self.speech_samples > 0:
|
| 453 |
+
self.buffer.clear()
|
| 454 |
+
self.speech_samples = 0
|
| 455 |
+
self.ring = [w]
|
| 456 |
+
return 'listening'
|
| 457 |
+
|
| 458 |
+
def get_audio(self):
|
| 459 |
+
audio = np.concatenate(self.buffer) if self.buffer else np.array([], dtype=np.float32)
|
| 460 |
+
self.buffer.clear()
|
| 461 |
+
return audio
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21530f9bbc540f461e2c0e29292ad359781d4d984d1e0c994510945f9b0edaab
|
| 3 |
+
size 226324754
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>",
|
| 16 |
+
"<|audio_start|>",
|
| 17 |
+
"<|audio_end|>",
|
| 18 |
+
"<|audio_pad|>",
|
| 19 |
+
"<tts_pad>",
|
| 20 |
+
"<tts_text_bos>",
|
| 21 |
+
"<tts_text_eod>",
|
| 22 |
+
"<tts_text_bos_single>"
|
| 23 |
+
],
|
| 24 |
+
"bos_token": {
|
| 25 |
+
"content": "<|im_start|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": false,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
},
|
| 31 |
+
"eos_token": {
|
| 32 |
+
"content": "<|im_end|>",
|
| 33 |
+
"lstrip": false,
|
| 34 |
+
"normalized": false,
|
| 35 |
+
"rstrip": false,
|
| 36 |
+
"single_word": false
|
| 37 |
+
},
|
| 38 |
+
"pad_token": {
|
| 39 |
+
"content": "<|endoftext|>",
|
| 40 |
+
"lstrip": false,
|
| 41 |
+
"normalized": false,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false
|
| 44 |
+
},
|
| 45 |
+
"unk_token": {
|
| 46 |
+
"content": "<|endoftext|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false
|
| 51 |
+
}
|
| 52 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": false,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<|endoftext|>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<|im_start|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "<|im_end|>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"3": {
|
| 31 |
+
"content": "<|object_ref_start|>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
},
|
| 38 |
+
"4": {
|
| 39 |
+
"content": "<|object_ref_end|>",
|
| 40 |
+
"lstrip": false,
|
| 41 |
+
"normalized": false,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false,
|
| 44 |
+
"special": true
|
| 45 |
+
},
|
| 46 |
+
"5": {
|
| 47 |
+
"content": "<|box_start|>",
|
| 48 |
+
"lstrip": false,
|
| 49 |
+
"normalized": false,
|
| 50 |
+
"rstrip": false,
|
| 51 |
+
"single_word": false,
|
| 52 |
+
"special": true
|
| 53 |
+
},
|
| 54 |
+
"6": {
|
| 55 |
+
"content": "<|box_end|>",
|
| 56 |
+
"lstrip": false,
|
| 57 |
+
"normalized": false,
|
| 58 |
+
"rstrip": false,
|
| 59 |
+
"single_word": false,
|
| 60 |
+
"special": true
|
| 61 |
+
},
|
| 62 |
+
"7": {
|
| 63 |
+
"content": "<|quad_start|>",
|
| 64 |
+
"lstrip": false,
|
| 65 |
+
"normalized": false,
|
| 66 |
+
"rstrip": false,
|
| 67 |
+
"single_word": false,
|
| 68 |
+
"special": true
|
| 69 |
+
},
|
| 70 |
+
"8": {
|
| 71 |
+
"content": "<|quad_end|>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": false,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false,
|
| 76 |
+
"special": true
|
| 77 |
+
},
|
| 78 |
+
"9": {
|
| 79 |
+
"content": "<|vision_start|>",
|
| 80 |
+
"lstrip": false,
|
| 81 |
+
"normalized": false,
|
| 82 |
+
"rstrip": false,
|
| 83 |
+
"single_word": false,
|
| 84 |
+
"special": true
|
| 85 |
+
},
|
| 86 |
+
"10": {
|
| 87 |
+
"content": "<|vision_end|>",
|
| 88 |
+
"lstrip": false,
|
| 89 |
+
"normalized": false,
|
| 90 |
+
"rstrip": false,
|
| 91 |
+
"single_word": false,
|
| 92 |
+
"special": true
|
| 93 |
+
},
|
| 94 |
+
"11": {
|
| 95 |
+
"content": "<|vision_pad|>",
|
| 96 |
+
"lstrip": false,
|
| 97 |
+
"normalized": false,
|
| 98 |
+
"rstrip": false,
|
| 99 |
+
"single_word": false,
|
| 100 |
+
"special": true
|
| 101 |
+
},
|
| 102 |
+
"12": {
|
| 103 |
+
"content": "<|image_pad|>",
|
| 104 |
+
"lstrip": false,
|
| 105 |
+
"normalized": false,
|
| 106 |
+
"rstrip": false,
|
| 107 |
+
"single_word": false,
|
| 108 |
+
"special": true
|
| 109 |
+
},
|
| 110 |
+
"13": {
|
| 111 |
+
"content": "<|video_pad|>",
|
| 112 |
+
"lstrip": false,
|
| 113 |
+
"normalized": false,
|
| 114 |
+
"rstrip": false,
|
| 115 |
+
"single_word": false,
|
| 116 |
+
"special": true
|
| 117 |
+
},
|
| 118 |
+
"14": {
|
| 119 |
+
"content": "<|audio_start|>",
|
| 120 |
+
"lstrip": false,
|
| 121 |
+
"normalized": false,
|
| 122 |
+
"rstrip": false,
|
| 123 |
+
"single_word": false,
|
| 124 |
+
"special": true
|
| 125 |
+
},
|
| 126 |
+
"15": {
|
| 127 |
+
"content": "<|audio_end|>",
|
| 128 |
+
"lstrip": false,
|
| 129 |
+
"normalized": false,
|
| 130 |
+
"rstrip": false,
|
| 131 |
+
"single_word": false,
|
| 132 |
+
"special": true
|
| 133 |
+
},
|
| 134 |
+
"16": {
|
| 135 |
+
"content": "<|audio_pad|>",
|
| 136 |
+
"lstrip": false,
|
| 137 |
+
"normalized": false,
|
| 138 |
+
"rstrip": false,
|
| 139 |
+
"single_word": false,
|
| 140 |
+
"special": true
|
| 141 |
+
},
|
| 142 |
+
"17": {
|
| 143 |
+
"content": "<tts_pad>",
|
| 144 |
+
"lstrip": false,
|
| 145 |
+
"normalized": false,
|
| 146 |
+
"rstrip": false,
|
| 147 |
+
"single_word": false,
|
| 148 |
+
"special": true
|
| 149 |
+
},
|
| 150 |
+
"18": {
|
| 151 |
+
"content": "<tts_text_bos>",
|
| 152 |
+
"lstrip": false,
|
| 153 |
+
"normalized": false,
|
| 154 |
+
"rstrip": false,
|
| 155 |
+
"single_word": false,
|
| 156 |
+
"special": true
|
| 157 |
+
},
|
| 158 |
+
"19": {
|
| 159 |
+
"content": "<tts_text_eod>",
|
| 160 |
+
"lstrip": false,
|
| 161 |
+
"normalized": false,
|
| 162 |
+
"rstrip": false,
|
| 163 |
+
"single_word": false,
|
| 164 |
+
"special": true
|
| 165 |
+
},
|
| 166 |
+
"20": {
|
| 167 |
+
"content": "<tts_text_bos_single>",
|
| 168 |
+
"lstrip": false,
|
| 169 |
+
"normalized": false,
|
| 170 |
+
"rstrip": false,
|
| 171 |
+
"single_word": false,
|
| 172 |
+
"special": true
|
| 173 |
+
},
|
| 174 |
+
"21": {
|
| 175 |
+
"content": "<tool_call>",
|
| 176 |
+
"lstrip": false,
|
| 177 |
+
"normalized": false,
|
| 178 |
+
"rstrip": false,
|
| 179 |
+
"single_word": false,
|
| 180 |
+
"special": false
|
| 181 |
+
},
|
| 182 |
+
"22": {
|
| 183 |
+
"content": "</tool_call>",
|
| 184 |
+
"lstrip": false,
|
| 185 |
+
"normalized": false,
|
| 186 |
+
"rstrip": false,
|
| 187 |
+
"single_word": false,
|
| 188 |
+
"special": false
|
| 189 |
+
},
|
| 190 |
+
"23": {
|
| 191 |
+
"content": "<tool_response>",
|
| 192 |
+
"lstrip": false,
|
| 193 |
+
"normalized": false,
|
| 194 |
+
"rstrip": false,
|
| 195 |
+
"single_word": false,
|
| 196 |
+
"special": false
|
| 197 |
+
},
|
| 198 |
+
"24": {
|
| 199 |
+
"content": "</tool_response>",
|
| 200 |
+
"lstrip": false,
|
| 201 |
+
"normalized": false,
|
| 202 |
+
"rstrip": false,
|
| 203 |
+
"single_word": false,
|
| 204 |
+
"special": false
|
| 205 |
+
},
|
| 206 |
+
"25": {
|
| 207 |
+
"content": "<think>",
|
| 208 |
+
"lstrip": false,
|
| 209 |
+
"normalized": false,
|
| 210 |
+
"rstrip": false,
|
| 211 |
+
"single_word": false,
|
| 212 |
+
"special": false
|
| 213 |
+
},
|
| 214 |
+
"26": {
|
| 215 |
+
"content": "</think>",
|
| 216 |
+
"lstrip": false,
|
| 217 |
+
"normalized": false,
|
| 218 |
+
"rstrip": false,
|
| 219 |
+
"single_word": false,
|
| 220 |
+
"special": false
|
| 221 |
+
},
|
| 222 |
+
"27": {
|
| 223 |
+
"content": "<|buffer1|>",
|
| 224 |
+
"lstrip": false,
|
| 225 |
+
"normalized": false,
|
| 226 |
+
"rstrip": false,
|
| 227 |
+
"single_word": false,
|
| 228 |
+
"special": false
|
| 229 |
+
},
|
| 230 |
+
"28": {
|
| 231 |
+
"content": "<|buffer2|>",
|
| 232 |
+
"lstrip": false,
|
| 233 |
+
"normalized": false,
|
| 234 |
+
"rstrip": false,
|
| 235 |
+
"single_word": false,
|
| 236 |
+
"special": false
|
| 237 |
+
},
|
| 238 |
+
"29": {
|
| 239 |
+
"content": "<|buffer3|>",
|
| 240 |
+
"lstrip": false,
|
| 241 |
+
"normalized": false,
|
| 242 |
+
"rstrip": false,
|
| 243 |
+
"single_word": false,
|
| 244 |
+
"special": false
|
| 245 |
+
},
|
| 246 |
+
"30": {
|
| 247 |
+
"content": "<|buffer4|>",
|
| 248 |
+
"lstrip": false,
|
| 249 |
+
"normalized": false,
|
| 250 |
+
"rstrip": false,
|
| 251 |
+
"single_word": false,
|
| 252 |
+
"special": false
|
| 253 |
+
},
|
| 254 |
+
"31": {
|
| 255 |
+
"content": "<|buffer5|>",
|
| 256 |
+
"lstrip": false,
|
| 257 |
+
"normalized": false,
|
| 258 |
+
"rstrip": false,
|
| 259 |
+
"single_word": false,
|
| 260 |
+
"special": false
|
| 261 |
+
},
|
| 262 |
+
"32": {
|
| 263 |
+
"content": "<|buffer6|>",
|
| 264 |
+
"lstrip": false,
|
| 265 |
+
"normalized": false,
|
| 266 |
+
"rstrip": false,
|
| 267 |
+
"single_word": false,
|
| 268 |
+
"special": false
|
| 269 |
+
},
|
| 270 |
+
"33": {
|
| 271 |
+
"content": "<|buffer7|>",
|
| 272 |
+
"lstrip": false,
|
| 273 |
+
"normalized": false,
|
| 274 |
+
"rstrip": false,
|
| 275 |
+
"single_word": false,
|
| 276 |
+
"special": false
|
| 277 |
+
},
|
| 278 |
+
"34": {
|
| 279 |
+
"content": "<|buffer8|>",
|
| 280 |
+
"lstrip": false,
|
| 281 |
+
"normalized": false,
|
| 282 |
+
"rstrip": false,
|
| 283 |
+
"single_word": false,
|
| 284 |
+
"special": false
|
| 285 |
+
},
|
| 286 |
+
"35": {
|
| 287 |
+
"content": "<|buffer9|>",
|
| 288 |
+
"lstrip": false,
|
| 289 |
+
"normalized": false,
|
| 290 |
+
"rstrip": false,
|
| 291 |
+
"single_word": false,
|
| 292 |
+
"special": false
|
| 293 |
+
}
|
| 294 |
+
},
|
| 295 |
+
"additional_special_tokens": [
|
| 296 |
+
"<|im_start|>",
|
| 297 |
+
"<|im_end|>",
|
| 298 |
+
"<|object_ref_start|>",
|
| 299 |
+
"<|object_ref_end|>",
|
| 300 |
+
"<|box_start|>",
|
| 301 |
+
"<|box_end|>",
|
| 302 |
+
"<|quad_start|>",
|
| 303 |
+
"<|quad_end|>",
|
| 304 |
+
"<|vision_start|>",
|
| 305 |
+
"<|vision_end|>",
|
| 306 |
+
"<|vision_pad|>",
|
| 307 |
+
"<|image_pad|>",
|
| 308 |
+
"<|video_pad|>",
|
| 309 |
+
"<|audio_start|>",
|
| 310 |
+
"<|audio_end|>",
|
| 311 |
+
"<|audio_pad|>",
|
| 312 |
+
"<tts_pad>",
|
| 313 |
+
"<tts_text_bos>",
|
| 314 |
+
"<tts_text_eod>",
|
| 315 |
+
"<tts_text_bos_single>"
|
| 316 |
+
],
|
| 317 |
+
"audio_bos_token": "<|audio_start|>",
|
| 318 |
+
"audio_eos_token": "<|audio_end|>",
|
| 319 |
+
"audio_token": "<|audio_pad|>",
|
| 320 |
+
"bos_token": "<|im_start|>",
|
| 321 |
+
"clean_up_tokenization_spaces": false,
|
| 322 |
+
"eos_token": "<|im_end|>",
|
| 323 |
+
"extra_special_tokens": {},
|
| 324 |
+
"image_token": "<|image_pad|>",
|
| 325 |
+
"legacy": true,
|
| 326 |
+
"model_max_length": 131072,
|
| 327 |
+
"pad_token": "<|endoftext|>",
|
| 328 |
+
"sp_model_kwargs": {},
|
| 329 |
+
"spaces_between_special_tokens": false,
|
| 330 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 331 |
+
"unk_token": "<|endoftext|>",
|
| 332 |
+
"video_token": "<|video_pad|>",
|
| 333 |
+
"vision_bos_token": "<|vision_start|>",
|
| 334 |
+
"vision_eos_token": "<|vision_end|>"
|
| 335 |
+
}
|