| |
| """ |
| NER-Small Inference Client |
| |
| A Python client for running inference with the Minibase-NER-Small model. |
| Handles named entity recognition requests to the local llama.cpp server. |
| """ |
|
|
| import requests |
| import json |
| from typing import Optional, Dict, Any, Tuple, List |
| import time |
| import re |
|
|
|
|
| class NERClient: |
| """ |
| Client for the NER-Small named entity recognition model. |
| |
| This client communicates with a local llama.cpp server running the |
| Minibase-NER-Small model for named entity recognition tasks. |
| """ |
|
|
| def __init__(self, base_url: str = "http://127.0.0.1:8000", timeout: int = 30): |
| """ |
| Initialize the NER client. |
| |
| Args: |
| base_url: Base URL of the llama.cpp server |
| timeout: Request timeout in seconds |
| """ |
| self.base_url = base_url.rstrip('/') |
| self.timeout = timeout |
| self.default_instruction = "Extract all named entities from the following text. List them as 1. Entity, 2. Entity, etc." |
|
|
| def _make_request(self, prompt: str, max_tokens: int = 512, |
| temperature: float = 0.1) -> Tuple[str, float]: |
| """ |
| Make a completion request to the model. |
| |
| Args: |
| prompt: The input prompt |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| |
| Returns: |
| Tuple of (response_text, latency_ms) |
| """ |
| payload = { |
| "prompt": prompt, |
| "max_tokens": max_tokens, |
| "temperature": temperature |
| } |
|
|
| headers = {'Content-Type': 'application/json'} |
|
|
| start_time = time.time() |
| try: |
| response = requests.post( |
| f"{self.base_url}/completion", |
| json=payload, |
| headers=headers, |
| timeout=self.timeout |
| ) |
|
|
| latency = (time.time() - start_time) * 1000 |
|
|
| if response.status_code == 200: |
| result = response.json() |
| return result.get('content', ''), latency |
| else: |
| return f"Error: HTTP {response.status_code}", latency |
|
|
| except requests.exceptions.RequestException as e: |
| latency = (time.time() - start_time) * 1000 |
| return f"Error: {e}", latency |
|
|
| def extract_entities(self, text: str, instruction: Optional[str] = None, |
| max_tokens: int = 512, temperature: float = 0.1) -> List[Dict[str, Any]]: |
| """ |
| Extract named entities from text. |
| |
| Args: |
| text: Input text to analyze |
| instruction: Custom instruction (uses default if None) |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| |
| Returns: |
| List of entity dictionaries with text and metadata |
| """ |
| if instruction is None: |
| instruction = self.default_instruction |
|
|
| prompt = f"{instruction}\n\nInput: {text}\n\nResponse: " |
|
|
| response_text, latency = self._make_request(prompt, max_tokens, temperature) |
|
|
| if response_text.startswith("Error"): |
| return [] |
|
|
| |
| entities = self._parse_entity_response(response_text) |
|
|
| |
| for entity in entities: |
| entity.update({ |
| 'confidence': 1.0, |
| 'latency_ms': latency |
| }) |
|
|
| return entities |
|
|
| def extract_entities_batch(self, texts: List[str], instruction: Optional[str] = None, |
| max_tokens: int = 512, temperature: float = 0.1) -> List[List[Dict[str, Any]]]: |
| """ |
| Extract named entities from multiple texts. |
| |
| Args: |
| texts: List of input texts to analyze |
| instruction: Custom instruction (uses default if None) |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| |
| Returns: |
| List of entity lists, one per input text |
| """ |
| results = [] |
| for text in texts: |
| entities = self.extract_entities(text, instruction, max_tokens, temperature) |
| results.append(entities) |
|
|
| return results |
|
|
| def _parse_entity_response(self, response_text: str) -> List[Dict[str, Any]]: |
| """ |
| Parse the model's numbered list response into structured entities. |
| |
| Args: |
| response_text: Raw model response |
| |
| Returns: |
| List of entity dictionaries |
| """ |
| entities = [] |
|
|
| |
| response_text = response_text.strip() |
|
|
| |
| lines = response_text.split('\n') |
|
|
| for line in lines: |
| line = line.strip() |
| if not line: |
| continue |
|
|
| |
| |
| numbered_match = re.match(r'^\d+\.\s*(.+?)(?:\s*-\s*.+)?$', line) |
| if numbered_match: |
| entity_text = numbered_match.group(1).strip() |
| |
| entity_text = re.sub(r'[.,;:!?]$', '', entity_text).strip() |
| |
| if entity_text and len(entity_text) > 1 and not entity_text.lower() in ['the', 'and', 'or', 'but', 'for', 'with']: |
| entities.append({ |
| 'text': entity_text, |
| 'type': 'ENTITY', |
| 'start': 0, |
| 'end': 0 |
| }) |
|
|
| return entities |
|
|
| def health_check(self) -> bool: |
| """ |
| Check if the model server is healthy and responding. |
| |
| Returns: |
| True if server is healthy, False otherwise |
| """ |
| try: |
| response = requests.get(f"{self.base_url}/health", timeout=5) |
| return response.status_code == 200 |
| except: |
| return False |
|
|
| def get_model_info(self) -> Optional[Dict[str, Any]]: |
| """ |
| Get information about the loaded model. |
| |
| Returns: |
| Model information dictionary or None if unavailable |
| """ |
| try: |
| response = requests.get(f"{self.base_url}/v1/models", timeout=5) |
| if response.status_code == 200: |
| return response.json() |
| except: |
| pass |
| return None |
|
|
|
|
| def main(): |
| """ |
| Command-line interface for NER inference. |
| """ |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description='NER-Small Inference Client') |
| parser.add_argument('text', help='Text to analyze for named entities') |
| parser.add_argument('--url', default='http://127.0.0.1:8000', |
| help='Model server URL (default: http://127.0.0.1:8000)') |
| parser.add_argument('--max-tokens', type=int, default=512, |
| help='Maximum tokens to generate (default: 512)') |
| parser.add_argument('--temperature', type=float, default=0.1, |
| help='Sampling temperature (default: 0.1)') |
|
|
| args = parser.parse_args() |
|
|
| |
| client = NERClient(args.url) |
|
|
| |
| if not client.health_check(): |
| print(f"❌ Error: Cannot connect to model server at {args.url}") |
| print("Make sure the llama.cpp server is running with the NER-Small model.") |
| return 1 |
|
|
| |
| entities = client.extract_entities( |
| args.text, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature |
| ) |
|
|
| |
| print(f"📝 Input Text: {args.text}") |
| print(f"🎯 Found {len(entities)} entities:") |
| print() |
|
|
| if entities: |
| for i, entity in enumerate(entities, 1): |
| print(f"{i}. {entity['text']} (Type: {entity['type']})") |
| else: |
| print("No entities found.") |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|