| |
| """ |
| Named Entity Recognition (NER) using Transformers |
| Extracts entities like PERSON, LOCATION, ORGANIZATION from text |
| """ |
|
|
| from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification |
| import argparse |
| from typing import List, Dict, Any |
| import json |
| import os |
| import logging |
|
|
| |
| logging.basicConfig( |
| level=logging.DEBUG, |
| format='%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| class TransformerNER: |
| |
| |
| MODELS = { |
| "dslim-bert": "dslim/bert-base-NER", |
| "dbmdz-bert": "dbmdz/bert-large-cased-finetuned-conll03-english", |
| "xlm-roberta": "xlm-roberta-large-finetuned-conll03-english", |
| "distilbert": "distilbert-base-cased-distilled-squad" |
| } |
| |
| def __init__(self, model_name: str = "dslim/bert-base-NER", aggregation_strategy: str = "simple"): |
| """ |
| Initialize NER pipeline with specified model |
| Default model: dslim/bert-base-NER (lightweight BERT model fine-tuned for NER) |
| """ |
| self.logger = logging.getLogger(__name__) |
| self.current_model_name = model_name |
| self.cache_dir = os.path.join(os.path.dirname(__file__), "model_cache") |
| os.makedirs(self.cache_dir, exist_ok=True) |
| |
| self._load_model(model_name, aggregation_strategy) |
| |
| def _load_model(self, model_name: str, aggregation_strategy: str = "simple"): |
| """Load or reload model with given parameters""" |
| |
| if model_name in self.MODELS: |
| resolved_name = self.MODELS[model_name] |
| else: |
| resolved_name = model_name |
| |
| self.current_model_name = model_name |
| self.aggregation_strategy = aggregation_strategy |
| |
| self.logger.info(f"Loading model: {resolved_name}") |
| self.logger.info(f"Cache directory: {self.cache_dir}") |
| self.logger.info(f"Aggregation strategy: {aggregation_strategy}") |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(resolved_name, cache_dir=self.cache_dir) |
| self.model = AutoModelForTokenClassification.from_pretrained(resolved_name, cache_dir=self.cache_dir) |
| self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, aggregation_strategy=aggregation_strategy) |
| self.logger.info("Model loaded successfully!") |
| |
| def switch_model(self, model_name: str, aggregation_strategy: str = None): |
| """Switch to a different model dynamically""" |
| if aggregation_strategy is None: |
| aggregation_strategy = self.aggregation_strategy |
| |
| try: |
| self._load_model(model_name, aggregation_strategy) |
| return True |
| except Exception as e: |
| self.logger.error(f"Failed to load model '{model_name}': {e}") |
| return False |
| |
| def change_aggregation(self, aggregation_strategy: str): |
| """Change aggregation strategy for current model""" |
| try: |
| self._load_model(self.current_model_name, aggregation_strategy) |
| return True |
| except Exception as e: |
| self.logger.error(f"Failed to change aggregation to '{aggregation_strategy}': {e}") |
| return False |
| |
| def _post_process_entities(self, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Post-process entities to fix common boundary and classification issues |
| """ |
| corrected = [] |
| |
| for entity in entities: |
| text = entity["text"].strip() |
| entity_type = entity["entity"] |
| |
| |
| if not text: |
| continue |
| |
| |
| corrected_entity = entity.copy() |
| |
| |
| if entity_type == "ORG" and len(text.split()) == 1: |
| |
| if any(text.lower().endswith(suffix) for suffix in ['i', 'a', 'o']) or text.istitle(): |
| corrected_entity["entity"] = "PER" |
| self.logger.debug(f"Fixed: '{text}' ORG -> PER") |
| |
| |
| countries = ['India', 'China', 'USA', 'UK', 'Germany', 'France', 'Japan'] |
| if text in countries and entity_type != "LOC": |
| corrected_entity["entity"] = "LOC" |
| self.logger.debug(f"Fixed: '{text}' {entity_type} -> LOC") |
| |
| |
| words = text.split() |
| if len(words) >= 2 and entity_type == "ORG": |
| |
| if words[0].istitle() and words[1].lower() in ['launches', 'announces', 'says', 'opens', 'creates', 'launch']: |
| |
| corrected_entity["text"] = words[0] |
| corrected_entity["entity"] = "PER" |
| corrected_entity["end"] = corrected_entity["start"] + len(words[0]) |
| self.logger.info(f"Split entity: '{text}' -> PER: '{words[0]}'") |
| |
| |
| tech_terms = ['electric', 'suv', 'car', 'vehicle', 'app', 'software', 'ai', 'robot', 'global'] |
| if any(term in text.lower() for term in tech_terms): |
| if entity_type != "MISC": |
| corrected_entity["entity"] = "MISC" |
| self.logger.info(f"Fixed: '{text}' {entity_type} -> MISC") |
| else: |
| self.logger.debug(f"Already MISC: '{text}'") |
| |
| corrected.append(corrected_entity) |
| |
| return corrected |
|
|
| def extract_entities(self, text: str, return_both: bool = False) -> Dict[str, List[Dict[str, Any]]]: |
| """ |
| Extract named entities from text |
| Returns list of entities with their labels, scores, and positions |
| |
| If return_both=True, returns dict with 'cleaned' and 'corrected' keys |
| If return_both=False, returns just the corrected entities (backward compatibility) |
| """ |
| entities = self.nlp(text) |
| |
| |
| cleaned_entities = [] |
| for entity in entities: |
| cleaned_entities.append({ |
| "entity": entity["entity_group"], |
| "text": entity["word"], |
| "score": round(entity["score"], 4), |
| "start": entity["start"], |
| "end": entity["end"] |
| }) |
| |
| |
| corrected_entities = self._post_process_entities(cleaned_entities) |
| |
| if return_both: |
| return { |
| "cleaned": cleaned_entities, |
| "corrected": corrected_entities |
| } |
| else: |
| return corrected_entities |
| |
| def extract_entities_debug(self, text: str) -> Dict[str, List[Dict[str, Any]]]: |
| """ |
| Extract entities and return both cleaned and corrected versions for debugging |
| """ |
| return self.extract_entities(text, return_both=True) |
| |
| def extract_entities_by_type(self, text: str) -> Dict[str, List[str]]: |
| """ |
| Extract entities grouped by type |
| Returns dictionary with entity types as keys |
| """ |
| entities = self.extract_entities(text) |
| |
| grouped = {} |
| for entity in entities: |
| entity_type = entity["entity"] |
| if entity_type not in grouped: |
| grouped[entity_type] = [] |
| if entity["text"] not in grouped[entity_type]: |
| grouped[entity_type].append(entity["text"]) |
| |
| return grouped |
| |
| def format_output(self, entities: List[Dict[str, Any]], text: str) -> str: |
| """ |
| Format entities for display with context |
| """ |
| output = [] |
| output.append("=" * 60) |
| output.append("NAMED ENTITY RECOGNITION RESULTS") |
| output.append("=" * 60) |
| output.append(f"\nOriginal Text:\n{text}\n") |
| output.append("-" * 40) |
| output.append("Entities Found:") |
| output.append("-" * 40) |
| |
| if not entities: |
| output.append("No entities found.") |
| else: |
| for entity in entities: |
| output.append(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
| |
| return "\n".join(output) |
| |
| def format_debug_output(self, debug_results: Dict[str, List[Dict[str, Any]]], text: str) -> str: |
| """ |
| Format debug output showing both cleaned and corrected entities |
| """ |
| output = [] |
| output.append("=" * 70) |
| output.append("NER DEBUG: BEFORE & AFTER POST-PROCESSING") |
| output.append("=" * 70) |
| output.append(f"\nOriginal Text:\n{text}\n") |
| |
| cleaned = debug_results["cleaned"] |
| corrected = debug_results["corrected"] |
| |
| |
| output.append("🔍 BEFORE Post-Processing (Raw Model Output):") |
| output.append("-" * 50) |
| if not cleaned: |
| output.append("No entities found by model.") |
| else: |
| for entity in cleaned: |
| output.append(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
| |
| output.append("") |
| |
| |
| output.append("✨ AFTER Post-Processing (Corrected):") |
| output.append("-" * 50) |
| if not corrected: |
| output.append("No entities after correction.") |
| else: |
| for entity in corrected: |
| output.append(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
| |
| |
| output.append("") |
| output.append("📝 Changes Made:") |
| output.append("-" * 25) |
| |
| changes_found = False |
| |
| |
| cleaned_lookup = {(e['text'], e['entity']) for e in cleaned} |
| corrected_lookup = {(e['text'], e['entity']) for e in corrected} |
| |
| |
| for corrected_entity in corrected: |
| corrected_key = (corrected_entity['text'], corrected_entity['entity']) |
| |
| |
| original_entity = None |
| for cleaned_entity in cleaned: |
| if (cleaned_entity['text'] == corrected_entity['text'] and |
| cleaned_entity['entity'] != corrected_entity['entity']): |
| original_entity = cleaned_entity |
| break |
| |
| if original_entity: |
| output.append(f" Fixed: '{original_entity['text']}' {original_entity['entity']} → {corrected_entity['entity']}") |
| changes_found = True |
| |
| |
| for corrected_entity in corrected: |
| found_exact_match = False |
| for cleaned_entity in cleaned: |
| if (cleaned_entity['text'] == corrected_entity['text'] and |
| cleaned_entity['entity'] == corrected_entity['entity']): |
| found_exact_match = True |
| break |
| |
| if not found_exact_match: |
| |
| for cleaned_entity in cleaned: |
| if (corrected_entity['text'] in cleaned_entity['text'] and |
| corrected_entity['text'] != cleaned_entity['text']): |
| output.append(f" Split: '{cleaned_entity['text']}' → '{corrected_entity['text']}'") |
| changes_found = True |
| break |
| |
| if not changes_found: |
| output.append(" No changes made by post-processing.") |
| |
| return "\n".join(output) |
|
|
|
|
| def interactive_mode(ner: TransformerNER): |
| """ |
| Interactive mode that keeps the model loaded and processes multiple texts |
| """ |
| print("\n" + "=" * 60) |
| print("INTERACTIVE NER MODE") |
| print("=" * 60) |
| print("Enter text to analyze (or 'quit' to exit)") |
| print("Commands: 'help' for full list, 'model <name>' to switch models") |
| print("=" * 60) |
| |
| grouped_mode = False |
| json_mode = False |
| debug_mode = False |
| |
| def show_help(): |
| print("\n" + "=" * 50) |
| print("INTERACTIVE COMMANDS") |
| print("=" * 50) |
| print("Output Modes:") |
| print(f" grouped - Toggle grouped output (currently: {'ON' if grouped_mode else 'OFF'})") |
| print(f" json - Toggle JSON output (currently: {'ON' if json_mode else 'OFF'})") |
| print(f" debug - Toggle debug mode - show before/after post-processing (currently: {'ON' if debug_mode else 'OFF'})") |
| print("\nModel Management:") |
| print(" model <name> - Switch to model (e.g., 'model dbmdz-bert')") |
| print(" models - List available model shortcuts") |
| print(" agg <strat> - Change aggregation (simple/first/average/max)") |
| print("\nFile Operations:") |
| print(" file <path> - Analyze text from file") |
| print("\nInformation:") |
| print(" info - Show current configuration") |
| print(" help - Show this help") |
| print(" quit - Exit interactive mode") |
| print("=" * 50) |
| |
| def show_models(): |
| print("\nAvailable model shortcuts:") |
| print("-" * 50) |
| for shortcut, full_name in TransformerNER.MODELS.items(): |
| current = " (current)" if shortcut == ner.current_model_name or full_name == ner.current_model_name else "" |
| print(f" {shortcut:<15} -> {full_name}{current}") |
| print(f"\nUsage: 'model <shortcut>' (e.g., 'model dbmdz-bert')") |
| print(f"Aggregation strategies: {['simple', 'first', 'average', 'max']}") |
| print(f"Usage: 'agg <strategy>' (e.g., 'agg first')") |
| |
| def show_info(): |
| resolved_name = ner.MODELS.get(ner.current_model_name, ner.current_model_name) |
| print(f"\nCurrent Configuration:") |
| print(f" Model: {ner.current_model_name}") |
| print(f" Full name: {resolved_name}") |
| print(f" Aggregation: {ner.aggregation_strategy}") |
| print(f" Grouped mode: {'ON' if grouped_mode else 'OFF'}") |
| print(f" JSON mode: {'ON' if json_mode else 'OFF'}") |
| print(f" Debug mode: {'ON' if debug_mode else 'OFF'}") |
| print(f" Cache dir: {ner.cache_dir}") |
| |
| def switch_model(model_name: str): |
| print(f"Switching to model: {model_name}") |
| if ner.switch_model(model_name): |
| print(f"✅ Successfully switched to {model_name}") |
| return True |
| else: |
| print(f"❌ Failed to switch to {model_name}") |
| return False |
| |
| def change_aggregation(strategy: str): |
| valid_strategies = ["simple", "first", "average", "max"] |
| if strategy not in valid_strategies: |
| print(f"❌ Invalid aggregation strategy. Valid options: {valid_strategies}") |
| return False |
| |
| print(f"Changing aggregation to: {strategy}") |
| if ner.change_aggregation(strategy): |
| print(f"✅ Successfully changed aggregation to {strategy}") |
| return True |
| else: |
| print(f"❌ Failed to change aggregation to {strategy}") |
| return False |
| |
| def process_file(file_path: str): |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| file_text = f.read() |
| print(f"📁 Processing file: {file_path}") |
| return file_text.strip() |
| except Exception as e: |
| print(f"❌ Error reading file '{file_path}': {e}") |
| return None |
| |
| while True: |
| try: |
| print("\n> ", end="", flush=True) |
| user_input = input().strip() |
| |
| if not user_input: |
| continue |
| |
| |
| parts = user_input.split(None, 1) |
| command = parts[0].lower() |
| args = parts[1] if len(parts) > 1 else "" |
| |
| |
| if command in ['quit', 'exit', 'q']: |
| print("Goodbye!") |
| break |
| |
| |
| elif command == 'grouped': |
| grouped_mode = not grouped_mode |
| print(f"Grouped mode: {'ON' if grouped_mode else 'OFF'}") |
| continue |
| |
| elif command == 'json': |
| json_mode = not json_mode |
| print(f"JSON mode: {'ON' if json_mode else 'OFF'}") |
| continue |
| |
| elif command == 'debug': |
| debug_mode = not debug_mode |
| print(f"Debug mode: {'ON' if debug_mode else 'OFF'}") |
| continue |
| |
| |
| elif command in ['models', 'list-models']: |
| show_models() |
| continue |
| |
| elif command == 'info': |
| show_info() |
| continue |
| |
| elif command == 'help': |
| show_help() |
| continue |
| |
| |
| elif command == 'model': |
| if not args: |
| print("❌ Please specify a model name. Use 'models' to see available options.") |
| continue |
| switch_model(args) |
| continue |
| |
| elif command in ['agg', 'aggregation']: |
| if not args: |
| print("❌ Please specify an aggregation strategy: simple, first, average, max") |
| continue |
| change_aggregation(args) |
| continue |
| |
| |
| elif command == 'file': |
| if not args: |
| print("❌ Please specify a file path.") |
| continue |
| file_content = process_file(args) |
| if file_content: |
| user_input = file_content |
| else: |
| continue |
| |
| |
| text = user_input if command != 'file' else file_content |
| |
| |
| if debug_mode: |
| |
| debug_results = ner.extract_entities_debug(text) |
| debug_output = ner.format_debug_output(debug_results, text) |
| print(debug_output) |
| else: |
| |
| if grouped_mode: |
| entities = ner.extract_entities_by_type(text) |
| else: |
| entities = ner.extract_entities(text) |
| |
| |
| if json_mode: |
| print(json.dumps(entities, indent=2)) |
| elif grouped_mode: |
| print("\nEntities by type:") |
| print("-" * 30) |
| if not entities: |
| print("No entities found.") |
| else: |
| for entity_type, entity_list in entities.items(): |
| print(f"{entity_type}: {', '.join(entity_list)}") |
| else: |
| if not entities: |
| print("No entities found.") |
| else: |
| print("\nEntities found:") |
| print("-" * 20) |
| for entity in entities: |
| print(f"• [{entity['entity']}] '{entity['text']}' (confidence: {entity['score']})") |
| |
| except KeyboardInterrupt: |
| print("\n\nGoodbye!") |
| break |
| except EOFError: |
| print("\nGoodbye!") |
| break |
| except Exception as e: |
| logger.error(f"Error processing text: {e}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Extract named entities from text using Transformers") |
| parser.add_argument("--text", type=str, help="Text to analyze") |
| parser.add_argument("--file", type=str, help="File containing text to analyze") |
| parser.add_argument("--model", type=str, default="dslim/bert-base-NER", |
| help="HuggingFace model to use. Shortcuts: dslim-bert, dbmdz-bert, xlm-roberta") |
| parser.add_argument("--aggregation", type=str, default="simple", |
| choices=["simple", "first", "average", "max"], |
| help="Aggregation strategy for subword tokens (default: simple)") |
| parser.add_argument("--json", action="store_true", help="Output as JSON") |
| parser.add_argument("--grouped", action="store_true", help="Group entities by type") |
| parser.add_argument("--interactive", "-i", action="store_true", help="Start interactive mode") |
| parser.add_argument("--list-models", action="store_true", help="List available model shortcuts") |
| |
| args = parser.parse_args() |
| |
| |
| if args.list_models: |
| print("\nAvailable model shortcuts:") |
| print("-" * 40) |
| for shortcut, full_name in TransformerNER.MODELS.items(): |
| print(f" {shortcut:<15} -> {full_name}") |
| print(f"\nDefault aggregation strategies: {['simple', 'first', 'average', 'max']}") |
| return |
| |
| |
| ner = TransformerNER(model_name=args.model, aggregation_strategy=args.aggregation) |
| |
| |
| if args.interactive: |
| interactive_mode(ner) |
| return |
| |
| |
| if args.file: |
| with open(args.file, 'r') as f: |
| text = f.read() |
| elif args.text: |
| text = args.text |
| else: |
| |
| interactive_mode(ner) |
| return |
| |
| if not text.strip(): |
| logging.error("No text provided") |
| return |
| |
| |
| if args.grouped: |
| entities = ner.extract_entities_by_type(text) |
| else: |
| entities = ner.extract_entities(text) |
| |
| |
| if args.json: |
| print(json.dumps(entities, indent=2)) |
| elif args.grouped: |
| print("\n" + "=" * 60) |
| print("ENTITIES GROUPED BY TYPE") |
| print("=" * 60) |
| for entity_type, entity_list in entities.items(): |
| print(f"\n{entity_type}:") |
| for item in entity_list: |
| print(f" • {item}") |
| else: |
| formatted = ner.format_output(entities, text) |
| print(formatted) |
|
|
|
|
| if __name__ == "__main__": |
| |
| example_sentences = [ |
| "Apple Inc. was founded by Steve Jobs in Cupertino, California.", |
| "Barack Obama was the 44th President of the United States.", |
| "The Eiffel Tower in Paris attracts millions of tourists each year.", |
| "Google's CEO Sundar Pichai announced new AI features at the conference in San Francisco.", |
| "Microsoft and OpenAI partnered to develop ChatGPT in Seattle." |
| ] |
| |
| |
| import sys |
| if len(sys.argv) == 1: |
| |
| logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
| |
| logging.info("Running demo with example sentences...\n") |
| ner = TransformerNER() |
| |
| for sentence in example_sentences: |
| print("\n" + "="*60) |
| print(f"Input: {sentence}") |
| print("-"*40) |
| entities = ner.extract_entities_by_type(sentence) |
| for entity_type, items in entities.items(): |
| print(f"{entity_type}: {', '.join(items)}") |
| |
| print("\n" + "="*60) |
| print("\nTo analyze your own text, use:") |
| print(" python ner_transformer.py --text 'Your text here'") |
| print(" python ner_transformer.py --file input.txt") |
| print(" python ner_transformer.py --json --grouped") |
| else: |
| |
| logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
| main() |
|
|