| import os |
| import json |
| import numpy as np |
| import pandas as pd |
| import logging |
| from collections import Counter |
| from sentence_transformers import SentenceTransformer |
| import warnings |
| from datetime import datetime |
| from sklearn.preprocessing import normalize |
| import requests |
| import json |
| import argparse |
| from openai import OpenAI |
|
|
| from scripts.scripts.sign2text_mapping import sign2text |
|
|
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
| |
| logging.basicConfig( |
| filename='AulSign.log', |
| level=logging.DEBUG, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| filemode='w' |
| ) |
|
|
|
|
|
|
| client = OpenAI( |
| organization=os.getenv("OPENAI_ORGANIZATION"), |
| project=os.getenv("OPENAI_PROJECT"), |
| api_key=os.getenv("OPENAI_API_KEY") |
| ) |
|
|
| print('Inference started...') |
|
|
| def query_ollama(messages, model="mistral:7b-instruct-fp16"): |
| url = "http://localhost:11434/api/chat" |
|
|
| options = {"seed": 42,"temperature": 0.1} |
|
|
|
|
| payload = { |
| "model": model, |
| "messages": messages, |
| "options": options, |
| "stream": False |
| } |
|
|
| response = requests.post(url, json=payload) |
|
|
| if response.status_code == 200: |
| return response.json()["message"]["content"] |
| else: |
| return f"Error: {response.status_code}, {response.text}" |
|
|
| def check_repetition(text, threshold=0.2): |
| if not text: |
| return False |
| |
| words = [word.strip for word in text.split('#')] |
|
|
| unique_words = len(set(words)) |
| total_words = len(words) |
|
|
| if "<unk>" in words: |
| logging.debug(f"Check repetition: '<unk>' was generated in the answer") |
| return True |
|
|
| |
| is_repetitive = unique_words < total_words * threshold |
| logging.debug(f"Check repetition: {is_repetitive} (Unique: {unique_words}, Total: {total_words})") |
| return is_repetitive |
|
|
|
|
| |
| def prepare_dataset(prediction: pd.DataFrame, validation: pd.DataFrame, modality:str): |
| if modality=='text2sign': |
| validation = validation.rename(columns={'fsw':'gold_fsw_seq','symbol': 'gold_symbol_seq', 'word': 'gold_cd'}) |
| metrics = prediction.merge(validation[['gold_symbol_seq','gold_cd', 'sentence','gold_fsw_seq']], on=['sentence']) |
| elif modality=='sign2text': |
| validation = validation.rename(columns={'word': 'gold_cd'}) |
| metrics = prediction.merge(validation[['sentence','gold_cd']], on=['gold_cd']) |
| return metrics |
|
|
| |
| def cos_sim(a, b): |
| return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
| def find_most_similar_sentence(user_embedding, train_sentences: pd.DataFrame, n=3, unk_threshold=7): |
| |
| sentence_embeddings = np.vstack(train_sentences["embedding_sentence"].values) |
| decompositions = train_sentences["decomposition"].values |
| sentences = train_sentences["sentence"].values |
| |
| |
| sentence_embeddings = normalize(sentence_embeddings, axis=1) |
| user_embedding = normalize(user_embedding.reshape(1, -1), axis=1) |
| |
| |
| similarities = np.dot(sentence_embeddings, user_embedding.T).flatten() |
| |
| |
| unk_counts = np.array([d.count("<unk>") for d in decompositions]) |
| similarities[unk_counts > unk_threshold] = 0 |
| |
| |
| top_n_indices = np.argsort(similarities)[-n:][::-1] |
| |
| |
| return [decompositions[i] for i in top_n_indices], [sentences[i] for i in top_n_indices] |
|
|
|
|
| def find_most_similar_canonical_entry(user_embedding, vocabulary: pd.DataFrame, n=30): |
| |
| vocabulary_embeddings = np.vstack(vocabulary["embedding"].values) |
| vocabulary_words = vocabulary["word"].values |
| |
| |
| vocabulary_embeddings = normalize(vocabulary_embeddings, axis=1) |
| user_embedding = normalize(user_embedding.reshape(1, -1), axis=1) |
| |
| |
| similarities = np.dot(vocabulary_embeddings, user_embedding.T).flatten() |
| |
| |
| sorted_indices = np.argsort(similarities)[::-1] |
| |
| |
| canonical_list = [] |
| canonical_similarities = [] |
| |
| for idx in sorted_indices: |
| if len(canonical_list) >= n: |
| break |
| |
| |
| canonical_entry = get_most_freq(vocabulary_words[idx]) |
| |
| |
| if canonical_entry not in canonical_list: |
| canonical_list.append(canonical_entry) |
| canonical_similarities.append(similarities[idx]) |
| |
| |
| return canonical_list |
|
|
|
|
| def get_most_freq(lista:list): |
| lista_cleaned = [] |
| for segno in lista: |
| segno_pulito = segno.lower().strip() |
| if segno_pulito not in lista_cleaned: |
| lista_cleaned.append(segno_pulito) |
|
|
| frequency_count = Counter(lista_cleaned) |
| |
| top_two_words = frequency_count.most_common(2) |
|
|
| if len(top_two_words) >= 2: |
| first_word = top_two_words[0][0] |
| second_word = top_two_words[1][0] |
|
|
| return first_word+'|'+second_word |
| else: |
| first_word = top_two_words[0][0] |
| return first_word |
|
|
| def get_most_freq_fsw(lista_fsw): |
| if isinstance(lista_fsw,str): |
| return lista_fsw |
| else: |
| frequency_count = Counter(lista_fsw) |
| max_freq_word = frequency_count.most_common(1)[0][0] |
| return max_freq_word |
|
|
|
|
| def get_fsw_exact(vocabulary: pd.DataFrame, can_desc_answer, model, top_k=10): |
| |
| vocabulary_embeddings = np.vstack(vocabulary["embedding"].values) |
| vocabulary_words = vocabulary["word"].values |
| vocabulary_fsw = vocabulary["fsw"].values |
|
|
| |
| vocabulary_embeddings = normalize(vocabulary_embeddings, axis=1) |
|
|
| fsw_seq = [] |
| can_desc_association_seq = [] |
| joint_prob = 1 |
|
|
| for can_d in can_desc_answer: |
| |
| can_d_emb = model.encode(can_d, normalize_embeddings=True).reshape(1, -1) |
|
|
| |
| similarities = np.dot(vocabulary_embeddings, can_d_emb.T).flatten() |
|
|
| |
| top_k_indices = np.argsort(similarities)[-top_k:][::-1] |
| top_k_words = vocabulary_words[top_k_indices] |
| top_k_fsws = vocabulary_fsw[top_k_indices] |
| top_k_similarities = similarities[top_k_indices] |
|
|
| |
| exact_match_index = next((i for i, word in enumerate(top_k_words) if get_most_freq(word) == can_d.strip()), None) |
|
|
| if exact_match_index is not None: |
| |
| most_similar_word = get_most_freq(top_k_words[exact_match_index]) |
| fsw = top_k_fsws[exact_match_index] |
| max_similarity = 1 |
| else: |
| |
| max_index = 0 |
| most_similar_word = get_most_freq(top_k_words[max_index]) |
| fsw = top_k_fsws[max_index] |
| max_similarity = top_k_similarities[max_index] |
|
|
| |
| logging.info(fsw) |
| fsw_seq.append(get_most_freq_fsw(fsw)) |
| joint_prob *= max_similarity |
| can_desc_association_seq.append(most_similar_word) |
|
|
| |
| logging.debug(f"Word: {can_d}") |
| logging.debug(f"Most similar word in vocabulary: {most_similar_word}") |
| logging.debug(f"Similarity: {max_similarity}") |
| logging.debug(f"Fsw_seq: {' '.join(fsw_seq)}") |
| logging.debug("---") |
|
|
| |
| joint_prob = pow(joint_prob, 1 / len(can_desc_association_seq)) |
| |
| return ' '.join(fsw_seq), ' # '.join(can_desc_association_seq), np.round(joint_prob, 3) |
|
|
| |
| def AulSign(input:str, rules_prompt_path:str, train_sentences:pd.DataFrame, vocabulary:pd.DataFrame, model, ollama:bool, modality:str): |
| """ |
| AulSign: A function for translating between text and Formal SignWriting (FSW) or vice versa. |
| |
| This function leverages embeddings, similarity matching, and language models to facilitate |
| translations based on the specified modality (`text2sign` or `sign2text`). |
| |
| Args: |
| input (str): |
| The sentence or sign sequence to be analyzed and translated. |
| rules_prompt_path (str): |
| Path to a file containing predefined prompts and rules to guide the language model. |
| train_sentences (pd.DataFrame): |
| A dataset containing sentences and their embeddings for training or similarity matching. |
| vocabulary (pd.DataFrame): |
| A table of vocabulary entries with canonical descriptions and embeddings, used for matching. |
| model: |
| The embedding model used to convert sentences or sign sequences into vector representations. |
| ollama (bool): |
| Specifies whether to use the `query_ollama` method for querying the language model. |
| modality (str): |
| The translation mode: |
| - `'text2sign'`: Converts text to Formal SignWriting sequences. |
| - `'sign2text'`: Converts Formal SignWriting to textual sentences. |
| |
| Returns: |
| For `modality == "text2sign"`: |
| tuple: |
| - answer (str): |
| The translated text or decomposition provided by the language model. |
| - fsw (list): |
| A list of Formal SignWriting sequences associated with the translation. |
| - can_desc_association_seq (list): |
| A list of canonical descriptions associated with the FSW sequences. |
| - joint_prob (float): |
| The joint probability of the most likely translation path. |
| |
| For `modality == "sign2text"`: |
| str: |
| The reconstructed textual sentence translated from the input sign sequence. |
| |
| If an invalid modality is provided: |
| str: |
| Returns 'error' to indicate invalid input. |
| |
| Raises: |
| Exception: |
| Logs and raises errors encountered during API calls or message construction. |
| """ |
| |
| sent_embedding = model.encode(input, normalize_embeddings=True) |
|
|
| if modality =='text2sign': |
| |
| similar_canonical = find_most_similar_canonical_entry(sent_embedding, vocabulary, n=100) |
| |
|
|
| |
| similar_canonical_str = ' # '.join(similar_canonical) |
|
|
| |
| with open(rules_prompt_path, 'r') as file: |
| rules_prompt = file.read().format(similar_canonical=similar_canonical_str) |
|
|
| |
| decomposition, sentences = find_most_similar_sentence( |
| user_embedding=sent_embedding, |
| train_sentences=train_sentences, |
| n=20 |
| ) |
|
|
| messages = [{"role": "system", "content": rules_prompt}] |
| for sentence, decomposition in zip(sentences, decomposition): |
| |
| if sentence and decomposition: |
| messages.append({"role": "user", "content": sentence}) |
| messages.append({"role": "assistant", "content": decomposition}) |
| else: |
| logging.warning("Missing 'sentence' or 'decomposition' in messages.") |
|
|
| messages.append({"role": "user", "content": "decompose the following sentence as shown in the previous examples"}) |
| messages.append({"role": "user", "content": input}) |
| |
| |
| valid_messages = [] |
| for message in messages: |
| if 'role' in message and 'content' in message: |
| valid_messages.append(message) |
| logging.debug(message) |
| else: |
| logging.error(f"Invalid message format detected: {message}") |
|
|
| if ollama: |
| |
| answer = query_ollama(messages) |
|
|
| logging.info("\n[LOG] MISTRAL Answer:") |
| logging.info(answer) |
|
|
| can_description_answer = answer.split('#') |
| else: |
| try: |
| |
| completion = client.chat.completions.create( |
| model="gpt-3.5-turbo", |
| messages=messages, |
| temperature=0 |
| ) |
| answer = completion.choices[0].message.content |
|
|
| if check_repetition(answer): |
| |
| presence_penalty = 0.6 |
| completion = client.chat.completions.create( |
| model="gpt-3.5-turbo", |
| messages=messages, |
| presence_penalty=presence_penalty, |
| temperature=0 |
| ) |
| logging.info(f"presence_penalty: {presence_penalty}") |
| answer = completion.choices[0].message.content |
| logging.info('ANSWER: GPT') |
| logging.info(answer + '\n\n') |
|
|
| |
| can_description_answer = answer.split('#') |
| |
| else: |
| logging.info('ANSWER: GPT') |
| logging.info(answer + '\n\n') |
|
|
| |
| can_description_answer = answer.split('#') |
|
|
|
|
| except Exception as e: |
| logging.error(f"Error during GPT API call: {e}") |
|
|
| |
| fsw, can_desc_association_seq, joint_prob = get_fsw_exact( |
| vocabulary=vocabulary, |
| can_desc_answer=can_description_answer, |
| model=model |
| ) |
|
|
| return answer, fsw, can_desc_association_seq, joint_prob |
| |
| elif modality =='sign2text': |
|
|
| |
| with open(rules_prompt_path, 'r') as file: |
| rules_prompt = file.read() |
|
|
|
|
| |
| decomposition, sentences = find_most_similar_sentence( |
| user_embedding=sent_embedding, |
| train_sentences=train_sentences, |
| n=30 |
| ) |
|
|
| messages = [{"role": "system", "content": rules_prompt}] |
| for sentence, decomposition in zip(sentences, decomposition): |
| |
| if sentence and decomposition: |
| messages.append({"role": "user", "content": decomposition}) |
| messages.append({"role": "assistant", "content": sentence}) |
| else: |
| logging.warning("Missing 'sentence' or 'decomposition' in messages.") |
|
|
| messages.append({"role": "user", "content": "reconstruct the sentence as shown on the examples above"}) |
| messages.append({"role": "user", "content": input}) |
| |
| |
| valid_messages = [] |
| for message in messages: |
| if 'role' in message and 'content' in message: |
| valid_messages.append(message) |
| logging.debug(message) |
| else: |
| logging.error(f"Invalid message format detected: {message}") |
|
|
| if ollama: |
| |
| answer = query_ollama(messages) |
|
|
| logging.info("\n[LOG] MISTRAL Answer:") |
| logging.info(answer) |
|
|
| can_description_answer = answer.split('#') |
| else: |
| try: |
| |
| completion = client.chat.completions.create( |
| model="gpt-3.5-turbo", |
| messages=messages, |
| temperature=0 |
| ) |
| answer = completion.choices[0].message.content |
| logging.info('ANSWER: GPT') |
| logging.info(answer + '\n\n') |
|
|
|
|
| except Exception as e: |
| logging.error(f"Error during GPT API call: {e}") |
|
|
| return answer |
| else: |
| return 'error' |
| |
|
|
| def main(modality, setup, input=None): |
| np.random.seed(42) |
| current_time = datetime.now().strftime("%Y_%m_%d_%H_%M") |
| data_path = f"data/preprocess_output_{setup}/file_comparison" |
| corpus_embeddings_path = 'tools/corpus_embeddings.json' |
| if setup is None: |
| sentences_train_embeddings_path = f"tools/sentences_train_embeddings_filtered_01.json" |
| else: |
| sentences_train_embeddings_path = f"tools/sentences_train_embeddings_{setup}.json" |
| rules_prompt_path_text2sign = 'tools/rules_prompt_text2sign.txt' |
| rules_prompt_path_sign2text = 'tools/rules_prompt_sign2text.txt' |
|
|
| |
| model_name = "mixedbread-ai/mxbai-embed-large-v1" |
| model = SentenceTransformer(model_name) |
|
|
| |
| with open(corpus_embeddings_path, 'r') as file: |
| corpus_embeddings = pd.DataFrame(json.load(file)) |
|
|
| with open(sentences_train_embeddings_path, 'r') as file: |
| sentences_train_embeddings = pd.DataFrame(json.load(file)) |
|
|
| if input: |
| if modality == 'text2sign': |
| answer, fsw_seq, can_desc_association_seq, joint_prob = AulSign( |
| input=input, |
| rules_prompt_path=rules_prompt_path_text2sign, |
| train_sentences=sentences_train_embeddings, |
| vocabulary=corpus_embeddings, |
| model=model, |
| ollama=False, |
| modality=modality |
| ) |
| |
| print(f"Canonical Descriptions: {can_desc_association_seq}") |
| print(f"Translation (FSW): {fsw_seq}") |
| |
| |
| |
| elif modality == 'sign2text': |
| mapped_input = sign2text(input,corpus_embeddings_path) |
| logging.info(f"\nReconstructed Sentence via Vocaboulary: {mapped_input}") |
| answer= AulSign( |
| input=mapped_input, |
| rules_prompt_path=rules_prompt_path_sign2text, |
| train_sentences=sentences_train_embeddings, |
| vocabulary=corpus_embeddings, |
| model=model, |
| ollama=False, |
| modality=modality |
| ) |
| print(f"Input Sign Voucaboualry Mapping: {input}") |
| print(f"Translation (Text): {answer}") |
|
|
| else: |
| test_path = os.path.join(data_path, f"test.csv") |
| test = pd.read_csv(test_path) |
| test = test.head(1) |
|
|
| if modality == 'text2sign': |
| list_sentence = [] |
| list_answer = [] |
| list_fsw_seq = [] |
| can_desc_association_list = [] |
| prob_of_association_list = [] |
|
|
| for index, row in test.iterrows(): |
| sentence = row['sentence'] |
| answer, fsw_seq, can_desc_association_seq, joint_prob = AulSign( |
| input=sentence, |
| rules_prompt_path=rules_prompt_path_text2sign, |
| train_sentences=sentences_train_embeddings, |
| vocabulary=corpus_embeddings, |
| model=model, |
| ollama=False, |
| modality=modality |
| ) |
|
|
| list_sentence.append(sentence) |
| list_answer.append(answer) |
| list_fsw_seq.append(fsw_seq) |
| can_desc_association_list.append(can_desc_association_seq) |
| prob_of_association_list.append(joint_prob) |
| |
| df_pred = pd.DataFrame({ |
| 'sentence': list_sentence, |
| 'pseudo_cd': list_answer, |
| 'pred_cd': can_desc_association_list, |
| 'joint_prob': prob_of_association_list, |
| 'pred_fsw_seq': list_fsw_seq |
| }) |
| output_path = os.path.join('result', f"{modality}_{current_time}") |
| os.makedirs(output_path, exist_ok=True) |
| df_pred = prepare_dataset(df_pred,test,modality) |
| df_pred.to_csv(os.path.join(output_path, f'result_{current_time}.csv'), index=False) |
|
|
| elif modality == 'sign2text': |
|
|
| list_answer = [] |
| list_gold_cd = [] |
|
|
| for index, row in test.iterrows(): |
| dec_sentence = row['word'] |
| answer = AulSign( |
| input=dec_sentence, |
| rules_prompt_path=rules_prompt_path_sign2text, |
| train_sentences=sentences_train_embeddings, |
| vocabulary=corpus_embeddings, |
| model=model, |
| ollama=False, |
| modality=modality |
| ) |
| list_gold_cd.append(dec_sentence) |
| list_answer.append(answer) |
| |
| df_pred = pd.DataFrame({ |
| 'pseudo_sentence': list_answer, |
| 'gold_cd': list_gold_cd, |
| }) |
| output_path = os.path.join('result', f"{modality}_{current_time}") |
| os.makedirs(output_path, exist_ok=True) |
| df_pred = prepare_dataset(df_pred,test,modality) |
| df_pred.to_csv(os.path.join(output_path, f'result_{current_time}.csv'), index=False) |
|
|
| if __name__ == "__main__": |
| |
| |
| |
| |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--mode", required=True, help="Mode of operation: text2sign or sign2text") |
| parser.add_argument("--input", help="Input text or sign sequence") |
| args = parser.parse_args() |
|
|
| main(args.mode, setup=None, input=args.input) |