kashif HF Staff commited on
Commit
578c911
·
verified ·
1 Parent(s): a784e6c

tokenizer: sync with Carbon-500M/3B (warnings, add_special_tokens pass-through, tolist in decode)

Browse files
Files changed (1) hide show
  1. tokenizer.py +11 -1
tokenizer.py CHANGED
@@ -16,6 +16,7 @@ Supports token_mask for Fine-grained Nucleotide Supervision (FNS):
16
 
17
  import os
18
  import json
 
19
  import itertools
20
  from typing import List, Optional, Tuple, Dict, Union, Any
21
 
@@ -322,7 +323,7 @@ class HybridDNATokenizer(PreTrainedTokenizer):
322
  else:
323
  base_ids = self._base_tokenizer.encode(
324
  segment_content,
325
- add_special_tokens=False
326
  )
327
  token_ids.extend(base_ids)
328
  if return_token_mask:
@@ -344,6 +345,8 @@ class HybridDNATokenizer(PreTrainedTokenizer):
344
  skip_special_tokens: bool = False,
345
  **kwargs
346
  ) -> str:
 
 
347
  if isinstance(token_ids, int):
348
  token_ids = [token_ids]
349
 
@@ -430,6 +433,13 @@ class HybridDNATokenizer(PreTrainedTokenizer):
430
  auto_dna_tags: Optional[bool] = None,
431
  **kwargs
432
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
433
  is_batch = isinstance(text, list)
434
  texts = text if is_batch else [text]
435
 
 
16
 
17
  import os
18
  import json
19
+ import warnings
20
  import itertools
21
  from typing import List, Optional, Tuple, Dict, Union, Any
22
 
 
323
  else:
324
  base_ids = self._base_tokenizer.encode(
325
  segment_content,
326
+ add_special_tokens=add_special_tokens
327
  )
328
  token_ids.extend(base_ids)
329
  if return_token_mask:
 
345
  skip_special_tokens: bool = False,
346
  **kwargs
347
  ) -> str:
348
+ if hasattr(token_ids, 'tolist'):
349
+ token_ids = token_ids.tolist()
350
  if isinstance(token_ids, int):
351
  token_ids = [token_ids]
352
 
 
433
  auto_dna_tags: Optional[bool] = None,
434
  **kwargs
435
  ) -> Dict[str, Any]:
436
+ if add_special_tokens:
437
+ warnings.warn(
438
+ "HybridTokenizer does not support add_special_tokens=True, ignoring.",
439
+ UserWarning
440
+ )
441
+ add_special_tokens = False
442
+
443
  is_batch = isinstance(text, list)
444
  texts = text if is_batch else [text]
445