| import fasttext |
| import numpy as np |
| import re |
| import string |
| from copy import deepcopy |
|
|
| class MaskLID: |
| """A class for code-switching language identification using iterative masking.""" |
| |
| def __init__(self, model_path, languages=-1): |
| """Initialize the MaskLID class. |
| |
| Args: |
| model_path (str): The path to the fastText model. |
| languages (int or list, optional): The indices or list of language labels to consider. Defaults to -1. |
| """ |
| self.model = fasttext.load_model(model_path) |
| self.output_matrix = self.model.get_output_matrix() |
| self.labels = self.model.get_labels() |
| self.language_indices = self._compute_language_indices(languages) |
| self.labels = [self.labels[i] for i in self.language_indices] |
|
|
| def _compute_language_indices(self, languages): |
| """Compute indices of selected languages. |
| |
| Args: |
| languages (int or list): The indices or list of language labels. |
| |
| Returns: |
| list: Indices of selected languages. |
| """ |
| if languages != -1 and isinstance(languages, list): |
| return [self.labels.index(l) for l in set(languages) if l in self.labels] |
| return list(range(len(self.labels))) |
|
|
| def _softmax(self, x): |
| """Compute softmax values for each score in array x. |
| |
| Args: |
| x (numpy.ndarray): Input array. |
| |
| Returns: |
| numpy.ndarray: Softmax output. |
| """ |
| exp_x = np.exp(x - np.max(x)) |
| return exp_x / np.sum(exp_x) |
|
|
| def _normalize_text(self, text): |
| """Normalize input text. |
| |
| Args: |
| text (str): Input text. |
| |
| Returns: |
| str: Normalized text. |
| """ |
| replace_by = " " |
| replacement_map = {ord(c): replace_by for c in '_:' + '•#{|}' + string.digits} |
| text = text.replace('\n', replace_by) |
| text = text.translate(replacement_map) |
| return re.sub(r'\s+', replace_by, text).strip() |
|
|
| def predict(self, text, k=1): |
| """Predict the language of the input text. |
| |
| Args: |
| text (str): Input text. |
| k (int, optional): Number of top predictions to retrieve. Defaults to 1. |
| |
| Returns: |
| tuple: Top predicted labels and their probabilities. |
| """ |
| sentence_vector = self.model.get_sentence_vector(text) |
| result_vector = np.dot(self.output_matrix, sentence_vector) |
| softmax_result = self._softmax(result_vector)[self.language_indices] |
| top_k_indices = np.argsort(softmax_result)[-k:][::-1] |
| top_k_labels = [self.labels[i] for i in top_k_indices] |
| top_k_probs = softmax_result[top_k_indices] |
| return tuple(top_k_labels), top_k_probs |
|
|
| def compute_v(self, sentence_vector): |
| """Compute the language vectors for a given sentence vector. |
| |
| Args: |
| sentence_vector (numpy.ndarray): Sentence vector. |
| |
| Returns: |
| list: Sorted list of labels and their associated vectors. |
| """ |
| result_vector = np.dot(self.output_matrix[self.language_indices, :], sentence_vector) |
| return sorted(zip(self.labels, result_vector), key=lambda x: x[1], reverse=True) |
|
|
| def compute_v_per_word(self, text): |
| """Compute language vectors for each word in the input text. |
| |
| Args: |
| text (str): Input text. |
| |
| Returns: |
| dict: Dictionary containing language vectors for each word. |
| """ |
| text = self._normalize_text(text) |
| words = self.model.get_line(text)[0] |
| words = [w for w in words if w not in ['</s>', '</s>']] |
| subword_ids = [self.model.get_subwords(sw)[1] for sw in words] |
| sentence_vector = [np.sum([self.model.get_input_vector(id) for id in sid], axis=0) for sid in subword_ids] |
|
|
| dict_text = {} |
| for i, word in enumerate(words): |
| key = f"{i}_{word}" |
| dict_text[key] = {'logits': self.compute_v(sentence_vector[i])} |
|
|
| return dict_text |
|
|
| def mask_label_top_k(self, dict_text, label, top_keep, top_remove): |
| """Mask top predictions for a given label. |
| |
| Args: |
| dict_text (dict): Dictionary containing language vectors for each word. |
| label (str): Label to mask. |
| top_keep (int): Number of top predictions to keep. |
| top_remove (int): Number of top predictions to remove. |
| |
| Returns: |
| tuple: Dictionaries of remaining and deleted words after masking. |
| """ |
| dict_remained = deepcopy(dict_text) |
| dict_deleted = {} |
|
|
| for key, value in dict_text.items(): |
| logits = value['logits'] |
| labels = [t[0] for t in logits] |
|
|
| if label in labels[:top_keep]: |
| dict_deleted[key] = dict_remained[key] |
|
|
| if label in labels[:top_remove]: |
| dict_remained.pop(key, None) |
|
|
| return dict_remained, dict_deleted |
|
|
| @staticmethod |
| def get_sizeof(text): |
| """Compute the size of text in bytes. |
| |
| Args: |
| text (str): Input text. |
| |
| Returns: |
| int: Size of text in bytes. |
| """ |
| return len(text.encode('utf-8')) |
|
|
| @staticmethod |
| def custom_sort(word): |
| """Custom sorting function for words. |
| |
| Args: |
| word (str): Input word. |
| |
| Returns: |
| int or float: Sorted value. |
| """ |
| match = re.match(r'^(\d+)_', word) |
| if match: |
| return int(match.group(1)) |
| else: |
| return float('inf') |
|
|
| def sum_logits(self, dict_data, label): |
| """Compute the sum of logits for a specific label across all words. |
| |
| Args: |
| dict_data (dict): Dictionary containing language vectors for each word. |
| label (str): Label to sum logits for. |
| |
| Returns: |
| float: Total sum of logits for the given label. |
| """ |
| total = 0 |
| for value in dict_data.values(): |
| logits = value['logits'] |
| labels = [t[0] for t in logits] |
| if label in labels: |
| total += logits[labels.index(label)][1] |
| return total |
|
|
| def predict_codeswitch(self, text, beta, alpha, min_prob, min_length, max_lambda=1, max_retry=3, alpha_step_increase=5, beta_step_increase=5): |
| """Predict language switching points in the input text. |
| |
| Args: |
| text (str): Input text. |
| beta (int): Number of top predictions to keep. |
| alpha (int): Number of top predictions to remove. |
| min_prob (float): Minimum probability threshold for language prediction. |
| min_length (int): Minimum length of text after masking. |
| max_lambda (int, optional): Maximum number of iterations. Defaults to 1. |
| max_retry (int, optional): Maximum number of retries. Defaults to 3. |
| alpha_step_increase (int, optional): Step increase for alpha. Defaults to 5. |
| beta_step_increase (int, optional): Step increase for beta. Defaults to 5. |
| Returns: |
| dict: Predicted language switching points and associated information. |
| """ |
| info = {} |
| index = 0 |
| retry = 0 |
|
|
| |
| dict_data = self.compute_v_per_word(text) |
|
|
| while index < max_lambda and retry < max_retry: |
| |
| |
| pred = self.predict(text, k=1) |
| label = pred[0][0] |
| |
| |
| prev_text = text |
| |
| dict_data, dict_masked = self.mask_label_top_k(dict_data, label, beta, alpha) |
|
|
| |
| masked_text = ' '.join(x.split('_', 1)[1] for x in dict_masked.keys()) |
| text = ' '.join(x.split('_', 1)[1] for x in dict_data.keys()) |
| |
| |
| if self.get_sizeof(masked_text) > min_length or index == 0: |
| temp_pred = self.predict(masked_text) |
|
|
| if (temp_pred[1][0] > min_prob and temp_pred[0][0] == label) or index == 0: |
| info[index] = { |
| 'label': label, |
| 'text': masked_text, |
| 'text_keys': dict_masked.keys(), |
| 'size': self.get_sizeof(masked_text), |
| 'sum_logit': self.sum_logits(dict_masked, label) |
| } |
| index += 1 |
| else: |
| text = prev_text |
| beta += beta_step_increase |
| alpha += alpha_step_increase |
| retry += 1 |
| else: |
| text = prev_text |
| beta += beta_step_increase |
| alpha += alpha_step_increase |
| retry += 1 |
|
|
| if self.get_sizeof(text) < min_length: |
| break |
|
|
| |
| |
| post_info = {} |
| for value in info.values(): |
| key = value['label'] |
| if key in post_info: |
| post_info[key].extend(value['text_keys']) |
| else: |
| post_info[key] = list(value['text_keys']) |
|
|
| |
| for key in post_info: |
| post_info[key] = ' '.join([x.split('_', 1)[1] for x in sorted(set(post_info[key]), key=self.custom_sort)]) |
| |
| |
| return post_info |