| import torch |
| import json |
| from huggingface_hub import hf_hub_download |
| import re |
| import emoji |
| from transformers import BertForSequenceClassification, BertTokenizer |
|
|
| def preprocess_text(text): |
| """Preprocess the input text to match training conditions.""" |
| text = re.sub(r'u/\w+', '[USER]', text) |
| text = re.sub(r'r/\w+', '[SUBREDDIT]', text) |
| text = re.sub(r'http[s]?://\S+', '[URL]', text) |
| text = emoji.demojize(text, delimiters=(" ", " ")) |
| text = text.lower() |
| return text |
|
|
| def load_model_and_resources(): |
| """Load the model, tokenizer, emotion labels, and thresholds from Hugging Face.""" |
| repo_id = "logasanjeev/emotions-analyzer-bert" |
| |
| try: |
| model = BertForSequenceClassification.from_pretrained(repo_id) |
| tokenizer = BertTokenizer.from_pretrained(repo_id) |
| except Exception as e: |
| raise RuntimeError(f"Error loading model/tokenizer: {str(e)}") |
|
|
| try: |
| thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json") |
| with open(thresholds_file, "r") as f: |
| thresholds_data = json.load(f) |
| if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data): |
| raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.") |
| emotion_labels = thresholds_data["emotion_labels"] |
| thresholds = thresholds_data["thresholds"] |
| except Exception as e: |
| raise RuntimeError(f"Error loading thresholds: {str(e)}") |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model.to(device) |
| model.eval() |
|
|
| return model, tokenizer, emotion_labels, thresholds, device |
|
|
| MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = None, None, None, None, None |
|
|
| def predict_emotions(text): |
| """Predict emotions for the given text using the GoEmotions BERT model. |
| |
| Args: |
| text (str): The input text to analyze. |
| |
| Returns: |
| tuple: (predictions, processed_text) |
| - predictions (str): Formatted string of predicted emotions and their confidence scores. |
| - processed_text (str): The preprocessed input text. |
| """ |
| global MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE |
| |
| if MODEL is None: |
| MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = load_model_and_resources() |
|
|
| processed_text = preprocess_text(text) |
| |
| encodings = TOKENIZER( |
| processed_text, |
| padding='max_length', |
| truncation=True, |
| max_length=128, |
| return_tensors='pt' |
| ) |
| |
| input_ids = encodings['input_ids'].to(DEVICE) |
| attention_mask = encodings['attention_mask'].to(DEVICE) |
| |
| with torch.no_grad(): |
| outputs = MODEL(input_ids, attention_mask=attention_mask) |
| logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] |
| |
| predictions = [] |
| for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)): |
| if logit >= thresh: |
| predictions.append((EMOTION_LABELS[i], round(logit, 4))) |
| |
| predictions.sort(key=lambda x: x[1], reverse=True) |
| |
| result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted." |
| return result, processed_text |
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT model.") |
| parser.add_argument("text", type=str, help="The input text to analyze for emotions.") |
| args = parser.parse_args() |
| |
| result, processed = predict_emotions(args.text) |
| print(f"Input: {args.text}") |
| print(f"Processed: {processed}") |
| print("Predicted Emotions:") |
| print(result) |