from flask import Flask, render_template_string, request, jsonify from flask_cors import CORS from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline import os import sys import threading import time app = Flask(__name__) CORS(app) # Model loading state (thread-safe) model_name = "OpenMed/privacy-filter-nemotron" classifier = None tokenizer = None model_loading = False model_error = None model_thread = None # Background model loading def load_model_async(): global classifier, tokenizer, model_loading, model_error model_loading = True print("="*60, flush=True) print("BACKGROUND: Loading OpenMed Privacy Filter model...", flush=True) print("="*60, flush=True) try: print(f"Loading tokenizer and model: {model_name}", flush=True) print("This may take 5-10 minutes on first run...", flush=True) # Use AutoModelForTokenClassification directly for better performance tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir="/app/.cache/huggingface" ) model = AutoModelForTokenClassification.from_pretrained( model_name, cache_dir="/app/.cache/huggingface" ) global classifier classifier = pipeline( task="token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="first", #none simple first average max device=-1 ) print("✓ Model loaded successfully!", flush=True) model_error = None except Exception as e: model_error = str(e) print(f"✗ ERROR loading model: {e}", flush=True) import traceback traceback.print_exc() finally: model_loading = False # Start model loading in background model_thread = threading.Thread(target=load_model_async, daemon=True) model_thread.start() def escape_html(text): """Escape HTML special characters to prevent XSS""" return (text .replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'")) def create_line_chunks(text, max_tokens=2048): """Split text into chunks that respect line boundaries. Groups lines together based on max token limit, never cutting mid-line. """ global tokenizer if tokenizer is None: return [(0, text, len(text.split()))] lines = text.split('\n') chunks = [] current_lines = [] current_token_count = 0 current_char_start = 0 for line in lines: line_tokens = tokenizer(line, add_special_tokens=False)['input_ids'] line_token_count = len(line_tokens) # If this single line exceeds max_tokens, we have to include it anyway if current_token_count + line_token_count > max_tokens and current_lines: # Save current chunk chunk_text = '\n'.join(current_lines) chunks.append((current_char_start, chunk_text, current_token_count)) # Start new chunk with this line current_lines = [line] current_token_count = line_token_count current_char_start = text.find(line, current_char_start + len('\n'.join(current_lines[:-1])) if current_lines[:-1] else 0) else: current_lines.append(line) current_token_count += line_token_count # Add final chunk if current_lines: chunk_text = '\n'.join(current_lines) chunks.append((current_char_start, chunk_text, current_token_count)) return chunks def merge_adjacent_entities(entities): """Merge adjacent entities of the same type that are likely from tokenization splits.""" if not entities: return entities # Sort by start position sorted_entities = sorted(entities, key=lambda x: x.get('start', 0)) merged = [] i = 0 while i < len(sorted_entities): current = sorted_entities[i] current_label = current.get('entity_group') or current.get('entity', 'unknown') current_end = current.get('end', 0) current_text = current.get('word', '') current_score = current.get('score', 0) # Look ahead for adjacent same-type entities j = i + 1 while j < len(sorted_entities): next_entity = sorted_entities[j] next_label = next_entity.get('entity_group') or next_entity.get('entity', 'unknown') next_start = next_entity.get('start', 0) # Check if same label and adjacent (or overlapping/nearby) if next_label == current_label and next_start <= current_end + 5: # Merge next_end = next_entity.get('end', 0) next_text = next_entity.get('word', '') next_score = next_entity.get('score', 0) # Combine text (remove overlap if any) if next_start <= current_end: current_text = current_text[:next_start - current.get('start', 0)] + next_text else: current_text = current_text + ' ' + next_text current_end = max(current_end, next_end) current_score = max(current_score, next_score) # Use highest score j += 1 else: break merged.append({ 'entity_group': current_label, 'entity': current_label, 'word': current_text, 'start': current.get('start', 0), 'end': current_end, 'score': current_score }) i = j return merged # HTML Template with proper loading states HTML_TEMPLATE = '''
PII Detection & Masking Demo using Flask