| from transformers import Pipeline, AutoModelForTokenClassification |
| import numpy as np |
| from eval import retrieve_predictions, align_tokens_labels_from_wordids |
| from reading import read_dataset |
| from utils import read_config |
|
|
|
|
|
|
| def write_sentences_to_format(sentences: list[str], filename: str): |
| """ |
| Écrit une phrase dans un fichier, un mot par ligne, avec le format : |
| index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=... |
| """ |
|
|
| if not sentences: |
| return "" |
| if isinstance(sentences, str): |
| sentences=[sentences] |
| import sys |
| sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n") |
| |
| full="# newdoc_id = GUM_academic_discrimination\n" |
| for sentence in sentences: |
| words = sentence.strip().split() |
| for i, word in enumerate(words, start=1): |
| |
| seg_label = "B-seg" if i == 1 or word[0].isupper() else "O" |
| line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n" |
| full+=line |
| if filename: |
| with open(filename, "w", encoding="utf-8") as f: |
| f.write(full) |
| |
| return full |
|
|
|
|
| class DiscoursePipeline(Pipeline): |
| def __init__(self, model, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs): |
| auto_model = AutoModelForTokenClassification.from_pretrained(model) |
| super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs) |
| self.config = {"model_checkpoint": model, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{ |
| "padding":"max_length", |
| "truncation":True, |
| "max_length": 512 |
| }} |
| self.model = model |
| self.output_folder = output_folder |
|
|
| def _sanitize_parameters(self, **kwargs): |
| |
| preprocess_params = {} |
| forward_params = {} |
| postprocess_params = {} |
| return preprocess_params, forward_params, postprocess_params |
|
|
| def preprocess(self, text:str): |
| self.original_text=text |
| formatted_text=write_sentences_to_format(text.split("\n"), filename=None) |
| dataset, _ = read_dataset( |
| formatted_text, |
| output_path=self.output_folder, |
| config=self.config, |
| add_lang_token=True, |
| add_frame_token=True, |
| ) |
| return {"dataset": dataset} |
|
|
| def _forward(self, inputs): |
| dataset = inputs["dataset"] |
| preds_from_model, label_ids, _ = retrieve_predictions( |
| self.model, dataset, self.output_folder, self.tokenizer, self.config |
| ) |
| return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset} |
|
|
| def postprocess(self, outputs): |
| preds = np.argmax(outputs["preds"], axis=-1) |
| predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer) |
| edus=text_to_edus(self.original_text, predictions) |
| return edus |
|
|
| def get_plain_text_from_format(formatted_text:str) -> str: |
| """ |
| Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. |
| """ |
| formatted_text=formatted_text.split("\n") |
| s="" |
| for line in formatted_text: |
| if not line.startswith("#"): |
| if len(line.split("\t"))>1: |
| s+=line.split("\t")[1]+" " |
| return s.strip() |
|
|
|
|
| def get_preds_from_format(formatted_text:str) -> str: |
| """ |
| Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. |
| """ |
| formatted_text=formatted_text.split("\n") |
| s="" |
| for line in formatted_text: |
| if not line.startswith("#"): |
| if len(line.split("\t"))>1: |
| s+=line.split("\t")[-1]+" " |
| return s.strip() |
|
|
|
|
| def text_to_edus(text: str, labels: list[str]) -> list[str]: |
| """ |
| Découpe un texte brut en EDUs à partir d'une séquence de labels BIO. |
| |
| Args: |
| text (str): Le texte brut (séquence de mots séparés par des espaces). |
| labels (list[str]): La séquence de labels BIO (B, I, O), |
| de même longueur que le nombre de tokens du texte. |
| |
| Returns: |
| list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte). |
| """ |
| words = text.strip().split() |
| if len(words) != len(labels): |
| raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels") |
|
|
| edus = [] |
| current_edu = [] |
|
|
| for word, label in zip(words, labels): |
| if label == "Conn=O" or label == "Seg=O": |
| current_edu.append(word) |
|
|
| elif label == "Conn=B-conn" or label == "Seg=B-seg": |
| |
| if current_edu: |
| |
| edus.append(" ".join(current_edu)) |
| current_edu = [] |
| current_edu.append(word) |
|
|
| |
| if current_edu: |
| edus.append(" ".join(current_edu)) |
|
|
| return edus |
|
|