Compatibility with v5

#4
by RaushanTurganbay HF Staff - opened
Files changed (1) hide show
  1. modular_isaac.py +9 -155
modular_isaac.py CHANGED
@@ -19,8 +19,8 @@ from transformers import (
19
  Qwen3ForCausalLM,
20
  Qwen3PreTrainedModel,
21
  )
22
- from transformers.cache_utils import SlidingWindowCache, StaticCache
23
  from transformers.generation.utils import GenerationMixin
 
24
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
25
  from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3Model
26
  from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
@@ -1340,10 +1340,14 @@ class IsaacModel(Qwen3Model):
1340
  sin = sin.to(inputs_embeds.dtype)
1341
 
1342
  # Prepare attention mask
1343
- if attention_mask is not None:
1344
- attention_mask = self._update_causal_mask(
1345
- attention_mask, inputs_embeds, cache_position, past_key_values, False
1346
- )
 
 
 
 
1347
 
1348
  # Initialize hidden states
1349
  hidden_states = inputs_embeds
@@ -1370,156 +1374,6 @@ class IsaacModel(Qwen3Model):
1370
  past_key_values=past_key_values,
1371
  )
1372
 
1373
- def _update_causal_mask(
1374
- self,
1375
- attention_mask: torch.Tensor,
1376
- input_tensor: torch.Tensor,
1377
- cache_position: torch.Tensor,
1378
- past_key_values: Cache,
1379
- output_attentions: bool = False,
1380
- ):
1381
- if self.config._attn_implementation == "flash_attention_2":
1382
- if attention_mask is not None and past_key_values is not None:
1383
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
1384
- if is_padding_right:
1385
- raise ValueError(
1386
- "You are attempting to perform batched generation with padding_side='right'"
1387
- " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
1388
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1389
- )
1390
- if attention_mask is not None and 0.0 in attention_mask:
1391
- return attention_mask
1392
- return None
1393
-
1394
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1395
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1396
- # to infer the attention mask.
1397
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1398
- using_static_cache = isinstance(past_key_values, StaticCache)
1399
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
1400
-
1401
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1402
- if (
1403
- self.config._attn_implementation == "sdpa"
1404
- and not (using_static_cache or using_sliding_window_cache)
1405
- and not output_attentions
1406
- ):
1407
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1408
- attention_mask,
1409
- inputs_embeds=input_tensor,
1410
- past_key_values_length=past_seen_tokens,
1411
- sliding_window=self.config.sliding_window,
1412
- is_training=self.training,
1413
- ):
1414
- return None
1415
-
1416
- dtype, device = input_tensor.dtype, input_tensor.device
1417
- min_dtype = torch.finfo(dtype).min
1418
- sequence_length = input_tensor.shape[1]
1419
- # SlidingWindowCache or StaticCache
1420
- if using_sliding_window_cache or using_static_cache:
1421
- target_length = past_key_values.get_max_cache_shape()
1422
- # DynamicCache or no cache
1423
- else:
1424
- target_length = (
1425
- attention_mask.shape[-1]
1426
- if isinstance(attention_mask, torch.Tensor)
1427
- else past_seen_tokens + sequence_length + 1
1428
- )
1429
-
1430
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1431
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1432
- attention_mask,
1433
- sequence_length=sequence_length,
1434
- target_length=target_length,
1435
- dtype=dtype,
1436
- device=device,
1437
- cache_position=cache_position,
1438
- batch_size=input_tensor.shape[0],
1439
- config=self.config,
1440
- past_key_values=past_key_values,
1441
- )
1442
-
1443
- if (
1444
- self.config._attn_implementation == "sdpa"
1445
- and attention_mask is not None
1446
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
1447
- and not output_attentions
1448
- ):
1449
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1450
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1451
- # Details: https://github.com/pytorch/pytorch/issues/110213
1452
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1453
-
1454
- return causal_mask
1455
-
1456
- @staticmethod
1457
- def _prepare_4d_causal_attention_mask_with_cache_position(
1458
- attention_mask: torch.Tensor,
1459
- sequence_length: int,
1460
- target_length: int,
1461
- dtype: torch.dtype,
1462
- device: torch.device,
1463
- cache_position: torch.Tensor,
1464
- batch_size: int,
1465
- config: Qwen3Config,
1466
- past_key_values: Cache,
1467
- ):
1468
- """
1469
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1470
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1471
-
1472
- Args:
1473
- attention_mask (`torch.Tensor`):
1474
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
1475
- sequence_length (`int`):
1476
- The sequence length being processed.
1477
- target_length (`int`):
1478
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
1479
- dtype (`torch.dtype`):
1480
- The dtype to use for the 4D attention mask.
1481
- device (`torch.device`):
1482
- The device to place the 4D attention mask on.
1483
- cache_position (`torch.Tensor`):
1484
- Indices depicting the position of the input sequence tokens in the sequence.
1485
- batch_size (`torch.Tensor`):
1486
- Batch size.
1487
- config (`Qwen3Config`):
1488
- The model's configuration class
1489
- past_key_values (`Cache`):
1490
- The cache class that is being used currently to generate
1491
- """
1492
- if attention_mask is not None and attention_mask.dim() == 4:
1493
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1494
- causal_mask = attention_mask
1495
- else:
1496
- min_dtype = torch.finfo(dtype).min
1497
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1498
- diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1499
- if config.sliding_window is not None:
1500
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
1501
- # the check is needed to verify is current checkpoint was trained with sliding window or not
1502
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
1503
- sliding_attend_mask = torch.arange(target_length, device=device) <= (
1504
- cache_position.reshape(-1, 1) - config.sliding_window
1505
- )
1506
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
1507
- causal_mask *= diagonal_attend_mask
1508
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1509
- if attention_mask is not None:
1510
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1511
- if attention_mask.shape[-1] > target_length:
1512
- attention_mask = attention_mask[:, :target_length]
1513
- mask_length = attention_mask.shape[-1]
1514
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1515
- causal_mask.device
1516
- )
1517
- padding_mask = padding_mask == 0
1518
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1519
- padding_mask, min_dtype
1520
- )
1521
- return causal_mask
1522
-
1523
 
1524
  class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
1525
  """Isaac multimodal model for conditional generation."""
 
19
  Qwen3ForCausalLM,
20
  Qwen3PreTrainedModel,
21
  )
 
22
  from transformers.generation.utils import GenerationMixin
23
+ from transformers.masking_utils import create_causal_mask
24
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
25
  from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3Model
26
  from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
 
1340
  sin = sin.to(inputs_embeds.dtype)
1341
 
1342
  # Prepare attention mask
1343
+ attention_mask = create_causal_mask(
1344
+ config=self.config,
1345
+ input_embeds=inputs_embeds,
1346
+ attention_mask=attention_mask,
1347
+ past_key_values=past_key_values,
1348
+ position_ids=position_ids,
1349
+ cache_position=cache_position,
1350
+ )
1351
 
1352
  # Initialize hidden states
1353
  hidden_states = inputs_embeds
 
1374
  past_key_values=past_key_values,
1375
  )
1376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1377
 
1378
  class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
1379
  """Isaac multimodal model for conditional generation."""