md / app.py
ciaochris's picture
Update app.py
25345f7 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from collections import deque
import random
from scipy.stats import entropy
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.manifold import TSNE
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
base_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model.to(device)
class ProjectionHead(nn.Module):
def __init__(self, input_dim=384, hidden_dim=128, output_dim=384):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.projection(x)
projection_head = ProjectionHead().to(device)
optimizer = optim.Adam(projection_head.parameters(), lr=0.001)
# Hierarchical concept structure
class ConceptHierarchy:
def __init__(self):
self.hierarchy = {
"health": ["physical", "mental", "holistic", "preventive"],
"tech": ["software", "hardware", "AI", "blockchain"],
"nature": ["ecology", "wildlife", "climate", "conservation"],
"spirit": ["mindfulness", "philosophy", "religion", "consciousness"]
}
self.reverse_lookup = {}
for main, subs in self.hierarchy.items():
for sub in subs:
self.reverse_lookup[sub] = main
def get_parent(self, subcategory):
return self.reverse_lookup.get(subcategory, subcategory)
def get_children(self, category):
return self.hierarchy.get(category, [])
def all_categories(self):
all_cats = list(self.hierarchy.keys())
for subs in self.hierarchy.values():
all_cats.extend(subs)
return all_cats
concept_hierarchy = ConceptHierarchy()
class CognitiveMemory:
def __init__(self, max_length=100):
self.samples = deque(maxlen=max_length)
self.embeddings_cache = {}
self.concept_centroids = {}
self.uncertainty_history = []
self.drift_scores = {}
def add(self, text, label, embedding=None):
if embedding is None:
embedding = embed_text(text)
self.samples.append((text, label, embedding))
# Update concept centroids
if label not in self.concept_centroids:
self.concept_centroids[label] = embedding
else:
# Moving average update
self.concept_centroids[label] = 0.9 * self.concept_centroids[label] + 0.1 * embedding
# Check for concept drift
if len(self.samples) > 10:
self._detect_concept_drift()
def _detect_concept_drift(self):
# Simple drift detection by measuring distance change over time
for label in self.concept_centroids:
recent_examples = [emb for txt, lbl, emb in self.samples if lbl == label][-5:]
if len(recent_examples) > 1:
recent_centroid = torch.stack(recent_examples).mean(dim=0)
drift = torch.norm(self.concept_centroids[label] - recent_centroid).item()
self.drift_scores[label] = drift
def get_embeddings_labels(self):
if not self.samples:
return None, None, None
texts, labels, embeddings = zip(*self.samples)
return embeddings, labels, texts
def get_drift_report(self):
if not self.drift_scores:
return "No drift detected yet"
highest_drift = max(self.drift_scores.items(), key=lambda x: x[1])
if highest_drift[1] > 0.15:
return f"Significant concept drift detected in '{highest_drift[0]}' category"
return "Concept stability maintained across all categories"
# Enhanced embedding with adaptive projection
def embed_text(text, apply_projection=False):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = base_model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu()
if apply_projection:
with torch.no_grad():
embedding = projection_head(embedding.to(device)).cpu()
return embedding
# Initialize memory
memory = CognitiveMemory()
# Contrastive learning update
def update_projection_head(pos_examples, neg_examples, temperature=0.1):
projection_head.train()
optimizer.zero_grad()
# Prepare positive and negative examples
pos_embeddings = torch.stack([ex.to(device) for ex in pos_examples])
neg_embeddings = torch.stack([ex.to(device) for ex in neg_examples])
# Project embeddings
pos_projections = projection_head(pos_embeddings)
neg_projections = projection_head(neg_embeddings)
# Calculate similarities
pos_sim = torch.mm(pos_projections, pos_projections.t()) / temperature
neg_sim = torch.mm(pos_projections, neg_projections.t()) / temperature
# Create contrastive loss
logits = torch.cat([pos_sim, neg_sim], dim=1)
labels = torch.arange(pos_projections.size(0)).to(device)
# Loss calculation (simplified contrastive loss)
loss = nn.CrossEntropyLoss()(logits, labels)
loss.backward()
optimizer.step()
projection_head.eval()
return loss.item()
# Active learning sample selection
def get_informative_samples(embeddings, labels, num_samples=3):
if len(set(labels)) < 2:
return ["Need examples from multiple categories"]
# Calculate uncertainty for each category
label_set = set(labels)
uncertainty_scores = {}
for category in label_set:
# Get other category embeddings
other_embeds = [e for e, l in zip(embeddings, labels) if l != category]
if not other_embeds:
continue
# Calculate centroid for this category
this_embeds = [e for e, l in zip(embeddings, labels) if l == category]
centroid = torch.stack(this_embeds).mean(dim=0)
# Calculate similarity to other categories
other_stack = torch.stack(other_embeds)
sims = torch.matmul(centroid.unsqueeze(0), other_stack.transpose(0, 1))
# Higher max similarity means more ambiguity/uncertainty
uncertainty_scores[category] = torch.max(sims).item()
# Find the most uncertain categories
sorted_categories = sorted(uncertainty_scores.items(), key=lambda x: -x[1])
# Suggest example prompts for the most uncertain categories
suggestions = []
for category, score in sorted_categories[:2]:
subcategories = concept_hierarchy.get_children(category)
if subcategories:
suggestions.append(f"Need examples distinguishing '{category}' from other categories")
suggestions.append(f"Consider examples about '{random.choice(subcategories)}'")
return suggestions
# Uncertainty quantification
def calculate_uncertainty(similarities):
# Convert to probability distribution
probs = similarities / np.sum(similarities)
# Calculate entropy (higher means more uncertain)
uncertainty = entropy(probs)
# Normalize between 0 and 1
max_entropy = np.log(len(probs))
normalized_uncertainty = uncertainty / max_entropy if max_entropy > 0 else 0
return normalized_uncertainty
# Counterfactual explanation generation
def generate_counterfactual(text_embedding, predicted_label, labels, embeddings):
# Find nearest example of a different class
different_class_embeddings = [(e, l, i) for i, (e, l) in enumerate(zip(embeddings, labels)) if l != predicted_label]
if not different_class_embeddings:
return "No alternative classes available for counterfactual"
# Calculate distances
distances = [torch.norm(text_embedding - e).item() for e, _, _ in different_class_embeddings]
nearest_idx = np.argmin(distances)
nearest_embed, nearest_label, original_idx = different_class_embeddings[nearest_idx]
# Calculate direction vector to move from current to alternate class
direction = nearest_embed - text_embedding
direction_normalized = direction / torch.norm(direction)
# Identify key dimensions (simplified)
key_dims = torch.topk(torch.abs(direction_normalized), 10).indices
return f"To change classification from '{predicted_label}' to '{nearest_label}', the text would need more emphasis on concepts found in '{labels[original_idx]}'"
# Advanced inference with uncertainty and counterfactuals
def infer_with_insights(text):
if len(memory.samples) < 5:
return "Label: Unknown", "Insight: Need more training examples (at least 5)", "Uncertainty: High", "Visualization not available", "No counterfactual available"
# Get text embedding
input_embedding = embed_text(text, apply_projection=True)
# Get memory contents
memory_embeddings, memory_labels, memory_texts = memory.get_embeddings_labels()
memory_embeddings = [embed_text(mem_text, apply_projection=True) for mem_text in memory_texts]
# Calculate similarities
input_vec_np = input_embedding.unsqueeze(0).numpy()
memory_vecs_np = torch.stack(memory_embeddings).numpy()
sims = cosine_similarity(input_vec_np, memory_vecs_np)[0]
# Find best match
best_idx = np.argmax(sims)
confidence = sims[best_idx]
predicted_label = memory_labels[best_idx]
# Calculate uncertainty
uncertainty = calculate_uncertainty(sims)
uncertainty_level = "High" if uncertainty > 0.8 else "Medium" if uncertainty > 0.5 else "Low"
# Generate counterfactual
counterfactual = generate_counterfactual(input_embedding, predicted_label, memory_labels, memory_embeddings)
# Generate hierarchical insight
parent_category = concept_hierarchy.get_parent(predicted_label)
subcategories = concept_hierarchy.get_children(parent_category)
if predicted_label in subcategories:
insight = f"This concept falls under '{parent_category}' with specific focus on '{predicted_label}' aspects."
else:
subcategory_text = ", ".join(subcategories[:2]) + ("..." if len(subcategories) > 2 else "")
insight = f"This concept broadly relates to '{predicted_label}' which includes aspects like {subcategory_text}."
# Create visualization data
tsne = TSNE(n_components=2, random_state=42)
all_embeddings = memory_vecs_np.tolist() + [input_vec_np[0].tolist()]
all_labels = list(memory_labels) + ["Current Input"]
# Create visualization code
vis_code = """
```python
# Load this code in a notebook to visualize
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
# Your embeddings and labels would go here
# This is a placeholder visualization
plt.figure(figsize=(10, 8))
for label in set(labels[:-1]):
indices = [i for i, l in enumerate(labels[:-1]) if l == label]
plt.scatter(coords[indices, 0], coords[indices, 1], label=label)
# Highlight the input point
plt.scatter(coords[-1, 0], coords[-1, 1], color='red',
s=100, marker='*', label='Current Input')
plt.legend()
plt.title("Concept Map Visualization")
plt.show()
```
"""
# Add uncertainty and drift detection to memory
memory.uncertainty_history.append(uncertainty)
drift_report = memory.get_drift_report()
full_insight = f"{insight}\n\n{drift_report}"
return f"Label: {predicted_label} (Confidence: {confidence:.2f})", full_insight, f"Uncertainty: {uncertainty_level} ({uncertainty:.2f})", vis_code, counterfactual
# Enhanced training with contrastive learning
def train_sample(text, label):
# Check if we have enough samples for contrastive learning
embeddings, labels, _ = memory.get_embeddings_labels() or ([], [], [])
text_embedding = embed_text(text)
# Add to memory
memory.add(text, label, embedding=text_embedding)
# If we have multiple categories, do contrastive update
unique_labels = set(labels) if labels else set()
if label in unique_labels and len(unique_labels) > 1:
# Get positive examples (same label)
pos_examples = [e for e, l in zip(embeddings, labels) if l == label]
# Get negative examples (different labels)
neg_examples = [e for e, l in zip(embeddings, labels) if l != label]
# If we have enough examples, do contrastive update
if len(pos_examples) > 0 and len(neg_examples) > 0:
loss = update_projection_head(
pos_examples[:min(5, len(pos_examples))],
neg_examples[:min(5, len(neg_examples))]
)
contrastive_msg = f" • Updated adaptive projection (loss: {loss:.4f})"
else:
contrastive_msg = ""
else:
contrastive_msg = ""
# Get active learning suggestions if we have enough samples
if len(memory.samples) >= 5:
active_suggestions = get_informative_samples(embeddings + [text_embedding], labels + [label])
active_msg = "\n\nSuggested next examples:\n" + "\n".join([f"• {s}" for s in active_suggestions])
else:
active_msg = "\n\nAdd " + str(5 - len(memory.samples)) + " more examples to enable active learning."
return f"Stored '{text}' as '{label}' | Total samples: {len(memory.samples)}{contrastive_msg}{active_msg}"
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("# Vers3Dynamics Labeling System")
gr.Markdown("### This system features meta-learning, active learning, uncertainty quantification, and concept drift detection")
with gr.Row():
text_input = gr.Textbox(label="Input Text", placeholder="Type a concept like 'Blockchain for healthcare records'...")
infer_btn = gr.Button("Analyze with Cognitive Insights")
with gr.Row():
label_output = gr.Textbox(label="Classification Result")
insight_output = gr.Textbox(label="Cognitive Insight")
with gr.Row():
uncertainty_output = gr.Textbox(label="Uncertainty Analysis")
counterfactual_output = gr.Textbox(label="Counterfactual Explanation")
visualization_output = gr.Code(label="Visualization Code", language="python")
infer_btn.click(
fn=infer_with_insights,
inputs=text_input,
outputs=[label_output, insight_output, uncertainty_output, visualization_output, counterfactual_output]
)
gr.Markdown("### Cognitive Training")
with gr.Row():
train_text = gr.Textbox(label="Training Example")
with gr.Row():
main_categories = gr.Radio(list(concept_hierarchy.hierarchy.keys()), label="Main Category")
sub_categories = gr.Dropdown([], label="Sub-Category (Optional)")
def update_subcategories(main_category):
if main_category:
return gr.Dropdown.update(choices=[""] + concept_hierarchy.get_children(main_category))
return gr.Dropdown.update(choices=[])
main_categories.change(fn=update_subcategories, inputs=main_categories, outputs=sub_categories)
train_btn = gr.Button("Store & Learn From Example")
train_output = gr.Textbox(label="Training Status & Suggestions")
def handle_training(text, main_category, sub_category):
# Use subcategory if provided, otherwise use main category
final_category = sub_category if sub_category else main_category
return train_sample(text, final_category)
train_btn.click(
fn=handle_training,
inputs=[train_text, main_categories, sub_categories],
outputs=train_output
)
# System status section
gr.Markdown("### Vers3Dynamics System Status")
def get_system_status():
if len(memory.samples) == 0:
return "System initialized - no training data yet"
num_samples = len(memory.samples)
_, labels, _ = memory.get_embeddings_labels()
category_counts = {}
for label in labels:
if label in category_counts:
category_counts[label] += 1
else:
category_counts[label] = 1
categories_info = ", ".join([f"{k}: {v}" for k, v in category_counts.items()])
adaptations = "Meta-learning projection: " + ("Active" if len(memory.samples) > 5 else "Not yet active")
drift_info = memory.get_drift_report()
return f"System Status:\n• Samples: {num_samples}\n• Categories: {categories_info}\n• {adaptations}\n• {drift_info}"
status_btn = gr.Button("Check System Status")
status_output = gr.Textbox(label="Current System Status")
status_btn.click(fn=get_system_status, outputs=status_output)
if __name__ == "__main__":
app.launch()