# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings from typing import List, Optional, Tuple, Union import torch import transformers from torch import nn from torch.nn import CrossEntropyLoss from transformers import AutoModel, AutoModelForCausalLM, GenerationConfig from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from .configuration import NemotronH_Nano_Omni_Reasoning_V3_Config from .modeling_nemotron_h import NemotronHForCausalLM from .evs import EfficientVideoSampling from .audio_model import SoundEncoder, SoundProjection logger = logging.get_logger(__name__) """ The following code is adapted from the https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B/blob/main/modeling_internvl_chat.py repository The chat function is adapted to handle NVLM 1-D tile-tagging design for dynamic high-resolution images. """ class SquaredReLU(nn.Module): def forward(self, x): return torch.pow(torch.nn.functional.relu(x), 2) class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) def version_cmp(v1, v2, op='eq'): import operator from packaging import version op_func = getattr(operator, op) return op_func(version.parse(v1), version.parse(v2)) class NemotronH_Nano_Omni_Reasoning_V3(PreTrainedModel): config_class = NemotronH_Nano_Omni_Reasoning_V3_Config main_input_name = 'pixel_values' _supports_flash_attn_2 = True _supports_flash_attn = True _no_split_modules = ['NemotronHBlock'] def __init__(self, config: NemotronH_Nano_Omni_Reasoning_V3_Config): super().__init__(config) assert version_cmp(transformers.__version__, '4.36.2', 'ge') image_size = config.force_image_size patch_size = config.patch_size self.patch_size = patch_size self.template = config.template self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.image_tag_type = config.image_tag_type self.img_context_token_id = config.img_context_token_id self.video_context_token_id = config.video_context_token_id logger.info(f'num_image_token: {self.num_image_token}') logger.info(f'ps_version: {self.ps_version}') # Instantiate LM directly to avoid Hugging Face dynamic module lookup requiring a repo id. self.language_model = NemotronHForCausalLM(config.llm_config) self.vision_model = AutoModel.from_config(config.vision_config, trust_remote_code=True) self.vision_model.model._initialize_weights = self.vision_model.model._init_weights # WAR for transformers issue 38358 self.vision_model.radio_model.make_preprocessor_external() # Attach a separate 3D patch projection for video frames. The RADIO ViT ships with only a 2D # `embedder` (shape `[embed_dim, C·P²]`); this repo's checkpoint also carries a # `video_embedder` (shape `[embed_dim, T·C·P²]`) used for temporally-packed video patches, # so we construct the module here to make the weight bind. `T = video_temporal_patch_size` # is the number of frames collapsed into each temporal patch. self.video_temporal_patch_dim = config.video_temporal_patch_size pg = self.vision_model.radio_model.model.patch_generator pg.video_embedder = nn.Linear( in_features=self.video_temporal_patch_dim * 3 * pg.patch_size * pg.patch_size, out_features=pg.embed_dim, bias=False, ) # Align CPE position-embedding interpolation with Megatron training + vLLM inference. # The `nvidia/C-RADIOv2-H` remote code uses `align_corners=True` in eval mode, but the V3 # checkpoint was trained against `align_corners=False` (see Megatron's `radio.py`). That # single-flag mismatch shifts every pos_embed by a fraction of a cell, which compounds # through 52 ViT layers and is the main cause of HF/vLLM divergence for video (where CPE # mode is active — dynamic-res tubelets don't match the model's native 2048-sized grid). self._patch_cpe_align_corners(pg) self.vision_model = self.vision_model.to(self.language_model.config.torch_dtype) self.drop_vision_class_token = True # Construct the vision projection. # Default vit_hidden_size = config.vit_hidden_size vision_projection_hidden_size = config.projector_hidden_size llm_hidden_size = config.llm_config.hidden_size self.video_pruning_rate = config.video_pruning_rate self.mlp1 = nn.Sequential( RMSNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, eps=1e-5), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, vision_projection_hidden_size, bias=False), SquaredReLU(), nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False) ) self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) # Sound/audio model components (optional - only if sound_config is provided) self.sound_context_token_id = getattr(config, 'sound_context_token_id', None) if config.sound_config is not None: sound_config = config.sound_config sound_hidden_size = sound_config.hidden_size sound_projection_hidden_size = sound_config.projection_hidden_size # Initialize sound feature extractor for converting raw audio to mel spectrograms from transformers import ParakeetFeatureExtractor sampling_rate = getattr(sound_config, 'sampling_rate', 16000) feature_size = getattr(sound_config, 'num_mel_bins', 128) self.sound_feature_extractor = ParakeetFeatureExtractor( sampling_rate=sampling_rate, feature_size=feature_size, ) logger.info(f'Sound feature extractor initialized with sampling_rate={sampling_rate}, feature_size={feature_size}') # Initialize sound encoder - wraps Parakeet from transformers self.sound_encoder = SoundEncoder(config=sound_config) self.sound_encoder = self.sound_encoder.to(self.language_model.config.torch_dtype) # Initialize sound projection MLP self.sound_projection = SoundProjection( sound_hidden_size=sound_hidden_size, projection_hidden_size=sound_projection_hidden_size, llm_hidden_size=llm_hidden_size, bias=sound_config.projection_bias, ) self.sound_projection = self.sound_projection.to(self.language_model.config.torch_dtype) logger.info(f'Sound model initialized with hidden_size={sound_hidden_size}') else: self.sound_encoder = None self.sound_projection = None self.sound_feature_extractor = None self.all_tied_weights_keys = {} def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, inputs_embeds = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) image_flags = image_flags.squeeze(-1) B, N, C = inputs_embeds.shape inputs_embeds = inputs_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) vit_batch_size = pixel_values.shape[0] vit_embeds = self.extract_feature(pixel_values) del pixel_values if torch.distributed.get_rank() == 0: print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') vit_embeds = vit_embeds[image_flags == 1] try: inputs_embeds[selected] = inputs_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, inputs_embeds[selected].shape={inputs_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = selected.sum() inputs_embeds[selected] = inputs_embeds[selected] * 0.0 + vit_embeds[:n_token] del vit_embeds inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @staticmethod def _patch_cpe_align_corners(patch_generator) -> None: """Monkey-patch `patch_generator._get_pos_embeddings` so the CPE-mode eval-path interpolation uses `align_corners=False` (Megatron training + vLLM inference convention) instead of the `align_corners=True` that the `nvidia/C-RADIOv2-H` remote code ships with. """ import math import torch.nn.functional as F orig_method = patch_generator._get_pos_embeddings.__func__ if hasattr( patch_generator._get_pos_embeddings, "__func__" ) else patch_generator._get_pos_embeddings def _get_pos_embeddings_aligned(self, batch_size, input_dims): if (self.num_rows, self.num_cols) == input_dims: return self.pos_embed pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2) def window_select(pe): if input_dims[0] < pe.shape[-2]: pe = pe[..., :input_dims[0], :] if input_dims[1] < pe.shape[-1]: pe = pe[..., :, :input_dims[1]] return pe if self.cpe_mode: if self.training: # Keep the original training-time jitter path (grid_sample + align_corners=True); # only patch the eval branch, which is what Megatron/vLLM use and where the bug is. return orig_method(self, batch_size, input_dims) max_dim = max(input_dims) pos_embed = F.interpolate( pos_embed.float(), size=(max_dim, max_dim), align_corners=False, mode="bilinear" ).to(pos_embed.dtype) pos_embed = window_select(pos_embed) else: pos_embed = window_select(pos_embed) if pos_embed.shape[-2:] != input_dims: pos_embed = F.interpolate( pos_embed.float(), size=input_dims, align_corners=False, mode="bilinear" ).to(pos_embed.dtype) pos_embed = pos_embed.flatten(2).permute(0, 2, 1) return pos_embed import types patch_generator._get_pos_embeddings = types.MethodType(_get_pos_embeddings_aligned, patch_generator) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) if self.ps_version == 'v1': warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " 'which results in a transposed image.') else: x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values): """Run the ViT on a batch of image tiles. Handles two layouts: - A single 4D tensor `(B, 3, H, W)` with all tiles sharing the same spatial size (legacy fixed-tile path **or** dynamic-resolution path when every image in the batch resizes to the same target). - A list of 4D tensors `[(1, 3, H_i, W_i), …]` when dynamic resolution picks different target sizes per image. Each is run through the ViT independently and the output tokens are concatenated along the sequence dim. The patch grid `(h, w)` is computed from the actual input shape, not assumed square — this is required for dynamic resolution where the tile aspect ratio matches the original image. """ if isinstance(pixel_values, (list, tuple)): outs = [self._extract_feature_single(pv) for pv in pixel_values] return torch.cat(outs, dim=0) return self._extract_feature_single(pixel_values) def _extract_feature_single(self, pixel_values): vit_embeds = self.vision_model(pixel_values).features vit_embeds = vit_embeds.to(dtype=torch.bfloat16) # Compute patch grid from the input tile dims; pixel-shuffle needs the real (h, w). patch_size = self.vision_model.radio_model.model.patch_generator.patch_size B, _, H, W = pixel_values.shape h = H // patch_size w = W // patch_size vit_embeds = vit_embeds.reshape(B, h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(B, -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def extract_video_feature(self, pixel_values_videos): """ Extract features from video frames using the 3D `video_embedder`. Consecutive `T = video_temporal_patch_dim` frames are packed into a single temporal patch before the ViT, so the output has `N_frames // T` temporal units (each with the usual number of spatial tokens) instead of one ViT output per frame. Implementation trick: RADIO's patch_generator uses a channel-agnostic `Im2Patches` rearrange followed by `self.embedder(patches)`. If we stack the T temporal frames into the channel dim — `(N_frames, C, H, W)` → `(N_frames/T, T·C, H, W)` — the rearrange produces patches of shape `(·, num_patches, T·C·P²)`, which is exactly what `video_embedder` expects. Temporarily swapping `embedder ↔ video_embedder` lets us reuse the full ViT forward without duplicating the transformer blocks, pos-embed handling, cls_token, etc. """ pg = self.vision_model.radio_model.model.patch_generator T = self.video_temporal_patch_dim N, C, H, W = pixel_values_videos.shape # Pad to a multiple of T by repeating the last frame so frame pairs align cleanly. if N % T != 0: pad = pixel_values_videos[-1:].expand(T - (N % T), -1, -1, -1) pixel_values_videos = torch.cat([pixel_values_videos, pad], dim=0) N = pixel_values_videos.shape[0] num_groups = N // T # Stack T frames into the channel dim. `.view` here preserves the (frame,channel) row-major # layout → per-patch feature order is [t=0,c=0..C-1, t=1,c=0..C-1, ...], matching how the # `video_embedder` weights are stored in the checkpoint. x = pixel_values_videos.reshape(num_groups, T * C, H, W) orig_embedder = pg.embedder pg.embedder = pg.video_embedder try: vit_embeds = self.vision_model(x).features finally: pg.embedder = orig_embedder # Same spatial post-processing as `extract_feature`. Compute `(h, w)` from the reshaped # input so dynamic-res video frames (non-square patch grid) are handled correctly. vit_embeds = vit_embeds.to(dtype=torch.bfloat16) patch_size = pg.patch_size h = H // patch_size w = W // patch_size vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def extract_sound_feature( self, input_features: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Extract and project sound features from audio input. Args: input_features: Mel spectrogram features [batch, seq_len, feature_dim] attention_mask: Optional attention mask [batch, seq_len] Returns: Sound embeddings projected to LLM hidden size [batch, encoded_seq_len, llm_hidden_size] """ if self.sound_encoder is None: raise RuntimeError("Sound encoder not initialized. Check if sound_config is provided.") # Encode audio features sound_embeds = self.sound_encoder(input_features, attention_mask) sound_embeds = sound_embeds.to(dtype=torch.bfloat16) # Project to LLM hidden size sound_embeds = self.sound_projection(sound_embeds) return sound_embeds @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, sound_clips: Optional[torch.FloatTensor] = None, sound_length: Optional[torch.Tensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: """Generate text given images, videos, and/or audio. Args: pixel_values: Image pixel values [num_tiles, C, H, W] pixel_values_videos: Video pixel values [num_frames, C, H, W] sound_clips: Raw audio waveforms. Can be: - A list of numpy arrays or torch tensors (one per audio clip) - A single numpy array or torch tensor for a single audio clip - Pre-extracted mel spectrogram features [batch, seq_len, num_mel_bins] sound_length: Length of each audio clip in samples (optional, used for batched audio) input_ids: Input token IDs [batch, seq_len] attention_mask: Attention mask [batch, seq_len] generation_config: Generation configuration output_hidden_states: Whether to output hidden states return_dict: Whether to return a dict **generate_kwargs: Additional generation arguments Returns: Generated token IDs """ assert self.img_context_token_id is not None has_images = pixel_values is not None has_videos = pixel_values_videos is not None has_sound = sound_clips is not None and self.sound_encoder is not None if has_images or has_videos or has_sound: image_vit_embeds, video_vit_embeds, sound_embeds = None, None, None # Process images if has_images: pixel_values = pixel_values.to(dtype=self.vision_model.config.torch_dtype) image_vit_embeds = self.extract_feature(pixel_values) # Process videos if has_videos: pixel_values_videos = pixel_values_videos.to(dtype=self.vision_model.config.torch_dtype) video_vit_embeds = self.extract_video_feature(pixel_values_videos) # Process sound/audio if has_sound: # Extract features from raw audio using the feature extractor # Handle different input types: # - list/tuple of waveforms # - 1D tensor/array (single waveform) # - 2D tensor [batch, samples] (batched raw waveforms) # - 3D tensor [batch, seq_len, num_mel_bins] (pre-extracted features) import numpy as np is_raw_waveform = False if isinstance(sound_clips, (list, tuple)): # List of audio clips (waveforms) is_raw_waveform = True waveforms = sound_clips elif isinstance(sound_clips, np.ndarray): # Numpy array - raw waveform is_raw_waveform = True waveforms = [sound_clips.squeeze()] if sound_clips.ndim > 1 else [sound_clips] elif isinstance(sound_clips, torch.Tensor): if sound_clips.dim() == 1: # 1D tensor - single raw waveform is_raw_waveform = True waveforms = [sound_clips.cpu().numpy()] elif sound_clips.dim() == 2: # 2D tensor [batch, samples] - batched raw waveforms is_raw_waveform = True waveforms = [clip.cpu().numpy() for clip in sound_clips] else: # 3D tensor [batch, seq_len, num_mel_bins] - pre-extracted features is_raw_waveform = False else: is_raw_waveform = False if is_raw_waveform: # Convert raw waveforms to mel spectrogram features audio_inputs = self.sound_feature_extractor( waveforms, sampling_rate=self.sound_feature_extractor.sampling_rate, return_tensors="pt", ) sound_input_features = audio_inputs.input_features sound_attention_mask = audio_inputs.get("attention_mask", None) else: # Already extracted features sound_input_features = sound_clips sound_attention_mask = None # Move to correct device and dtype target_device = self.sound_encoder.encoder.subsampling.linear.weight.device target_dtype = self.language_model.config.torch_dtype sound_input_features = sound_input_features.to(dtype=target_dtype, device=target_device) if sound_attention_mask is not None: sound_attention_mask = sound_attention_mask.to(device=target_device) sound_embeds = self.extract_sound_feature(sound_input_features, sound_attention_mask) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape inputs_embeds = inputs_embeds.reshape(B * N, C) input_ids_copy = input_ids.reshape(B * N) # Replace image tokens with image embeddings if image_vit_embeds is not None: image_mask = (input_ids_copy == self.img_context_token_id) assert image_mask.sum() != 0, "No image tokens found in input_ids" inputs_embeds[image_mask] = image_vit_embeds.reshape(-1, C).to(inputs_embeds.device, inputs_embeds.dtype) # Replace video tokens with video embeddings. The tokenizer has no distinct `