cond_gen / modelling_qwen3.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
import torch
import torch.nn.functional as F
from typing import Any, Optional
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen3.modeling_qwen3 import Qwen3Config, Qwen3ForCausalLM
from transformers.cache_utils import Cache
from decoders import Qwen3DecoderLayerAdaLN
from condition_encoders import ConditionEncoder
from vocab import (
CHORD_BOS_ID,
CHORD_EOS_ID,
CHORD_N_ID,
SEGMENT_FALLBACK_ID,
STRUCTURE_BOS_ID,
STRUCTURE_EOS_ID,
)
class MAGEL(Qwen3ForCausalLM):
"""
- masks-based CE loss
- decoder layers replaced with Qwen3DecoderLayerAdaLN
"""
def __init__(
self,
config: Qwen3Config,
**kwargs: Any,
):
super().__init__(config)
adaln_dim = int(config.hidden_size)
chord_dropout_trigger_prob = float(config.magel_chord_dropout_trigger_prob)
structure_dropout_trigger_prob = float(config.magel_structure_dropout_trigger_prob)
self.vocab_size = config.vocab_size
self.adaln_dim = adaln_dim
self.condition_encoder = ConditionEncoder(hidden_size=adaln_dim)
self.chord_dropout_trigger_prob = chord_dropout_trigger_prob
self.structure_dropout_trigger_prob = structure_dropout_trigger_prob
for layer_idx in range(len(self.model.layers)):
self.model.layers[layer_idx] = Qwen3DecoderLayerAdaLN(
config,
layer_idx=layer_idx,
cond_dim=adaln_dim,
)
# Persist MAGEL-specific ctor args so checkpoints can be reloaded without
# out-of-band flags.
self.config.magel_chord_dropout_trigger_prob = chord_dropout_trigger_prob
self.config.magel_structure_dropout_trigger_prob = structure_dropout_trigger_prob
self.post_init()
@staticmethod
def _drop_audio_condition_spans(
ids: torch.LongTensor,
condition_mask: torch.BoolTensor,
trigger_prob: float,
replacement_id: int,
bos_id: int,
eos_id: int,
) -> torch.LongTensor:
if trigger_prob <= 0.0:
return ids
# Only drop aligned audio-condition positions; keep BOS/EOS untouched.
eligible_mask = condition_mask & (ids != bos_id) & (ids != eos_id)
if not eligible_mask.any():
return ids
dropped = ids.clone()
trigger_mask = torch.rand(ids.size(0), device=ids.device) < trigger_prob
span_len = 25
for batch_idx in torch.nonzero(trigger_mask, as_tuple=False).flatten():
candidate_positions = torch.nonzero(
eligible_mask[batch_idx], as_tuple=False
).flatten()
num_candidates = int(candidate_positions.numel())
if num_candidates == 0:
continue
drop_ratio = torch.rand((), device=ids.device).item()
num_to_drop = int(round(drop_ratio * num_candidates))
if num_to_drop <= 0:
continue
remaining = num_to_drop
available_positions = candidate_positions.clone()
while remaining > 0:
num_available = int(available_positions.numel())
if num_available == 0:
break
cur_span_len = min(span_len, remaining)
if num_available <= cur_span_len:
start_idx = 0
selected_positions = available_positions[:cur_span_len]
else:
max_start = num_available - cur_span_len + 1
start_idx = int(
torch.randint(0, max_start, (1,), device=ids.device).item()
)
selected_positions = available_positions[
start_idx : start_idx + cur_span_len
]
dropped[batch_idx, selected_positions] = replacement_id
keep_mask = torch.ones(
num_available,
dtype=torch.bool,
device=ids.device,
)
keep_mask[start_idx : start_idx + int(selected_positions.numel())] = False
available_positions = available_positions[keep_mask]
remaining -= int(selected_positions.numel())
return dropped
def _build_condition(
self,
chord_ids: Optional[torch.LongTensor],
structure_ids: Optional[torch.LongTensor],
condition_mask: Optional[torch.BoolTensor],
cond_precomputed: Optional[torch.FloatTensor],
) -> Optional[torch.FloatTensor]:
if cond_precomputed is not None:
return cond_precomputed
if chord_ids is None or structure_ids is None:
return None
if self.training:
if condition_mask is None:
raise ValueError("condition_mask is required during training.")
chord_ids = self._drop_audio_condition_spans(
ids=chord_ids,
condition_mask=condition_mask,
trigger_prob=self.chord_dropout_trigger_prob,
replacement_id=CHORD_N_ID,
bos_id=CHORD_BOS_ID,
eos_id=CHORD_EOS_ID,
)
structure_ids = self._drop_audio_condition_spans(
ids=structure_ids,
condition_mask=condition_mask,
trigger_prob=self.structure_dropout_trigger_prob,
replacement_id=SEGMENT_FALLBACK_ID,
bos_id=STRUCTURE_BOS_ID,
eos_id=STRUCTURE_EOS_ID,
)
return self.condition_encoder(chord_ids, structure_ids)
def ce_loss(
self,
logits: torch.FloatTensor,
labels: Optional[torch.LongTensor],
masks: Optional[torch.LongTensor],
) -> Optional[torch.Tensor]:
if labels is None or masks is None:
return None
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].clone()
valid_token_mask = masks[:, 1:].bool().contiguous()
if not valid_token_mask.any():
return shift_logits.new_zeros(())
shift_labels.masked_fill_(~valid_token_mask, -100)
loss_sum = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1).to(shift_logits.device),
ignore_index=-100,
reduction="sum",
)
valid_count = valid_token_mask.sum().to(
device=loss_sum.device,
dtype=loss_sum.dtype,
)
return loss_sum / valid_count.clamp_min(1)
def forward(
self,
input_ids: torch.LongTensor,
masks: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
chord_ids: Optional[torch.LongTensor] = None,
structure_ids: Optional[torch.LongTensor] = None,
condition_mask: Optional[torch.BoolTensor] = None,
cond_precomputed: Optional[torch.FloatTensor] = None,
) -> CausalLMOutputWithPast:
if use_cache is None:
use_cache = self.config.use_cache
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
cond = self._build_condition(
chord_ids=chord_ids,
structure_ids=structure_ids,
condition_mask=condition_mask,
cond_precomputed=cond_precomputed,
)
base_out = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cond_expanded=cond,
condition_mask=condition_mask,
cache_position=cache_position,
)
hidden_states = base_out.last_hidden_state
logits = self.lm_head(hidden_states)
loss = self.ce_loss(logits=logits, labels=labels, masks=masks)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=base_out.past_key_values,
hidden_states=base_out.hidden_states,
attentions=base_out.attentions,
)