humit-tagger-base / modeling_humit_tagger.py
Ahmet Yildirim
- Fix parametername of lemmatisation
9eefaf1
from transformers import (
AutoModel,
AutoTokenizer
)
import torch
from huggingface_hub import hf_hub_download
import os
import importlib.util
import sys
import shutil
from safetensors.torch import load_model
import json
import re
import copy
class HumitTaggerModel(torch.nn.Module):
# We do not need to do anything to register our class as this class will only be used
# for easily getting humit-tagger worki
def register_for_auto_class(auto_class):
pass
return
# Define our own from-pretrained to load the weights and other files needed for the tagger to work
def from_pretrained(repo_name, **kwargs):
# Download this model's config:
this_model_config_path = hf_hub_download(repo_id=repo_name, filename=kwargs["config"].humit_tagger_configuration)
# load this model's config
with open(this_model_config_path,"r") as js:
kwargs["this_model_config"]=json.load(js)
# Download this model's lemma rules pickle file:
lemma_rules_path = hf_hub_download(repo_id=repo_name, filename=kwargs["config"].lemma_rules_py_file)
# load lemma rules class
sys.path.append(os.path.dirname(lemma_rules_path))
spec = importlib.util.spec_from_file_location("lemma_rules", lemma_rules_path)
lemma_rules = importlib.util.module_from_spec(spec)
sys.modules["lemma_rules"] = lemma_rules
spec.loader.exec_module(lemma_rules)
# Download base_model files into cache
base_config_file = hf_hub_download(repo_id=repo_name, filename=kwargs["this_model_config"]["base_model_config_file"])
base_model_file = hf_hub_download(repo_id=repo_name, filename=kwargs["this_model_config"]["base_model_model_file"])
base_model_config_json_file = hf_hub_download(repo_id=kwargs["this_model_config"]["base_model"], filename=kwargs["this_model_config"]["base_model_config_json_file"])
fullformlist_file = hf_hub_download(repo_id=repo_name, filename=kwargs["this_model_config"]["fullformlist_file"])
# Register the new files:
# First register the base model config file
sys.path.append(os.path.dirname(base_config_file))
spec = importlib.util.spec_from_file_location("base_config", base_config_file)
base_config = importlib.util.module_from_spec(spec)
sys.modules["base_config"] = base_config
spec.loader.exec_module(base_config)
# Then register the base model file
sys.path.append(os.path.dirname(base_model_file))
spec = importlib.util.spec_from_file_location("base_model", base_model_file)
base_model = importlib.util.module_from_spec(spec)
sys.modules["base_model"] = base_model
spec.loader.exec_module(base_model)
# Download model weights
model_weights_path = hf_hub_download(repo_id=repo_name, filename=kwargs["this_model_config"]["model_weights"])
# load base model config
with open(base_model_config_json_file,"r") as js:
kwargs["base_model_json_cfg"] = json.load(js)
kwargs["model_weights_path"] = model_weights_path
kwargs["repo_name"] = repo_name
kwargs["fullformlist_file"] = fullformlist_file
return HumitTaggerModel(**kwargs)
def __init__(self, **kwargs ):
super(HumitTaggerModel, self).__init__()
json_cfg = kwargs["base_model_json_cfg"]
self.config = kwargs["this_model_config"]
self.LemmaHandling = sys.modules["lemma_rules"].LemmaHandling
self.LemmaHandling.load_lemma_rules_from_obj(self.config["lemma_rules"])
cfg=sys.modules["base_config"].NorbertConfig(**json_cfg)
self.bert=sys.modules["base_model"].NorbertModel(cfg, pooling_type="CLS")
self.dropout = torch.nn.Dropout(self.bert.config.hidden_dropout_prob)
self.classifier1 = torch.nn.Linear(self.bert.config.hidden_size, self.config["num_labels1"])
self.classifier2 = torch.nn.Linear(self.bert.config.hidden_size, self.config["num_labels2"])
self.classifier3 = torch.nn.Linear(self.bert.config.hidden_size, self.config["num_labels3"])
self.seq_classifier = torch.nn.Linear(self.bert.config.hidden_size, self.config["num_labels_seq"])
self.ignore_index = self.config["ignore_index"]
load_model(self, kwargs["model_weights_path"])
self.tokenizer=AutoTokenizer.from_pretrained(kwargs["repo_name"])
if "batch_size" in kwargs:
self.batch_size=kwargs["batch_size"]
else:
self.batch_size=8
if "device" in kwargs:
self.device = torch.device(kwargs["device"])
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.MAX_LENGTH_WITHOUT_CLS = self.bert.config.max_position_embeddings -1
self.tags=self.config["tags"]
self.tags_str=[[" ".join(i) for i in self.config["tags"][0]], [" ".join(i) for i in self.config["tags"][1]]]
self.to(self.device)
self.REPLACE_DICT = self.config["replace_dict"]
self.REPLACE_PATTERN = '|'.join(sorted(re.escape(k) for k in self.REPLACE_DICT))
self.MAX_LENGTH = self.bert.config.max_position_embeddings
# Note the classes that represents gen and prop tags
self.gen_tag_classes = set()
self.prop_tag_classes = set()
self.t_2_tag_classes = set()
for i, lst in enumerate(self.config["tags"][0]):
if "gen" in lst:
self.gen_tag_classes.add(i)
if "prop" in lst:
self.prop_tag_classes.add(i)
if "2" in lst:
self.t_2_tag_classes.add(i)
# Load the fullform list
self.fullform_list=[{},{}]
try:
with open(kwargs["fullformlist_file"], 'r') as f:
self.fullform_list = json.load(f)
for k in range(2):
for i in self.fullform_list[k]:
for j in self.fullform_list[k][i][j]:
self.fullform_list[k][i][j]=set(self.fullform_list[k][i][j])
except:
pass
def forward(self, input_ids=None, attention_mask=None ):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True )
sequence_output = self.dropout(outputs.last_hidden_state)
logits1 = self.classifier1(sequence_output)
logits2 = self.classifier2(sequence_output)
logits3 = self.classifier3(sequence_output)
seq_logits = self.seq_classifier(sequence_output)
total_loss = 0
return {
"logits1": logits1,
"logits2": logits2,
"logits3": logits3,
"seq_logits": seq_logits,
}
def _preprocess_text(self,text):
new_text = re.sub(self.REPLACE_PATTERN, lambda m: self.REPLACE_DICT.get(m.group(0).upper()), text)
while new_text != text:
text = new_text
new_text = re.sub(self.REPLACE_PATTERN, lambda m: self.REPLACE_DICT.get(m.group(0).upper()), text)
return new_text
def _batchify(self, lst):
# Create batches
batched_sentences=[]
my_batch=[]
for sentence in lst:
sentence.append(self.tokenizer.sep_token_id)
my_batch.append(sentence)
if len(my_batch)==self.batch_size:
max_len=len(max(my_batch, key=len))
if max_len > self.MAX_LENGTH:
max_len = self.MAX_LENGTH
my_attentions=torch.LongTensor([[1] * len(i[0:max_len]) + [0]*(max_len-len(i[0:max_len])) for i in my_batch]).to("cpu")
my_batch=[i[0:max_len] + [0]*(max_len-len(i[0:max_len])) for i in my_batch]
to_append={
"input_ids": torch.LongTensor(my_batch).to("cpu"),
"attention_mask": my_attentions,
}
batched_sentences.append(to_append)
my_batch=[]
if len(my_batch)>0:
max_len=len(max(my_batch, key=len))
if max_len > self.MAX_LENGTH:
max_len = self.MAX_LENGTH
my_attentions=torch.LongTensor([[1] * len(i[0:max_len]) + [0]*(max_len-len(i[0:max_len])) for i in my_batch]).to("cpu")
my_batch=[i[0:max_len] + [0]*(max_len-len(i[0:max_len])) for i in my_batch]
to_append={
"input_ids": torch.LongTensor(my_batch).to("cpu"),
"attention_mask": my_attentions,
}
batched_sentences.append(to_append)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return batched_sentences
def _split_sentences(self, inp):
# Remove double spaces
inp=" ".join(inp.split())
# Here we get the whole text tokenized.
encodings = self.tokenizer(inp,add_special_tokens=False, return_tensors="pt").to(self.device)
# Save a copy of the tokenization
original_encodings=copy.deepcopy(encodings)
original_encodings=original_encodings.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Pad to the complete size (model max_size -1 (-1 to add CLS))
old_size=encodings["input_ids"][0].size()[0]
# Pad size
pad_size=self.MAX_LENGTH_WITHOUT_CLS - old_size % self.MAX_LENGTH_WITHOUT_CLS
# Number of rows
row_count=int(old_size/self.MAX_LENGTH_WITHOUT_CLS) + 1
# Do padding with pad_id to the pad_size that we have calculated.
encodings["input_ids"] = torch.nn.functional.pad(input=encodings["input_ids"], pad=(0, pad_size), mode="constant", value=self.tokenizer.pad_token_id)
# Set the last token as SENTENCE END (SEP)
encodings["input_ids"][0][old_size]=self.tokenizer.sep_token_id
# Chunk into max_length items
encodings["input_ids"]=torch.reshape(encodings["input_ids"],(row_count,self.MAX_LENGTH_WITHOUT_CLS))
# Add CLS to each item
encodings["input_ids"]=torch.cat(( torch.full((row_count,1), self.tokenizer.cls_token_id, device=self.device) ,encodings["input_ids"]),dim=1)
# Create attention mask
encodings["attention_mask"]=torch.ones_like(encodings["input_ids"], device=self.device)
# Create batches
input_ids_batched=torch.split(encodings["input_ids"], self.batch_size)
attention_mask_batched=torch.split(encodings["attention_mask"], self.batch_size)
# Set the last chunk's attention mask according to its size
attention_mask_batched[-1][-1][pad_size +1:] = 0
encodings=encodings.to("cpu")
# Now pass all chunks through the model and get the labels
# While passing, we count the number of bokmal and nynorsk markers
labels_output=[]
# First get them back to CPU to open space on GPU
input_ids_batched=[i.to("cpu") for i in input_ids_batched]
attention_mask_batched=[i.to("cpu") for i in attention_mask_batched]
if torch.cuda.is_available():
torch.cuda.empty_cache()
for input_ids, attention_masks in zip(input_ids_batched, attention_mask_batched):
current_batch={"input_ids":input_ids.to(self.device).long(), "attention_mask":attention_masks.to(self.device).long()}
outputs = self(**current_batch)
del current_batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
label_data=outputs["logits1"].argmax(-1)
labels_output.extend(label_data)
# Serialize back
labels_output=torch.stack(labels_output ,dim=0)
labels_output=labels_output[:, range(1,self.MAX_LENGTH)]
labels_output=torch.reshape(labels_output,(1,row_count * self.MAX_LENGTH_WITHOUT_CLS))
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Now the data is split into sentences
# So, now create sentence data as list so that this could be used
# in torch operations and can be input to the models
sentence_list=[]
this_sentence=[self.tokenizer.cls_token_id]
for token, label in zip(original_encodings["input_ids"][0].tolist(), labels_output[0].tolist()):
if label==0:
this_sentence.append(token)
else:
this_sentence.append(token)
sentence_list.append(this_sentence)
this_sentence=[self.tokenizer.cls_token_id]
if len(this_sentence)>1:
sentence_list.append(this_sentence)
del original_encodings
del labels_output
del attention_mask_batched
del input_ids_batched
del encodings
del old_size
del inp
del outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
return sentence_list
def _matcher(self, o):
return o.group(0)[0] + "\n\n" + o.group(0)[2]
def split_sentences(self, inp, **tag_config):
inp = [i.replace("\n"," ") for i in re.sub(r"[^.!\?](\n)([^a-z,æ,ø,å,\\ ])", self._matcher, inp).split("\n\n")]
sentences = []
for i in inp:
sentences.extend(self._split_sentences(i.strip()))
return sentences
def _lemmatize(self, tag, LANG):
# Here, a "tag" is a list of words in one sentence, their tags and an ordering of lemma classes according the lemmatization model for each word.
# We go over all words, and apply our algorithm for lemmatization
# 1. If the "pron" tag is found in the tags
# then, we check if the "gen" tag also exists
# if there is the "gen" tag in tags and if there is "s" at the end of the word, we remove that s
# and return the rest of the word as lemma
# 2. OR, we continue with "høflig" processing
# if the word is "De" and if it has the tag "høflig" then we set the lemma as "De", otherwise "de"
# 3. OR, we continue with checking the word and its word class (subst, verb, adj, etc.) towards the fullform lists.
# if the word and its word class exists in the fullformlist (of the language bokmål or nynorsk according the the language parameter)
# then we set the lemma from the fullform list.
# if there are multiple lemmas in the fullform list, then we check each lemma suggested by the model
# we pick the lemma amon the lemmas suggested by the fullformlist that comes the first among the lemmas suggested by model
# 4. OR, we set the first lemma suggested by the model
# 5. OR, just in case, one way or another if we cannot set a lemma, we set the word as the lemma
# Go over all words in the sentence
for i in range(len(tag)):
# If there is prop in tags
if tag[i]["t"] in self.prop_tag_classes:
# set the lemma as the word
tag[i]["l"]=tag[i]["w"]
# if there is gen in tags then remove the last Ss
if tag[i]["t"] in self.gen_tag_classes:
if tag[i]["l"].endswith("'s") or tag[i]["l"].endswith("'S"):
tag[i]["l"]=tag[i]["l"][:-2]
elif tag[i]["l"].endswith("s") or tag[i]["l"].endswith("S") or tag[i]["l"].endswith("'"):
tag[i]["l"]=tag[i]["l"][:-1]
continue
# if høflig
if tag[i]["w"]=="De":
if tag[i]["t"] in self.t_2_tag_classes:
tag[i]["l"]="De"
continue
else:
tag[i]["l"]="de"
continue
# for the rest of the cases of the word, lowercase the word and check against the fullform list
word=tag[i]["w"].lower()
word_class = self.tags[0][tag[i]["t"]][0]
# get the lemma from the fullform list
fullform_list_lemma = self.fullform_list[LANG].get(word, {}).get(word_class)
# if there is not a lemma in the fullformlist
# use the first lemma from the model
if fullform_list_lemma==None:
tag[i]["l"] = self.LemmaHandling.get_lemma_given_word_and_lemma_list_index(tag[i]["w"], tag[i]["l"][0] )
# if there is only one fullformlist-lemma:
elif len(fullform_list_lemma) == 1:
tag[i]["l"] = next(iter(fullform_list_lemma))
# if there are multiple lemmas in the fullformlist
# here we disambugate among these lemmas using the alternatives from the model
elif len(fullform_list_lemma) > 1:
tag[i]["l"] = next((selected_lemma for x in tag[i]["l"] if (selected_lemma := self.LemmaHandling.get_lemma_given_word_and_lemma_list_index(tag[i]["w"], x )) in fullform_list_lemma), self.LemmaHandling.get_lemma_given_word_and_lemma_list_index(tag[i]["w"], tag[i]["l"][0] ) )
# This branch will probably not be called but kept just in case
# If none of the cases above, use the first lemma suggested by the model
else:
tag[i]["l"] = self.LemmaHandling.get_lemma_given_word_and_lemma_list_index(tag[i]["w"], tag[i]["l"][0] )
# This if will probable not be true either but kept just in case
# If a lemma could not be assigned after all these operations
# then asign the word itself
# Check by if the lemma field is still a list or if the field-type is string the legth is 0
if type(tag[i]["l"]) == list or len(tag[i]["l"]) == 0:
tag[i]["l"] = tag[i]["w"]
return tag
def tag_sentence_list(self, lst, **tag_config):
# If the sentences are not tokenized, tokenize while batching:
tokenized_batches = []
if type(lst[0])==str:
tokenized_batches = []
for i in range(0, len(lst), self.batch_size):
batch_texts = lst[i:i + self.batch_size]
encoded_batch = self.tokenizer(batch_texts, padding=True, truncation=True, max_length=self.MAX_LENGTH, return_tensors="pt", return_token_type_ids=False)
encoded_batch["input_ids"].to("cpu")
encoded_batch["attention_mask"].to("cpu")
tokenized_batches.append(encoded_batch)
# sentences are already tokenized, then batchify them:
else:
tokenized_batches = self._batchify(lst)
# If lemmatization will be applied
if tag_config["lemmatize"]:
# If language will be identified per sentence
if tag_config["lang_per_sentence"]:
id_to_lang = self.config["id_to_lang"]
# If the output will be to a python list
if tag_config["write_output_to"]==None:
all_tagged_sentences = []
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemma_indices = torch.topk(all_out["logits3"].flatten(start_dim=2, end_dim=2), len(self.LemmaHandling.lemma_list))
#batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas, lang in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemma_indices.indices.tolist(), batch_langs[:, 0].tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, lang)
all_tagged_sentences.append({"lang":id_to_lang[lang], "sent": [ {"w":i["w"], "t":self.tags[lang][i["t"]], "l":i["l"]} for i in this_sentence]})
return all_tagged_sentences
# If the output is in TSV format to a pipe (stdout or a file handle)
elif tag_config["output_tsv"]:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemma_indices = torch.topk(all_out["logits3"].flatten(start_dim=2, end_dim=2), len(self.LemmaHandling.lemma_list))
#batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas, lang in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemma_indices.indices.tolist(), batch_langs[:, 0].tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, lang)
this_sentence=[ {"w":i["w"], "t":self.tags_str[lang][i["t"]], "l":i["l"]} for i in this_sentence]
tag_config["write_output_to"].write(id_to_lang[lang])
for lin in this_sentence:
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["w"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["l"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["t"])
tag_config["write_output_to"].write("\n")
tag_config["write_output_to"].write("\n")
# If output format will be json to a pipe (stdout or a file handle)
else:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemma_indices = torch.topk(all_out["logits3"].flatten(start_dim=2, end_dim=2), len(self.LemmaHandling.lemma_list))
#batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas, lang in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemma_indices.indices.tolist(), batch_langs[:, 0].tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, lang)
json.dump({"lang":id_to_lang[lang], "sent":[ {"w":i["w"], "t":self.tags[lang][i["t"]], "l":i["l"]} for i in this_sentence]}, tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
# If the language is set as parameter
elif tag_config["lang"] != -1:
LANG = tag_config["lang"]
LANG_STR = self.config["id_to_lang"][LANG]
# If the output will be to a python list
if tag_config["write_output_to"]==None:
all_tagged_sentences = []
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemma_indices = torch.topk(all_out["logits3"].flatten(start_dim=2, end_dim=2), len(self.LemmaHandling.lemma_list))
#batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemma_indices in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemma_indices.indices.tolist()): #batch_lemmas.tolist(),
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemma_indices[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, LANG)
all_tagged_sentences.append({"lang":LANG_STR, "sent": [ {"w":i["w"], "t":self.tags[LANG][i["t"]], "l":i["l"]} for i in this_sentence]})
return all_tagged_sentences
# If the output is in TSV format to a pipe (stdout or a file handle)
elif tag_config["output_tsv"]:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemma_indices = torch.topk(all_out["logits3"].flatten(start_dim=2, end_dim=2), len(self.LemmaHandling.lemma_list))
#batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemma_indices.indices.tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, LANG)
this_sentence=[ {"w":i["w"], "t":self.tags_str[LANG][i["t"]], "l":i["l"]} for i in this_sentence]
tag_config["write_output_to"].write(LANG_STR)
for lin in this_sentence:
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["w"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["l"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["t"])
tag_config["write_output_to"].write("\n")
tag_config["write_output_to"].write("\n")
# If output format will be json to a pipe (stdout or a file handle)
else:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemma_indices = torch.topk(all_out["logits3"].flatten(start_dim=2, end_dim=2), len(self.LemmaHandling.lemma_list))
#batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemma_indices.indices.tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, LANG)
json.dump({"lang":LANG_STR, "sent": [ {"w":i["w"], "t":self.tags[LANG][i["t"]], "l":i["l"]} for i in this_sentence]}, tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
# If language will be identified according to the majority of all sentences:
else:
all_tags=[]
all_lemmas=[]
all_langs=[]
all_input_ids=[]
# Go over all batches and each sentence in each batch
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemma_indices = torch.topk(all_out["logits3"].flatten(start_dim=2, end_dim=2), len(self.LemmaHandling.lemma_list))
#batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
all_input_ids.extend(batch["input_ids"].tolist())
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
all_langs.extend(batch_langs[:, 0].tolist())
all_tags.extend(batch_tags.tolist())
all_lemmas.extend(batch_lemma_indices.indices.tolist())
# Identify the language
tag_config["lang"] = 1 if sum(all_langs)/len(all_langs)>=0.5 else 0
LANG = tag_config["lang"]
LANG_STR = self.config["id_to_lang"][LANG]
# If the output will be returned as python list:
if tag_config["write_output_to"]==None:
all_tagged_sentences = []
for input_ids, tags, lemmas in zip(all_input_ids, all_tags, all_lemmas):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, LANG)
all_tagged_sentences.append({"lang":LANG_STR, "sent": [ {"w":i["w"], "t":self.tags[LANG][i["t"]], "l":i["l"]} for i in this_sentence] })
return all_tagged_sentences
# If the output is in TSV format
elif tag_config["output_tsv"]:
for input_ids, tags, lemmas in zip(all_input_ids, all_tags, all_lemmas):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, LANG)
this_sentence=[ {"w":i["w"], "t":self.tags_str[LANG][i["t"]], "l":i["l"]} for i in this_sentence]
tag_config["write_output_to"].write(LANG_STR)
for lin in this_sentence:
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["w"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["l"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["t"])
tag_config["write_output_to"].write("\n")
tag_config["write_output_to"].write("\n")
# If output format will be json
else:
for input_ids, tags, lemmas in zip(all_input_ids, all_tags, all_lemmas):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma[0] == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag, "l":lemma})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag, "l":lemma})
this_sentence = self._lemmatize(this_sentence, LANG)
json.dump({"lang":LANG_STR, "sent":[ {"w":i["w"], "t":self.tags[LANG][i["t"]], "l":i["l"]} for i in this_sentence]}, tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
# If lemmatization will not be applied:
else:
# If language will be identified per sentence
if tag_config["lang_per_sentence"]:
id_to_lang = self.config["id_to_lang"]
# If the output will be to a python list
if tag_config["write_output_to"]==None:
all_tagged_sentences = []
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas, lang in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemmas.tolist(), batch_langs[:, 0].tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
all_tagged_sentences.append({"lang":id_to_lang[lang], "sent": [ {"w":i["w"], "t":self.tags[lang][i["t"]]} for i in this_sentence]})
return all_tagged_sentences
# If the output is in TSV format to a pipe (stdout or a file handle)
elif tag_config["output_tsv"]:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas, lang in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemmas.tolist(), batch_langs[:, 0].tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
this_sentence=[ {"w":i["w"], "t":self.tags_str[lang][i["t"]] } for i in this_sentence]
tag_config["write_output_to"].write(id_to_lang[lang])
for lin in this_sentence:
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["w"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["t"])
tag_config["write_output_to"].write("\n")
tag_config["write_output_to"].write("\n")
# If output format will be json to a pipe (stdout or a file handle)
else:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas, lang in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemmas.tolist(), batch_langs[:, 0].tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
json.dump({"lang":id_to_lang[lang], "sent":[ {"w":i["w"], "t":self.tags[lang][i["t"]]} for i in this_sentence]}, tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
# If the language is set as parameter
elif tag_config["lang"] != -1:
LANG = tag_config["lang"]
LANG_STR = self.config["id_to_lang"][LANG]
# If the output will be to a python list
if tag_config["write_output_to"]==None:
all_tagged_sentences = []
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemmas.tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
all_tagged_sentences.append({"lang":LANG_STR, "sent": [ {"w":i["w"], "t":self.tags[LANG][i["t"]]} for i in this_sentence]})
return all_tagged_sentences
# If the output is in TSV format to a pipe (stdout or a file handle)
elif tag_config["output_tsv"]:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemmas.tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
this_sentence=[ {"w":i["w"], "t":self.tags_str[LANG][i["t"]]} for i in this_sentence]
tag_config["write_output_to"].write(LANG_STR)
for lin in this_sentence:
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["w"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["t"])
tag_config["write_output_to"].write("\n")
tag_config["write_output_to"].write("\n")
# If output format will be json to a pipe (stdout or a file handle)
else:
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
for input_ids, tags, lemmas in zip(batch["input_ids"].tolist(), batch_tags.tolist(),
batch_lemmas.tolist()):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
json.dump({"lang":LANG_STR, "sent": [ {"w":i["w"], "t":self.tags[LANG][i["t"]]} for i in this_sentence]}, tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
# If language will be identified according to the majority of all sentences:
else:
all_tags=[]
all_lemmas=[]
all_langs=[]
all_input_ids=[]
# Go over all batches and each sentence in each batch
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_tags = torch.argmax(all_out["logits2"], dim=-1)
batch_lemmas = torch.argmax(all_out["logits3"], dim=-1)
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
all_input_ids.extend(batch["input_ids"].tolist())
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
all_langs.extend(batch_langs[:, 0].tolist())
all_tags.extend(batch_tags.tolist())
all_lemmas.extend(batch_lemmas.tolist())
# Identify the language
tag_config["lang"] = 1 if sum(all_langs)/len(all_langs)>=0.5 else 0
LANG = tag_config["lang"]
LANG_STR = self.config["id_to_lang"][LANG]
# If the output will be returned as python list:
if tag_config["write_output_to"]==None:
all_tagged_sentences = []
for input_ids, tags, lemmas in zip(all_input_ids, all_tags, all_lemmas):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
all_tagged_sentences.append({"lang":LANG_STR, "sent": [ {"w":i["w"], "t":self.tags[LANG][i["t"]]} for i in this_sentence] })
return all_tagged_sentences
# If the output is in TSV format
elif tag_config["output_tsv"]:
for input_ids, tags, lemmas in zip(all_input_ids, all_tags, all_lemmas):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
this_sentence=[ {"w":i["w"], "t":self.tags_str[LANG][i["t"]]} for i in this_sentence]
tag_config["write_output_to"].write(LANG_STR)
for lin in this_sentence:
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["w"])
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lin["t"])
tag_config["write_output_to"].write("\n")
tag_config["write_output_to"].write("\n")
# If output format will be json
else:
for input_ids, tags, lemmas in zip(all_input_ids, all_tags, all_lemmas):
this_sentence=[]
for inps, tag, lemma in zip(input_ids[1:], tags[1:], lemmas[1:]):
if inps == self.tokenizer.sep_token_id or inps == self.tokenizer.pad_token_id:
break
if lemma == 0: # If there is no lemma here, that means we haven't reached the end of the word
if len(this_sentence)>0:
this_sentence[-1]["w"] += self.tokenizer.decode(inps)
else:
this_sentence.append({"w": self.tokenizer.decode(inps), "t": tag})
else:
this_sentence.append({"w":self.tokenizer.decode(inps).strip(), "t":tag})
json.dump({"lang":LANG_STR, "sent":[ {"w":i["w"], "t":self.tags[LANG][i["t"]]} for i in this_sentence]}, tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
def _check_if_text_file_and_return_content(self, filepath):
try:
with open(filepath, 'r') as f:
return f.read()
except Exception as e:
return False
@torch.no_grad()
def tag(self, inp=None, **tag_config):
self.eval()
if "lemmatisation" in tag_config and tag_config["lemmatisation"]==False:
tag_config["lemmatize"] = False
del tag_config["lemmatisation"]
elif "lemmatisation" in tag_config:
tag_config["lemmatize"] = True
del tag_config["lemmatisation"]
elif "lemmatization" in tag_config and tag_config["lemmatization"]==False:
tag_config["lemmatize"] = False
del tag_config["lemmatization"]
elif "lemmatization" in tag_config:
tag_config["lemmatize"] = True
del tag_config["lemmatization"]
else:
tag_config["lemmatize"] = True
if "one_sentence_per_line" not in tag_config:
tag_config["one_sentence_per_line"]=False
if "lang" not in tag_config:
tag_config["lang"]=-1
else:
if tag_config["lang"] in self.config["lang_to_id"]:
tag_config["lang"] = self.config["lang_to_id"][tag_config["lang"]]
else:
tag_config["lang"]=-1
if "output_tsv" not in tag_config:
tag_config["output_tsv"] = False
if "lang_per_sentence" not in tag_config:
tag_config["lang_per_sentence"] = False
elif tag_config["lang_per_sentence"]:
tag_config["lang_per_sentence"] = True
if tag_config["lang"]!=-1 and tag_config["lang_per_sentence"]:
raise ValueError("lang_per_sentence and lang parameters cannot be set at the same time. ")
if "input_directory" in tag_config:
if not "output_directory" in tag_config:
raise ValueError("output_directory must be defined if input_directory is defined. ")
if "write_output_to" in tag_config and tag_config["write_output_to"]!=None:
raise ValueError("If an input and output directory is given, then write_output_to cannot be used as the output will be written to as files in output_directory.")
write_to = sys.stderr if not sys.stderr.closed else sys.stdout if not sys.stdout.closed else open("tag.log","w")
# Process directory
for dir_path, _, files in os.walk(tag_config["input_directory"]):
for f in files:
input_path = os.path.join(dir_path, f)
out_path = os.path.join(tag_config["output_directory"], os.path.relpath(dir_path, tag_config["input_directory"]), f+".tagged")
file_content=self._check_if_text_file_and_return_content(input_path)
if type(file_content)==str:
file_content=self._preprocess_text(file_content)
print (f"Tagging {input_path} to {out_path}.")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
if tag_config["one_sentence_per_line"]:
inp = [i for i in file_content.split("\n") if i!=""]
inp = [" ".join(i.split()) for i in inp if i!=""]
with open(out_path, "w") as opened_file:
tag_config["write_output_to"] = opened_file
self.tag_sentence_list(inp, **tag_config)
else:
inp = self.split_sentences(file_content, **tag_config)
with open(out_path, "w") as opened_file:
tag_config["write_output_to"] = opened_file
self.tag_sentence_list(inp, **tag_config)
else:
print (f"Could not properly open and read {input_path}.")
if write_to is not sys.stdout and write_to is not sys.stderr:
write_to.close()
return
else:
if "write_output_to" not in tag_config or "write_output_to" in tag_config and tag_config["write_output_to"]== None:
tag_config["write_output_to"] = sys.stdout
elif type(tag_config["write_output_to"]) == str and tag_config["write_output_to"]=="list":
tag_config["write_output_to"] = None
elif type(tag_config["write_output_to"]) == str:
tag_config["write_output_to"] = open(tag_config["write_output_to"], "w")
if inp==None:
pass
elif type(inp) == str:
# Tag one sentence per line in a string
if tag_config["one_sentence_per_line"]:
inp = [i for i in inp.split("\n") if i!=""]
inp = [" ".join(self._preprocess_text(i).split()) for i in inp if i!=""]
return self.tag_sentence_list(inp, **tag_config)
# identify sentences
inp = self.split_sentences(inp, **tag_config)
return self.tag_sentence_list(inp, **tag_config)
# Tag one sentence per list item
elif type(inp) == list:
inp=[i.strip() for i in inp]
inp=[" ".join(self._preprocess_text(i).split()) for i in inp if i!=""]
return self.tag_sentence_list(inp, **tag_config)
def identify_language_sentence_list(self, lst, **tag_config):
# If the sentences are not tokenized, tokenize while batching:
tokenized_batches = []
if type(lst[0])==str:
tokenized_batches = []
for i in range(0, len(lst), self.batch_size):
batch_texts = lst[i:i + self.batch_size]
encoded_batch = self.tokenizer(batch_texts, padding=True, truncation=True, max_length=self.MAX_LENGTH, return_tensors="pt", return_token_type_ids=False)
encoded_batch["input_ids"].to("cpu")
encoded_batch["attention_mask"].to("cpu")
tokenized_batches.append(encoded_batch)
# sentences are already tokenized, then batchify them:
else:
tokenized_batches = self._batchify(lst)
all_tagged_sentences = []
# Go over all batches and each sentence in each batch
for batch in tokenized_batches:
all_out = self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))
batch_langs = torch.argmax(all_out["seq_logits"], dim=-1)
batch["input_ids"].to("cpu")
batch["attention_mask"].to("cpu")
all_tagged_sentences.extend(batch_langs[:, 0].tolist())
# If language will be identified per item
if tag_config["lang_per_item"]:
return [self.config["id_to_lang"][i] for i in all_tagged_sentences]
# If language will be identified according to the majority of all sentences:
else:
LANG = 1 if sum(all_tagged_sentences)/len(all_tagged_sentences)>=0.5 else 0
LANG_STR = self.config["id_to_lang"][LANG]
return [LANG_STR] * len(lst)
@torch.no_grad()
def identify_language(self, inp=None, **tag_config):
self.eval()
if "one_sentence_per_line" not in tag_config:
tag_config["one_sentence_per_line"]=False
if "lang" in tag_config:
del tag_config["lang"]
if "output_tsv" not in tag_config:
tag_config["output_tsv"] = False
if "lang_per_sentence" not in tag_config:
tag_config["lang_per_sentence"] = False
elif type(tag_config["lang_per_sentence"])==bool and tag_config["lang_per_sentence"]:
tag_config["lang_per_sentence"] = True
if "input_directory" in tag_config and "output_directory" in tag_config and "write_output_to" in tag_config and tag_config["write_output_to"]!=None:
raise ValueError("If an input and output directory is given, then write_output_to cannot be used as the output will be written to as files in output_directory.")
if "write_output_to" not in tag_config or "write_output_to" in tag_config and tag_config["write_output_to"]== None:
tag_config["write_output_to"] = sys.stdout
elif type(tag_config["write_output_to"]) == str and tag_config["write_output_to"]=="list":
if tag_config["output_tsv"]:
raise ValueError("write_output_to cannot be set to list if output_tsv is set.")
if "output_directory" in tag_config and tag_config["output_directory"]:
raise ValueError("write_output_to cannot be set to list if output_directory is set.")
tag_config["write_output_to"] = None
elif type(tag_config["write_output_to"]) == str:
tag_config["write_output_to"] = open(tag_config["write_output_to"], "w")
if "output_directory" in tag_config:
tag_config["write_output_to"] = None
if "split_sentences" not in tag_config:
tag_config["split_sentences"] = False
if "lang_per_item" not in tag_config:
tag_config["lang_per_item"] = False
if "fast_mode" in tag_config:
if "input_directory" not in tag_config:
raise ValueError("input_directory must be defined if fast_mode is set.")
if tag_config["split_sentences"]:
raise ValueError("fast_mode does not split sentences, so split_sentences cannot be set in this mode.")
if tag_config["lang_per_item"]:
raise ValueError("fast_mode does not identify languages of each line or sentence in a file, so lang_per_item cannot be set in this mode.")
if tag_config["lang_per_sentence"]:
raise ValueError("fast_mode does not identify languages of sentence in a file, so lang_per_sentence cannot be set in this mode.")
general_output=[]
file_names=[]
contents=[]
# Process directory
for dir_path, _, files in os.walk(tag_config["input_directory"]):
for f in files:
input_path = os.path.join(dir_path, f)
if len(file_names) == self.batch_size:
batch = self.tokenizer(contents, padding=True, truncation=True, max_length=self.MAX_LENGTH, return_tensors="pt", return_token_type_ids=False)
langs = torch.argmax( self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))["seq_logits"], dim=-1)[:, 0].tolist()
del batch
torch.cuda.empty_cache()
if tag_config["write_output_to"]==None:
general_output.extend([{"f":i[0], "lang":self.config["id_to_lang"][i[1]]} for i in zip(file_names, langs)])
elif tag_config["output_tsv"]:
for fil,lan in zip(file_names, langs):
tag_config["write_output_to"].write(fil)
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(self.config["id_to_lang"][lan])
tag_config["write_output_to"].write("\n")
else:
for fil,lan in zip(file_names, langs):
json.dump({"f":fil, "lang":self.config["id_to_lang"][lan]})
file_names=[]
contents=[]
else:
content=None
try:
with open(input_path,"r") as ff:
content=ff.read(3000).replace("\n"," ").replace("\r","")
except:
pass
if content!=None:
file_names.append(input_path)
contents.append(content)
if len(file_names)>0:
batch = self.tokenizer(contents, padding=True, truncation=True, max_length=self.MAX_LENGTH, return_tensors="pt", return_token_type_ids=False)
langs = torch.argmax( self(batch["input_ids"].to(self.device), batch["attention_mask"].to(self.device))["seq_logits"], dim=-1)[:, 0].tolist()
del batch
torch.cuda.empty_cache()
if tag_config["write_output_to"]==None:
general_output.extend([{"f":i[0], "lang":self.config["id_to_lang"][i[1]]} for i in zip(file_names, langs)])
elif tag_config["output_tsv"]:
for fil,lan in zip(file_names, langs):
tag_config["write_output_to"].write(fil)
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(self.config["id_to_lang"][lan])
tag_config["write_output_to"].write("\n")
else:
for fil,lan in zip(file_names, langs):
json.dump({"f":fil, "lang":self.config["id_to_lang"][lan]})
return general_output if len(general_output)>0 else None
if "input_directory" in tag_config:
general_output=[]
# Process directory
for dir_path, _, files in os.walk(tag_config["input_directory"]):
for f in files:
input_path = os.path.join(dir_path, f)
file_content=self._check_if_text_file_and_return_content(input_path)
if type(file_content)==str:
file_content=self._preprocess_text(file_content)
new_inp=None
if tag_config["one_sentence_per_line"]:
inp = [i for i in file_content.split("\n") if i!=""]
inp = [i for i in inp if i!=""]
out = self.identify_language_sentence_list(inp, **tag_config)
else:
inp = self.split_sentences(file_content, **tag_config)
out = self.identify_language_sentence_list(inp, **tag_config)
new_inp=[self.tokenizer.decode(i[1:]).split("[SEP]")[0].strip() for i in inp]
if new_inp!=None:
inp=new_inp
# If no output pipe is available than write to
if tag_config["write_output_to"]==None:
if "output_directory" in tag_config:
out_path = os.path.join(tag_config["output_directory"], os.path.relpath(dir_path, tag_config["input_directory"]), f+".lang")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w") as opened_file:
if tag_config["lang_per_sentence"]:
if tag_config["output_tsv"]:
for sen,lan in zip(inp, out):
opened_file.write(sen)
opened_file.write("\t")
opened_file.write(lan)
opened_file.write("\n")
else:
json.dump([{"s":sen, "lang":lan} for sen,lan in zip(inp, out) ] , opened_file)
else:
if tag_config["output_tsv"]:
opened_file.write(out[0])
else:
json.dump({"lang":out[0]} , opened_file)
else:
if tag_config["lang_per_sentence"]:
general_output.extend([{"s":sen, "lang":lan} for sen,lan in zip(inp, out) ])
else:
general_output.append({"f":input_path, "lang":out[0]})
# If there is an opened pipe already
else:
if tag_config["lang_per_sentence"]:
if tag_config["output_tsv"]:
for sen,lan in zip(inp, out):
tag_config["write_output_to"].write(sen)
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lan)
tag_config["write_output_to"].write("\n")
tag_config["write_output_to"].write("\n")
else:
json.dump([{"s":sen, "lang":lan} for sen,lan in zip(inp, out) ] , tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
else:
if tag_config["output_tsv"]:
tag_config["write_output_to"].write(input_path)
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(out[0])
tag_config["write_output_to"].write("\n")
else:
json.dump({"f":input_path, "lang":out[0]} , tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
else:
if tag_config["output_tsv"]:
tag_config["write_output_to"].write(input_path)
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write("err")
tag_config["write_output_to"].write("\n")
else:
json.dump({"f":input_path, "lang":"err"} , tag_config["write_output_to"])
tag_config["write_output_to"].write("\n")
if tag_config["write_output_to"] and tag_config["write_output_to"] is not sys.stdout and tag_config["write_output_to"] is not sys.stderr:
tag_config["write_output_to"].close()
return general_output if len(general_output)>0 else None
if inp==None:
pass
elif type(inp) == str:
new_inp=None
# if split sentences is set
if tag_config["split_sentences"]:
inp = self._preprocess_text(inp)
inp = self.split_sentences(inp, **tag_config)
new_inp=[self.tokenizer.decode(i[1:]).strip() for i in inp]
if tag_config["lang_per_sentence"]:
tag_config["lang_per_item"] = True
# if tag one sentence per line in a string
elif tag_config["one_sentence_per_line"]:
inp = [i for i in inp.split("\n") if i!=""]
inp = [self._preprocess_text(i) for i in inp if i!=""]
if tag_config["lang_per_sentence"]:
tag_config["lang_per_item"] = True
# Otherwise identify the language of the input string as a whole
else:
inp = [self._preprocess_text(inp)]
# Identify language
out = self.identify_language_sentence_list(inp, **tag_config)
if new_inp!=None:
inp=new_inp
# If return as list
if tag_config["write_output_to"]==None:
return [{"s":i[0], "lang": i[1]} for i in zip(inp, out)]
if tag_config["output_tsv"]:
for sen,lan in zip(inp, out):
tag_config["write_output_to"].write(sen)
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(out)
tag_config["write_output_to"].write("\n")
else:
json.dump([{"s":sen, "lang":lan} for sen,lan in zip(inp, out) ] , tag_config["write_output_to"])
return
# Tag one sentence per list item
elif type(inp) == list:
inp=[i.strip() for i in inp]
inp=[self._preprocess_text(i) for i in inp if i!=""]
out = self.identify_language_sentence_list(inp, **tag_config)
# If return as list
if tag_config["write_output_to"]==None:
return [{"s":i[0], "lang": i[1]} for i in zip(inp, out)]
if tag_config["output_tsv"]:
for sen,lan in zip(inp, out):
tag_config["write_output_to"].write(sen)
tag_config["write_output_to"].write("\t")
tag_config["write_output_to"].write(lan)
tag_config["write_output_to"].write("\n")
else:
json.dump([{"s":sen, "lang":lan} for sen,lan in zip(inp, out) ] , tag_config["write_output_to"])
return