| import copy |
| import transformers |
| import tokenizers |
| import torch |
| from typing import Dict, Optional, Sequence, List |
| from packaging import version |
|
|
| from llava.mm_utils import tokenizer_image_token |
| from llava.train.arguments import ModelArguments, TrainingArguments, DataArguments |
| from llava.constants import IGNORE_INDEX, MM_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN |
| from llava import conversation as conversation_lib |
|
|
| IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') |
|
|
| def _tokenize_fn(strings: Sequence[str], |
| tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
| """Tokenize a list of strings.""" |
| tokenized_list = [ |
| tokenizer( |
| text, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ) for text in strings |
| ] |
| input_ids = labels = [ |
| tokenized.input_ids[0] for tokenized in tokenized_list |
| ] |
| input_ids_lens = labels_lens = [ |
| tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
| for tokenized in tokenized_list |
| ] |
| return dict( |
| input_ids=input_ids, |
| labels=labels, |
| input_ids_lens=input_ids_lens, |
| labels_lens=labels_lens, |
| ) |
|
|
|
|
| def _mask_targets(target, tokenized_lens, speakers): |
| |
| cur_idx = tokenized_lens[0] |
| tokenized_lens = tokenized_lens[1:] |
| target[:cur_idx] = IGNORE_INDEX |
| |
| for tokenized_len, speaker in zip(tokenized_lens, speakers): |
| if speaker == "human": |
| target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX |
| cur_idx += tokenized_len |
|
|
|
|
| def _add_speaker_and_signal(header, source, get_conversation=True): |
| """Add speaker and start/end signal on each round.""" |
| BEGIN_SIGNAL = "### " |
| END_SIGNAL = "\n" |
| conversation = header |
| for sentence in source: |
| from_str = sentence["from"] |
| if from_str.lower() == "human": |
| from_str = conversation_lib.default_conversation.roles[0] |
| elif from_str.lower() == "gpt": |
| from_str = conversation_lib.default_conversation.roles[1] |
| else: |
| from_str = 'unknown' |
| sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + |
| sentence["value"] + END_SIGNAL) |
| if get_conversation: |
| conversation += sentence["value"] |
| conversation += BEGIN_SIGNAL |
| return conversation |
|
|
| def preprocess_multimodal( |
| sources: Sequence[str], |
| data_args: DataArguments |
| ) -> Dict: |
| is_multimodal = data_args.is_multimodal |
| if not is_multimodal: |
| return sources |
|
|
| for source in sources: |
| for sentence in source: |
|
|
| if DEFAULT_VIDEO_TOKEN in sentence['value']: |
| |
| |
| sentence['value'] = sentence['value'].strip() |
| if "mmtag" in conversation_lib.default_conversation.version: |
| raise NotImplementedError |
| |
| replace_token = DEFAULT_VIDEO_TOKEN |
| if data_args.mm_use_start_end: |
| replace_token = DEFAULT_VIDEO_START_TOKEN + replace_token + DEFAULT_VIDEO_END_TOKEN |
| sentence["value"] = sentence["value"].replace(DEFAULT_VIDEO_TOKEN, replace_token) |
| |
|
|
| if DEFAULT_IMAGE_TOKEN in sentence['value']: |
| |
| |
| sentence['value'] = sentence['value'].strip() |
| if "mmtag" in conversation_lib.default_conversation.version: |
| sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>') |
| replace_token = DEFAULT_IMAGE_TOKEN |
| if data_args.mm_use_start_end: |
| replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN |
| sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) |
|
|
| return sources |
|
|
|
|
| def preprocess_llama_2( |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| has_image: bool = False |
| ) -> Dict: |
| conv = conversation_lib.default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]["from"]] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence["from"]] |
| assert role == conv.roles[j % 2], f"{i}" |
| conv.append_message(role, sentence["value"]) |
| conversations.append(conv.get_prompt()) |
|
|
| |
|
|
| if has_image: |
| input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| else: |
| input_ids = tokenizer( |
| conversations, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
|
|
| targets = input_ids.clone() |
|
|
| assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 |
|
|
| |
| sep = "[/INST] " |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| rounds = conversation.split(conv.sep2) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
| for i, rou in enumerate(rounds): |
| if rou == "": |
| break |
|
|
| parts = rou.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
|
|
| if has_image: |
| round_len = len(tokenizer_image_token(rou, tokenizer)) |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
| else: |
| round_len = len(tokenizer(rou).input_ids) |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
| cur_len += round_len |
| target[cur_len:] = IGNORE_INDEX |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_INDEX |
| print( |
| f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| f" (ignored)" |
| ) |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| ) |
|
|
|
|
| def preprocess_v1( |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| has_image: bool = False |
| ) -> Dict: |
| conv = conversation_lib.default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]["from"]] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence["from"]] |
| assert role == conv.roles[j % 2], f"{i}" |
| conv.append_message(role, sentence["value"]) |
| conversations.append(conv.get_prompt()) |
|
|
| |
|
|
| if has_image: |
| input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| else: |
| input_ids = tokenizer( |
| conversations, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
|
|
| targets = input_ids.clone() |
|
|
| assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
|
|
| |
| sep = conv.sep + conv.roles[1] + ": " |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| rounds = conversation.split(conv.sep2) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
| for i, rou in enumerate(rounds): |
| if rou == "": |
| break |
|
|
| parts = rou.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
|
|
| if has_image: |
| round_len = len(tokenizer_image_token(rou, tokenizer)) |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
| else: |
| round_len = len(tokenizer(rou).input_ids) |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
| if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: |
| round_len -= 1 |
| instruction_len -= 1 |
|
|
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
| cur_len += round_len |
| target[cur_len:] = IGNORE_INDEX |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_INDEX |
| print( |
| f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| f" (ignored)" |
| ) |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| ) |
|
|
|
|
| def preprocess_mpt( |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| has_image: bool = False |
| ) -> Dict: |
| conv = conversation_lib.default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]["from"]] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence["from"]] |
| assert role == conv.roles[j % 2], f"{i}" |
| conv.append_message(role, sentence["value"]) |
| conversations.append(conv.get_prompt()) |
|
|
| |
|
|
| if has_image: |
| input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| else: |
| input_ids = tokenizer( |
| conversations, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
|
|
| targets = input_ids.clone() |
| assert conv.sep_style == conversation_lib.SeparatorStyle.MPT |
|
|
| |
| sep = conv.sep + conv.roles[1] |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| rounds = conversation.split(conv.sep) |
| re_rounds = [conv.sep.join(rounds[:3])] |
| for conv_idx in range(3, len(rounds), 2): |
| re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) |
| cur_len = 0 |
| target[:cur_len] = IGNORE_INDEX |
| for i, rou in enumerate(re_rounds): |
| if rou == "": |
| break |
|
|
| parts = rou.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
|
|
| if has_image: |
| round_len = len(tokenizer_image_token(rou, tokenizer)) |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
| else: |
| round_len = len(tokenizer(rou).input_ids) |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
|
|
| if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: |
| round_len += 1 |
| instruction_len += 1 |
|
|
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
| cur_len += round_len |
| target[cur_len:] = IGNORE_INDEX |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_INDEX |
| print( |
| f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| f" (ignored)" |
| ) |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| ) |
|
|
|
|
| def preprocess_plain( |
| sources: Sequence[str], |
| tokenizer: transformers.PreTrainedTokenizer, |
| ) -> Dict: |
| |
| conversations = [] |
| for source in sources: |
| assert len(source) == 2 |
| |
| |
| conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep |
| conversations.append(conversation) |
| |
| input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
| targets = copy.deepcopy(input_ids) |
| for target, source in zip(targets, sources): |
| tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) |
| target[:tokenized_len] = IGNORE_INDEX |
|
|
| return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
|
|
|
| def preprocess_gemma( |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| has_image: bool = False |
| ) -> Dict: |
| conv = conversation_lib.default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]["from"]] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence["from"]] |
| assert role == conv.roles[j % 2], f"{i}" |
| conv.append_message(role, sentence["value"]) |
| conversations.append(conv.get_prompt(use_chat_template=True, tokenizer=tokenizer)) |
|
|
| |
|
|
| if has_image: |
| input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| else: |
| input_ids = tokenizer( |
| conversations, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
|
|
| targets = input_ids.clone() |
|
|
| |
| sep = conv.sep + conv.roles[1] + '\n' |
| sep2 = conv.sep2 + '\n' + conv.sep + conv.roles[0] |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| rounds = conversation.split(sep2) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
| for i, rou in enumerate(rounds): |
| if rou == "": |
| break |
| if i != len(rounds) - 1: |
| rou += conv.sep2 + '\n' |
| if i >= 1 : |
| rou = conv.sep + conv.roles[0] + rou |
|
|
| parts = rou.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
|
|
| if has_image: |
| round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
| else: |
| raise NotImplementedError |
|
|
|
|
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
| cur_len += round_len |
| target[cur_len:] = IGNORE_INDEX |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_INDEX |
| print( |
| f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| f" (ignored)" |
| ) |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| ) |
|
|
|
|
| def preprocess_mistral( |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| has_image: bool = False |
| ) -> Dict: |
| conv = conversation_lib.default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]["from"]] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence["from"]] |
| assert role == conv.roles[j % 2], f"{i}" |
| conv.append_message(role, sentence["value"]) |
| conversations.append(conv.get_prompt(use_chat_template=True, tokenizer=tokenizer)) |
|
|
| |
|
|
| if has_image: |
| input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| else: |
| input_ids = tokenizer( |
| conversations, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
|
|
| targets = input_ids.clone() |
|
|
| |
| sep = " [/INST]" |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| rounds = conversation.split(conv.sep2) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
| for i, rou in enumerate(rounds): |
| if rou == "": |
| break |
| parts = rou.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
|
|
| if has_image: |
| round_len = len(tokenizer_image_token(rou, tokenizer)) |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
| else: |
| round_len = len(tokenizer(rou).input_ids) |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
|
|
| target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
| cur_len += round_len |
| target[cur_len:] = IGNORE_INDEX |
| if rou[-1] == ' ': |
| cur_len += 1 |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_INDEX |
| print( |
| f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| f" (ignored)" |
| ) |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| ) |
|
|
|
|
| def preprocess_thoth( |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| has_image: bool = False |
| ) -> Dict: |
| conv = conversation_lib.default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1], "model": conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]["from"]] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence["from"]] |
| assert role == conv.roles[j % 2], f"{i}" |
| conv.append_message(role, sentence["value"]) |
| conversations.append(conv.get_prompt()) |
|
|
| |
|
|
| if has_image: |
| input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| else: |
| input_ids = tokenizer( |
| conversations, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
|
|
| targets = input_ids.clone() |
|
|
|
|
| |
| sep = conv.sep + conv.roles[1] + ": " |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| rounds = conversation.split(conv.sep2) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
| for i, rou in enumerate(rounds): |
| if rou == "": |
| break |
|
|
| parts = rou.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
|
|
| if has_image: |
| round_len = len(tokenizer_image_token(rou, tokenizer)) |
| instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
| else: |
| round_len = len(tokenizer(rou).input_ids) |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
| target[cur_len: cur_len + instruction_len] = IGNORE_INDEX |
| cur_len += round_len + 1 |
| if i == 0: |
| cur_len -= 1 |
| target[cur_len:] = IGNORE_INDEX |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_INDEX |
| print( |
| f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
| f" (ignored)" |
| ) |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| ) |
|
|
|
|
| def preprocess( |
| sources: Sequence[str], |
| tokenizer: transformers.PreTrainedTokenizer, |
| has_image: bool = False |
| ) -> Dict: |
| """ |
| Given a list of sources, each is a conversation list. This transform: |
| 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
| 2. Concatenate conversations together; |
| 3. Tokenize the concatenated conversation; |
| 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
| """ |
| if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: |
| return preprocess_plain(sources, tokenizer) |
| if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: |
| return preprocess_llama_2(sources, tokenizer, has_image=has_image) |
| if conversation_lib.default_conversation.version.startswith("v1"): |
| return preprocess_v1(sources, tokenizer, has_image=has_image) |
| if conversation_lib.default_conversation.version == "mpt": |
| return preprocess_mpt(sources, tokenizer, has_image=has_image) |
| if conversation_lib.default_conversation.version == 'gemma': |
| return preprocess_gemma(sources, tokenizer, has_image=has_image) |
| if conversation_lib.default_conversation.version == 'thoth': |
| return preprocess_thoth(sources, tokenizer, has_image=has_image) |
| if conversation_lib.default_conversation.version == 'mistral': |
| return preprocess_mistral(sources, tokenizer, has_image=has_image) |
| |
| conversations = [] |
| for source in sources: |
| header = f"{conversation_lib.default_conversation.system}\n\n" |
| conversation = _add_speaker_and_signal(header, source) |
| conversations.append(conversation) |
| |
| def get_tokenize_len(prompts): |
| return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] |
|
|
| if has_image: |
| input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
| else: |
| conversations_tokenized = _tokenize_fn(conversations, tokenizer) |
| input_ids = conversations_tokenized["input_ids"] |
|
|
| targets = copy.deepcopy(input_ids) |
| for target, source in zip(targets, sources): |
| if has_image: |
| tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) |
| else: |
| tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] |
| speakers = [sentence["from"] for sentence in source] |
| _mask_targets(target, tokenized_lens, speakers) |
|
|
| return dict(input_ids=input_ids, labels=targets) |
|
|