File size: 15,328 Bytes
4f07533 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 | # Copyright 2025 TeleAI Rhodes Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration classes for PRTS built on Qwen3-VL."""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig
class PRTS_Qwen3VLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a PRTS Text Model based on Qwen3-VL.
It extends PretrainedConfig with Qwen3-VL text model parameters and PRTS-specific parameters.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3VL model.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer.
num_key_value_heads (`int`, *optional*, defaults to 32):
Number of key-value heads for Grouped Query Attention.
head_dim (`int`, *optional*, defaults to 128):
The dimension of the head.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
image_token_id (`int`, *optional*):
Token index used as placeholder for image embeddings.
video_token_id (`int`, *optional*):
Token index used as placeholder for video embeddings.
action_token_id (`int`, *optional*):
Token index used as placeholder for action embeddings.
action_start_token_id (`int`, *optional*):
Token index for action sequence start.
action_end_token_id (`int`, *optional*):
Token index for action sequence end.
vision_start_token_id (`int`, *optional*):
Token index for vision sequence start.
**kwargs:
Additional keyword arguments passed to PretrainedConfig.
"""
model_type = "prts_qwen3_vl_text" # TODO (zy): check if this is correct
base_config_key = "text_config"
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=128000,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=5000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
# PRTS specific
action_token_id=None,
action_start_token_id=None,
action_end_token_id=None,
crl_goal_repr_token_id=None,
crl_obs_repr_token_id=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate rope config
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
# PRTS specific token IDs
self.action_token_id = action_token_id
self.action_start_token_id = action_start_token_id
self.action_end_token_id = action_end_token_id
self.crl_goal_repr_token_id = crl_goal_repr_token_id
self.crl_obs_repr_token_id = crl_obs_repr_token_id
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class PRTS_FlowMatchingConfig_Qwen3VL(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a PRTS model based on Qwen3-VL.
It extends PretrainedConfig with Qwen3-VL model parameters and PRTS-specific parameters for action prediction.
[`PRTS_FlowMatchingConfig_Qwen3VL`] is the configuration class to store the configuration of a PRTS model. It is used to
instantiate a PRTS model according to the specified arguments, defining the vision encoder, text encoder,
action expert, and flow matching components.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PRTS_Qwen3VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
The config object or dictionary of the vision backbone.
max_action_dim (`int`, *optional*, defaults to 14):
Maximum dimension of action vectors. Used for padding different robot action spaces.
action_chunk_size (`int`, *optional*, defaults to 100):
Number of action timesteps to predict in each forward pass.
num_denoise_steps (`int`, *optional*, defaults to 4):
Number of denoising steps for flow matching during inference.
flow_matching_action_loss_weight (`float`, *optional*, defaults to 1.0):
Weight for the flow matching action loss.
crl_loss_weight (`float`, *optional*, defaults to 0.0):
Weight for the Contrastive Reinforcement Learning (CRL) loss. Set to 0 to disable.
crl_embed_dim (`int`, *optional*, defaults to 256):
Dimension of the CRL embedding space for action and goal encoders.
crl_logsumexp_reg_weight (`float`, *optional*, defaults to 0.0):
Weight for logsumexp regularization on CRL logits.
image_token_id (`int`, *optional*):
Token id for image placeholders.
video_token_id (`int`, *optional*):
Token id for video placeholders.
vision_start_token_id (`int`, *optional*):
Token id for vision start marker.
vision_end_token_id (`int`, *optional*):
Token id for vision end marker.
**kwargs:
Additional keyword arguments passed to PretrainedConfig.
Example:
```python
>>> from prts.models import PRTS_FlowMatchingConfig_Qwen3VL, PRTS_Qwen3VL
>>> # Initializing a PRTS Qwen3-VL configuration
>>> configuration = PRTS_FlowMatchingConfig_Qwen3VL()
>>> # Initializing a model from the configuration
>>> model = PRTS_Qwen3VL(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "prts_qwen3_vl"
sub_configs = {
"vision_config": Qwen3VLVisionConfig,
"text_config": PRTS_Qwen3VLTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
# PRTS specific
max_action_dim=32,
action_chunk_size=50,
num_denoise_steps=4,
flow_matching_action_loss_weight=0.,
use_fast_action_tokenizer=True,
# Embodiment tag: identifies the robot embodiment used for finetuning.
# Stores the delta_action_mask key so eval code can recover it without
# needing the training dataset config.
embodiment_tag=None,
# DiT action head config
dit_action_head_config=None,
# CRL (Contrastive Reinforcement Learning) parameters
crl_loss_weight=0.,
crl_embed_dim=256,
crl_logsumexp_reg_weight=0.0,
crl_encoder_init_w=1e-12, # Cold initialization weight for encoder last layer
crl_repr_norm=True, # Whether to L2-normalize CRL representations
**kwargs,
):
# Initialize vision config
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
# Initialize text config
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
# For BC use all kwargs to init `TextConfig`
self.text_config = self.sub_configs["text_config"](**kwargs)
# PRTS-specific parameters
self.max_action_dim = max_action_dim
self.action_chunk_size = action_chunk_size
self.num_denoise_steps = num_denoise_steps
self.flow_matching_action_loss_weight = flow_matching_action_loss_weight
self.use_fast_action_tokenizer = use_fast_action_tokenizer
self.embodiment_tag = embodiment_tag
# DiT action head config (nested dict)
# cross_attention_dim defaults to text_config.hidden_size at model init time
_default_dit_config = {
# Architecture — aligned with GR00T N1.6 (32 layers, inner_dim=32×48=1536)
"num_layers": 16, # 32
"num_attention_heads": 32,
"attention_head_dim": 48,
"output_dim": 1024,
# Regularisation
"dropout": 0.2,
"interleave_self_attention": True,
"norm_type": "ada_norm",
"final_dropout": True,
# Action-head specifics
"add_pos_embed": True,
# Noise schedule
"noise_beta_alpha": 1.5,
"noise_beta_beta": 1.0,
"noise_s": 0.999,
"num_timestep_buckets": 1000,
# Attention backend
"attn_implementation": "sdpa",
# AlternateVLDiT — separate visual / text token cross-attention
"use_alternate_vl_dit": True,
"attend_text_every_n_blocks": 2,
# MoT-style action expert: forwards full VLM ``past_key_values`` into the head;
# expert depth defaults to text_config.num_hidden_layers (override with expert_num_layers).
"use_mot_action_expert": False,
"mlp_mult": 4, # FFN hidden dim = inner_dim * mlp_mult (standard DiT only)
}
if dit_action_head_config is not None:
_default_dit_config.update(dit_action_head_config)
self.dit_action_head_config = _default_dit_config
# CRL (Contrastive Reinforcement Learning) parameters
self.crl_loss_weight = crl_loss_weight
self.crl_embed_dim = crl_embed_dim
self.crl_logsumexp_reg_weight = crl_logsumexp_reg_weight
self.crl_encoder_init_w = crl_encoder_init_w
self.crl_repr_norm = crl_repr_norm
# Token IDs
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
# # Propagate token IDs to text config
# if self.image_token_id is not None:
# self.text_config.image_token_id = self.image_token_id
# if self.video_token_id is not None:
# self.text_config.video_token_id = self.video_token_id
# if self.vision_start_token_id is not None:
# self.text_config.vision_start_token_id = self.vision_start_token_id
# Ensure vocab sizes are consistent
# if hasattr(self.text_config, 'vocab_size'):
# self.vocab_size = self.text_config.vocab_size
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
# TODO (zy): 这里需要看下是不是在VLConfig传入这些state action的特殊token更合适更灵活
@property
def action_token_id(self):
"""Get action token id from text config."""
return getattr(self.text_config, 'action_token_id', None)
@action_token_id.setter
def action_token_id(self, value):
"""Set action token id in text config."""
if hasattr(self.text_config, 'action_token_id'):
self.text_config.action_token_id = value
def __getattribute__(self, key):
if "text_config" in super().__getattribute__("__dict__") and key not in [
"dtype",
"_attn_implementation_internal",
]:
text_config = super().__getattribute__("text_config")
if key in text_config.__dict__:
return getattr(text_config, key)
return super().__getattribute__(key)
PRTS_FlowMatchingConfig_Qwen3VL.register_for_auto_class()
__all__ = ["PRTS_FlowMatchingConfig_Qwen3VL", "PRTS_Qwen3VLTextConfig"]
|