Spaces:
Sleeping
Sleeping
| """ | |
| MaTableGPT MCP Service | |
| ====================== | |
| A Model Context Protocol (MCP) service for extracting table data from | |
| materials science literature using GPT models. | |
| This service provides tools for: | |
| 1. Table Representation: Converting HTML tables to TSV or JSON format | |
| 2. Table Splitting: Breaking down complex tables into simpler components | |
| 3. GPT-based Data Extraction: Using fine-tuning, few-shot, or zero-shot models | |
| 4. Follow-up Questions: Refining extraction results through iterative questioning | |
| 5. Model Evaluation: Assessing extraction quality | |
| """ | |
| import os | |
| import json | |
| import re | |
| import logging | |
| import tempfile | |
| import uuid | |
| from datetime import datetime | |
| from typing import Optional, Dict, List, Any, Union | |
| from dataclasses import dataclass, field | |
| from contextlib import asynccontextmanager | |
| from bs4 import BeautifulSoup | |
| import pandas as pd | |
| # MCP imports | |
| from mcp.server.fastmcp import FastMCP | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("matablgpt-mcp") | |
| # ============================================================================= | |
| # Data Classes | |
| # ============================================================================= | |
| class TableData: | |
| """Represents a parsed table structure""" | |
| title: str = "" | |
| caption: str = "" | |
| tag: str = "" # HTML table tag | |
| headers: List[List[str]] = field(default_factory=list) | |
| body: List[List[str]] = field(default_factory=list) | |
| class ExtractionResult: | |
| """Represents the result of GPT extraction""" | |
| session_id: str | |
| table_name: str | |
| model_type: str # 'fine-tuning', 'few-shot', 'zero-shot' | |
| result: Dict[str, Any] | |
| timestamp: str | |
| follow_up_applied: bool = False | |
| class SessionData: | |
| """Session data for storing extraction results""" | |
| session_id: str | |
| created_at: str | |
| tables: Dict[str, TableData] = field(default_factory=dict) | |
| representations: Dict[str, str] = field(default_factory=dict) | |
| extractions: List[ExtractionResult] = field(default_factory=list) | |
| # ============================================================================= | |
| # Table Processing Classes | |
| # ============================================================================= | |
| class TableRepresenter: | |
| """ | |
| Converts HTML tables to TSV (Tab-Separated Values) representation. | |
| Handles merged cells, captions, and titles. | |
| """ | |
| def __init__(self): | |
| # Cell representation formats | |
| self.merged_cell = '<merge {}={}>{}</merge>' | |
| self.both_merged_cell = '<merge {}={} {}={}>{}</merge>' | |
| self.cell = '{}\\t' | |
| self.line_breaking = '\\n' | |
| self.table_tag = '<table>{}</table>' | |
| self.caption_tag = '<caption>{}</caption>' | |
| self.title_tag = '<title>{}</title>' | |
| def text_filter(self, text: str) -> str: | |
| """Remove unnecessary text and HTML tags from the given string.""" | |
| out = text | |
| # Replace special Unicode characters | |
| replacements = [ | |
| ('\\xa0', ' '), ('\\u2005', ' '), ('\\u2009', ' '), | |
| ('\\u202f', ' '), ('\\u200b', ''), ('<b>', ''), ('</b>', '') | |
| ] | |
| for old, new in replacements: | |
| out = out.replace(old, new) | |
| # Remove specific patterns | |
| patterns = [ | |
| (r'<cap>(\(\d+\)|\d+|\[\d+\]|\d+\,\d+|\d+\,\d+\,\d+|\d+\,\d+\–\d+|\d+\D+|\(\d+\,\s*\d+\)|\(\d+\D+\))</cap>', r'\1'), | |
| (r'<cap>(\s*ref\.\s\d+.*?)</cap>', r'\1'), | |
| (r'\(<cap>(\s*(ref\.\s\d+.*?)\s*)</cap>\)', r'\1'), | |
| (r'<cap>(\s*Ref\.\s\d+.*?)</cap>', r'\1'), | |
| (r'\(<cap>(\s*(Ref\.\s\d+.*?)\s*)</cap>\)', r'\1'), | |
| (r'<cap>(\[\d+|\d+\])</cap>', r'\1'), | |
| (r'<cap>((.*?)et al\..*?)</cap>', r'\1'), | |
| (r'<cap>((.*?)Fig\..*?)</cap>', r'\1'), | |
| (r'<cap>(Song and Hu \(2014\))</cap>', r'\1'), | |
| (r'<div> <cap> </cap> </div> ', ''), | |
| (r'<cap>(mA\.cm)</cap>', r'\1'), | |
| (r'<cap>(https.*?)</cap>', r'\1'), | |
| (r'<cap>(\d+\.\d+\@\d+)</cap>', r'\1') | |
| ] | |
| for pattern, repl in patterns: | |
| out = re.sub(pattern, repl, out) | |
| return out | |
| def process_table(self, t): | |
| """Remove unnecessary HTML tags from the table element.""" | |
| tags_to_remove = [ | |
| 'img', 'em', 'i', 'p', 'span', 'strong', 'math', 'mi', 'br', | |
| 'script', 'svg', 'mrow', 'mo', 'mn', 'msub', 'msubsup', 'mtext', | |
| 'mjx-container', 'mjx-math', 'mjx-mrow', 'mjx-msub', 'mjx-mi', | |
| 'mjx-c', 'mjx-script', 'mjx-mspace', 'mjx-assistive-mml', 'mspace' | |
| ] | |
| for tag in tags_to_remove: | |
| elements = t.find_all(tag) | |
| for element in elements: | |
| if tag in ['img', 'script', 'svg']: | |
| element.decompose() | |
| else: | |
| element.unwrap() | |
| return t | |
| def html_to_tsv(self, html_table: str, title: str = "", caption: str = "") -> str: | |
| """ | |
| Convert HTML table to TSV representation. | |
| Args: | |
| html_table: HTML string containing the table | |
| title: Table title | |
| caption: Table caption | |
| Returns: | |
| TSV representation of the table | |
| """ | |
| soup = BeautifulSoup(html_table, 'html.parser') | |
| table = soup.find('table') | |
| if not table: | |
| table = soup | |
| # Get table dimensions | |
| tbody = table.find('tbody') or table | |
| first_row = tbody.find('tr') | |
| if not first_row: | |
| return "Error: No table rows found" | |
| width = sum(int(cell.get('colspan', 1)) for cell in first_row.find_all(re.compile('(?<!ma)th|td'))) | |
| height = len(table.find_all('tr')) | |
| # Initialize output grid | |
| out = [['' for _ in range(width)] for _ in range(height)] | |
| # Process each row | |
| i = 0 | |
| for tr in table.find_all('tr'): | |
| j = 0 | |
| for cell in tr.find_all(re.compile('(?<!ma)th|td')): | |
| # Process links | |
| for a_tag in cell.find_all('a'): | |
| a_text = a_tag.get_text() | |
| if a_text.isdigit(): | |
| a_tag.string = f"<ref>{a_text}</ref>" | |
| else: | |
| a_tag.string = f"<cap>{a_text}</cap>" | |
| cell = self.process_table(cell) | |
| # Find next empty cell | |
| while j < width and out[i][j] != '': | |
| j += 1 | |
| if j >= width: | |
| break | |
| refined_text = ''.join(str(element) for element in cell.contents) | |
| colspan = int(cell.get('colspan', 0)) | |
| rowspan = int(cell.get('rowspan', 0)) | |
| # Handle merged cells | |
| if colspan and rowspan: | |
| out[i][j] = self.both_merged_cell.format('colspan', colspan, 'rowspan', rowspan, self.text_filter(refined_text)) | |
| for c in range(colspan): | |
| for r in range(rowspan): | |
| if c > 0 or r > 0: | |
| if i + r < height and j + c < width: | |
| out[i + r][j + c] = '::' | |
| elif colspan: | |
| out[i][j] = self.merged_cell.format('colspan', colspan, self.text_filter(refined_text)) | |
| for c in range(1, colspan): | |
| if j + c < width: | |
| out[i][j + c] = '::' | |
| elif rowspan: | |
| out[i][j] = self.merged_cell.format('rowspan', rowspan, self.text_filter(refined_text)) | |
| for r in range(1, rowspan): | |
| if i + r < height: | |
| out[i + r][j] = '::' | |
| else: | |
| text = self.text_filter(refined_text) if refined_text else ' ' | |
| out[i][j] = text | |
| j += colspan if colspan else 1 | |
| i += 1 | |
| # Build result string | |
| result = '' | |
| for row in out: | |
| for element in row: | |
| if element != '::': | |
| result += self.cell.format(element) | |
| result += self.line_breaking | |
| final_result = self.title_tag.format(title) + self.table_tag.format(result) | |
| if caption: | |
| if isinstance(caption, dict): | |
| caption_str = ', '.join([f"{k}: {v}" for k, v in caption.items()]) | |
| else: | |
| caption_str = str(caption) | |
| final_result += '\n' + self.caption_tag.format(caption_str) | |
| return final_result | |
| class TableToJSON: | |
| """ | |
| Converts HTML tables to JSON representation. | |
| """ | |
| def process_caption(self, table): | |
| """Process caption and reference tags.""" | |
| # Remove tfoot | |
| for tfoot in table.find_all('tfoot'): | |
| tfoot.decompose() | |
| for cell in table.find_all(['td', 'th']): | |
| for link in cell.find_all('a'): | |
| link_text = link.get_text() | |
| if len(link_text) == 1 and (link_text.isalpha() or link_text == '*'): | |
| link.string = f"<cap>{link_text}</cap>" | |
| else: | |
| link.string = f"<ref>{link_text}</ref>" | |
| return table | |
| def process_sub_sup(self, table): | |
| """Process subscript and superscript tags.""" | |
| for cell in table.find_all(['td', 'th']): | |
| for sup in cell.find_all('sup'): | |
| sup_text = sup.get_text() or "" | |
| sup.string = f"<sup>{sup_text}</sup>" | |
| for sub in cell.find_all('sub'): | |
| sub_text = sub.get_text() or "" | |
| sub.string = f"<sub>{sub_text}</sub>" | |
| return table | |
| def html_to_json(self, html_table: str, title: str = "", caption: str = "") -> Dict: | |
| """ | |
| Convert HTML table to JSON representation. | |
| Args: | |
| html_table: HTML string containing the table | |
| title: Table title | |
| caption: Table caption | |
| Returns: | |
| JSON dictionary representation of the table | |
| """ | |
| soup = BeautifulSoup(html_table, 'html.parser') | |
| table = soup.find('table') | |
| if not table: | |
| table = soup | |
| # Process table | |
| table = self.process_caption(table) | |
| table = self.process_sub_sup(table) | |
| # Fill empty header cells | |
| for th in table.find_all('th'): | |
| if not th.text.strip(): | |
| th.insert(0, '-') | |
| # Convert to DataFrame | |
| try: | |
| dfs = pd.read_html(str(table)) | |
| if not dfs: | |
| return {"error": "Could not parse table"} | |
| df = dfs[0] | |
| df.fillna("NaN", inplace=True) | |
| except Exception as e: | |
| return {"error": f"Failed to parse table: {str(e)}"} | |
| # Build JSON structure | |
| result = {} | |
| header_levels = df.columns.nlevels | |
| keys = list(df.columns) | |
| for i, key in enumerate(keys): | |
| values = df.iloc[:, i].tolist() | |
| if header_levels > 1: | |
| current = result | |
| for j, k in enumerate(key): | |
| if j == len(key) - 1: | |
| current[k] = values | |
| else: | |
| if k not in current: | |
| current[k] = {} | |
| current = current[k] | |
| else: | |
| result[key] = values | |
| # Add metadata | |
| final_result = { | |
| "Title": title, | |
| "caption": caption, | |
| **result | |
| } | |
| return final_result | |
| class TableSplitter: | |
| """ | |
| Splits complex tables into simpler components for better extraction. | |
| """ | |
| def analyze_table_structure(self, html_table: str) -> Dict: | |
| """ | |
| Analyze the structure of an HTML table. | |
| Args: | |
| html_table: HTML string containing the table | |
| Returns: | |
| Dictionary containing structural analysis | |
| """ | |
| soup = BeautifulSoup(html_table, 'html.parser') | |
| table = soup.find('table') or soup | |
| rows = table.find_all('tr') | |
| # Analyze each row | |
| row_analysis = [] | |
| for row in rows: | |
| cells = row.find_all(['td', 'th']) | |
| cell_types = [cell.name for cell in cells] | |
| merged_cells = sum(1 for cell in cells if cell.get('colspan') or cell.get('rowspan')) | |
| # Determine if row is header or body | |
| is_header = all(c.name == 'th' for c in cells) or self._is_header_content(cells) | |
| row_analysis.append({ | |
| "cell_count": len(cells), | |
| "cell_types": cell_types, | |
| "merged_cells": merged_cells, | |
| "is_header": is_header | |
| }) | |
| return { | |
| "total_rows": len(rows), | |
| "has_thead": table.find('thead') is not None, | |
| "has_tbody": table.find('tbody') is not None, | |
| "row_analysis": row_analysis | |
| } | |
| def _is_header_content(self, cells) -> bool: | |
| """Check if cells contain header-like content.""" | |
| if not cells: | |
| return False | |
| # Check if all cells have the same value (likely a spanning header) | |
| texts = [c.get_text().strip() for c in cells] | |
| if len(set(texts)) == 1 and texts[0]: | |
| return True | |
| # Check if content is mostly non-numeric | |
| numeric_count = 0 | |
| for text in texts: | |
| try: | |
| float(re.sub(r'[^\d.-]', '', text)) | |
| numeric_count += 1 | |
| except: | |
| pass | |
| return numeric_count < len(texts) / 2 | |
| def split_table(self, html_table: str, title: str = "", caption: str = "") -> List[Dict]: | |
| """ | |
| Split a complex table into simpler components. | |
| Args: | |
| html_table: HTML string containing the table | |
| title: Table title | |
| caption: Table caption | |
| Returns: | |
| List of simplified table dictionaries | |
| """ | |
| soup = BeautifulSoup(html_table, 'html.parser') | |
| table = soup.find('table') or soup | |
| analysis = self.analyze_table_structure(html_table) | |
| # If simple table, return as-is | |
| if all(not r['is_header'] or i == 0 for i, r in enumerate(analysis['row_analysis'])): | |
| return [{ | |
| "html": str(table), | |
| "title": title, | |
| "caption": caption, | |
| "index": 1 | |
| }] | |
| # Split based on internal headers | |
| split_tables = [] | |
| current_header = None | |
| current_rows = [] | |
| thead = table.find('thead') | |
| original_header = str(thead) if thead else "" | |
| tbody = table.find('tbody') or table | |
| for i, row in enumerate(tbody.find_all('tr')): | |
| if analysis['row_analysis'][i if not thead else i + len(thead.find_all('tr'))]['is_header']: | |
| # Save previous section | |
| if current_rows: | |
| split_tables.append({ | |
| "html": self._build_table_html(original_header, current_header, current_rows), | |
| "title": title, | |
| "caption": caption, | |
| "index": len(split_tables) + 1 | |
| }) | |
| current_header = str(row) | |
| current_rows = [] | |
| else: | |
| current_rows.append(str(row)) | |
| # Save last section | |
| if current_rows: | |
| split_tables.append({ | |
| "html": self._build_table_html(original_header, current_header, current_rows), | |
| "title": title, | |
| "caption": caption, | |
| "index": len(split_tables) + 1 | |
| }) | |
| return split_tables if split_tables else [{ | |
| "html": str(table), | |
| "title": title, | |
| "caption": caption, | |
| "index": 1 | |
| }] | |
| def _build_table_html(self, original_header: str, sub_header: str, rows: List[str]) -> str: | |
| """Build HTML table from components.""" | |
| header = original_header | |
| if sub_header: | |
| if header: | |
| header = header.replace('</thead>', sub_header + '</thead>') | |
| else: | |
| header = f"<thead>{sub_header}</thead>" | |
| body = "<tbody>" + "".join(rows) + "</tbody>" | |
| return f"<table>{header}{body}</table>" | |
| # ============================================================================= | |
| # GPT Extraction Classes | |
| # ============================================================================= | |
| class GPTExtractor: | |
| """ | |
| Handles GPT-based extraction of catalyst data from table representations. | |
| Supports third-party API services with custom base URL (reverse proxy, | |
| API aggregators like OpenRouter, OneAPI, etc.). | |
| Environment Variables: | |
| LLM_API_KEY or OPENAI_API_KEY: Your API key | |
| LLM_API_BASE or OPENAI_API_BASE: API base URL (required for third-party services) | |
| LLM_MODEL or OPENAI_MODEL: Model name (default: gpt-4-turbo-preview) | |
| """ | |
| # Performance types to extract | |
| PERFORMANCE_LIST = [ | |
| 'overpotential', 'tafel_slope', 'Rct', 'stability', 'Cdl', | |
| 'onset_potential', 'current_density', 'potential', 'TOF', 'ECSA', | |
| 'water_splitting_potential', 'mass_activity', 'exchange_current_density', | |
| 'Rs', 'specific_activity', 'onset_overpotential', 'BET', 'surface_area', | |
| 'loading', 'apparent_activation_energy' | |
| ] | |
| # Property template | |
| PROPERTY_TEMPLATE = { | |
| 'electrolyte': '', 'reaction_type': '', 'value': '', | |
| 'current_density': '', 'overpotential': '', 'potential': '', | |
| 'substrate': '', 'versus': '', 'condition': '' | |
| } | |
| # Default model | |
| DEFAULT_MODEL = "gpt-4-turbo-preview" | |
| def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None): | |
| """ | |
| Initialize GPT Extractor. | |
| Args: | |
| api_key: API key. Falls back to LLM_API_KEY or OPENAI_API_KEY env var. | |
| base_url: API base URL. Falls back to LLM_API_BASE or OPENAI_API_BASE env var. | |
| model: Model name. Falls back to LLM_MODEL or OPENAI_MODEL env var. | |
| """ | |
| # Support multiple env var names for flexibility | |
| self.api_key = ( | |
| api_key or | |
| os.environ.get('LLM_API_KEY', '') or | |
| os.environ.get('OPENAI_API_KEY', '') | |
| ) | |
| self.base_url = ( | |
| base_url or | |
| os.environ.get('LLM_API_BASE', '') or | |
| os.environ.get('OPENAI_API_BASE', '') or | |
| os.environ.get('OPENAI_BASE_URL', '') | |
| ) | |
| self.model = ( | |
| model or | |
| os.environ.get('LLM_MODEL', '') or | |
| os.environ.get('OPENAI_MODEL', '') or | |
| self.DEFAULT_MODEL | |
| ) | |
| self._client = None | |
| logger.info(f"GPTExtractor initialized with model: {self.model}") | |
| if self.base_url: | |
| logger.info(f"Using custom API base URL: {self.base_url}") | |
| else: | |
| logger.warning("No API base URL configured - using default OpenAI endpoint") | |
| def client(self): | |
| """Lazy initialization of OpenAI-compatible client.""" | |
| if self._client is None: | |
| try: | |
| from openai import OpenAI | |
| # Build client kwargs | |
| client_kwargs = {"api_key": self.api_key} | |
| # Add base_url for third-party API services | |
| if self.base_url: | |
| client_kwargs["base_url"] = self.base_url | |
| self._client = OpenAI(**client_kwargs) | |
| logger.info("API client initialized successfully") | |
| except ImportError: | |
| raise ImportError("OpenAI package not installed. Install with: pip install openai") | |
| return self._client | |
| def get_model(self) -> str: | |
| """Get the model name to use for API calls.""" | |
| return self.model | |
| def get_system_prompt(self, model_type: str) -> str: | |
| """Get system prompt based on model type.""" | |
| if model_type == 'fine-tuning': | |
| return """This task is to take a string as input and convert it to JSON format. | |
| I want to extract the performance below: [reaction_type, versus, overpotential, substrate, loading, | |
| tafel_slope, onset_potential, current_density, BET, specific_activity, mass_activity, surface_area, | |
| ECSA, apparent_activation_energy, water_splitting_potential, potential, Rs, Rct, Cdl, TOF, stability, | |
| electrolyte, exchange_current_density, onset_overpotential]. | |
| If there is information about overpotential and Tafel slope in the input, the output should be: | |
| { | |
| "catalyst_name": { | |
| "overpotential": {"electrolyte": "1.0 M KOH", "reaction_type": "OER", "value": "230 mV", "current_density": "50 mA/cm2"}, | |
| "tafel_slope": {"electrolyte": "1.0 M KOH", "reaction_type": "OER", "value": "54 mV/dec"} | |
| } | |
| } | |
| If certain information cannot be found, those keys should not be included in the output. | |
| If there are no values corresponding to performance metrics, simply extract the catalyst name as: {"catalyst_name": {}}""" | |
| elif model_type == 'few-shot': | |
| return f"""I will extract the performance information of the catalyst from the table and create a JSON format. | |
| The types of performance to be extracted: performance_list = {self.PERFORMANCE_LIST} | |
| You can only use the names as they are in the performance_list. | |
| The JSON format will have performance within the catalyst, and each performance will include elements present in the table: | |
| reaction type, value, electrolyte, condition, current density, versus (ex: RHE) and substrate. | |
| The output must contain only JSON dictionary. Other sentences or opinions must not be in output.""" | |
| else: # zero-shot | |
| return f"""I'm going to convert the information in the table representer into JSON format. | |
| CATALYST_TEMPLATE = {{'catalyst_name': {{'performance_name': {{PROPERTY_TEMPLATE}}}}}} | |
| PROPERTY_TEMPLATE = {self.PROPERTY_TEMPLATE} | |
| performance_list = {self.PERFORMANCE_LIST} | |
| Extract catalyst information following these templates strictly.""" | |
| def extract_zero_shot(self, table_representation: str) -> Dict: | |
| """ | |
| Extract data using zero-shot approach with step-by-step questioning. | |
| Args: | |
| table_representation: TSV or JSON representation of the table | |
| Returns: | |
| Extracted catalyst data in JSON format | |
| """ | |
| messages = [{"role": "system", "content": self.get_system_prompt('zero-shot') + "\n\n" + table_representation}] | |
| # Step 1: Get catalyst list | |
| catalyst_q = "Show the catalysts present in the table representer as a Python list. Answer must be ONLY python list." | |
| messages.append({"role": "user", "content": catalyst_q}) | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| catalyst_answer = response.choices[0].message.content.strip() | |
| catalyst_list = eval(catalyst_answer) | |
| messages.append({"role": "assistant", "content": catalyst_answer}) | |
| except Exception as e: | |
| return {"error": f"Failed to extract catalysts: {str(e)}"} | |
| result = {"catalysts": []} | |
| for catalyst in catalyst_list: | |
| # Step 2: Get performance template for each catalyst | |
| perf_q = f"""Create a CATALYST_TEMPLATE filling in the performance of '{catalyst}' from the table representer, | |
| strictly adhering to these rules: | |
| Rule 1: Only include actual existing performances from the Performance_list. | |
| Rule 2: Set all values of keys in PROPERTY_TEMPLATE to be " ". DO NOT INSERT ANY VALUE. | |
| Rule 3: Answer must be ONLY JSON format.""" | |
| messages.append({"role": "user", "content": perf_q}) | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| perf_answer = response.choices[0].message.content.strip() | |
| messages.append({"role": "assistant", "content": perf_answer}) | |
| # Step 3: Fill in property values | |
| prop_q = """In PROPERTY_TEMPLATE, maintain all keys, and fill in values that exist in the table representer. | |
| If there are more than two "values" for the same performance, make it into a list. Include units in the values.""" | |
| messages.append({"role": "user", "content": prop_q}) | |
| response = self.client.chat.completions.create( | |
| model=self.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| prop_answer = response.choices[0].message.content.strip() | |
| # Step 4: Remove empty keys | |
| delete_q = "Remove keys with no values from previous version of CATALYST_TEMPLATE. Output only JSON." | |
| messages.append({"role": "assistant", "content": prop_answer}) | |
| messages.append({"role": "user", "content": delete_q}) | |
| response = self.client.chat.completions.create( | |
| model=self.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| final_answer = response.choices[0].message.content.strip() | |
| # Parse JSON | |
| if "```" in final_answer: | |
| final_answer = final_answer.replace("```json", "").replace("```", "") | |
| catalyst_data = json.loads(final_answer) | |
| result["catalysts"].append(catalyst_data) | |
| except Exception as e: | |
| result["catalysts"].append({catalyst: {"error": str(e)}}) | |
| return result["catalysts"][0] if len(result["catalysts"]) == 1 else result | |
| def extract_few_shot(self, table_representation: str, examples: List[Dict] = None) -> Dict: | |
| """ | |
| Extract data using few-shot approach with example pairs. | |
| Args: | |
| table_representation: TSV or JSON representation of the table | |
| examples: List of input/output example pairs | |
| Returns: | |
| Extracted catalyst data in JSON format | |
| """ | |
| messages = [{"role": "system", "content": self.get_system_prompt('few-shot')}] | |
| # Add examples if provided | |
| if examples: | |
| for ex in examples: | |
| messages.append({"role": "user", "content": ex.get('input', '')}) | |
| messages.append({"role": "assistant", "content": ex.get('output', '')}) | |
| messages.append({"role": "user", "content": table_representation}) | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| result = response.choices[0].message.content.strip() | |
| if "```" in result: | |
| result = result.replace("```json", "").replace("```", "") | |
| return json.loads(result) | |
| except json.JSONDecodeError: | |
| return {"raw_response": result, "error": "Could not parse as JSON"} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def extract_with_fine_tuned(self, table_representation: str, model_name: str) -> Dict: | |
| """ | |
| Extract data using a fine-tuned model. | |
| Args: | |
| table_representation: TSV or JSON representation of the table | |
| model_name: Name of the fine-tuned model | |
| Returns: | |
| Extracted catalyst data in JSON format | |
| """ | |
| messages = [ | |
| {"role": "system", "content": self.get_system_prompt('fine-tuning')}, | |
| {"role": "user", "content": str(table_representation)} | |
| ] | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| result = response.choices[0].message.content.strip() | |
| try: | |
| return json.loads(result) | |
| except: | |
| from ast import literal_eval | |
| return literal_eval(result) | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # ============================================================================= | |
| # Session Management | |
| # ============================================================================= | |
| class SessionManager: | |
| """Manages extraction sessions and data storage.""" | |
| def __init__(self, storage_dir: str = None): | |
| self.storage_dir = storage_dir or tempfile.mkdtemp(prefix="matablgpt_") | |
| os.makedirs(self.storage_dir, exist_ok=True) | |
| self.sessions: Dict[str, SessionData] = {} | |
| def create_session(self) -> str: | |
| """Create a new session.""" | |
| session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" | |
| session_dir = os.path.join(self.storage_dir, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| self.sessions[session_id] = SessionData( | |
| session_id=session_id, | |
| created_at=datetime.now().isoformat() | |
| ) | |
| return session_id | |
| def get_session(self, session_id: str) -> Optional[SessionData]: | |
| """Get session by ID.""" | |
| return self.sessions.get(session_id) | |
| def save_table(self, session_id: str, table_name: str, table_data: TableData) -> bool: | |
| """Save table data to session.""" | |
| session = self.get_session(session_id) | |
| if not session: | |
| return False | |
| session.tables[table_name] = table_data | |
| return True | |
| def save_representation(self, session_id: str, table_name: str, representation: str, format_type: str) -> bool: | |
| """Save table representation to session.""" | |
| session = self.get_session(session_id) | |
| if not session: | |
| return False | |
| key = f"{table_name}_{format_type}" | |
| session.representations[key] = representation | |
| return True | |
| def save_extraction(self, session_id: str, result: ExtractionResult) -> bool: | |
| """Save extraction result to session.""" | |
| session = self.get_session(session_id) | |
| if not session: | |
| return False | |
| session.extractions.append(result) | |
| return True | |
| def export_session(self, session_id: str) -> Dict: | |
| """Export session data as dictionary.""" | |
| session = self.get_session(session_id) | |
| if not session: | |
| return {"error": "Session not found"} | |
| return { | |
| "session_id": session.session_id, | |
| "created_at": session.created_at, | |
| "tables_count": len(session.tables), | |
| "representations_count": len(session.representations), | |
| "extractions_count": len(session.extractions), | |
| "extractions": [ | |
| { | |
| "table_name": e.table_name, | |
| "model_type": e.model_type, | |
| "result": e.result, | |
| "timestamp": e.timestamp, | |
| "follow_up_applied": e.follow_up_applied | |
| } | |
| for e in session.extractions | |
| ] | |
| } | |
| # ============================================================================= | |
| # MCP Server Definition | |
| # ============================================================================= | |
| # Initialize global components | |
| table_representer = TableRepresenter() | |
| table_to_json = TableToJSON() | |
| table_splitter = TableSplitter() | |
| session_manager = SessionManager() | |
| gpt_extractor = None # Lazy initialization | |
| def get_extractor() -> GPTExtractor: | |
| """Get or create GPT extractor instance.""" | |
| global gpt_extractor | |
| if gpt_extractor is None: | |
| gpt_extractor = GPTExtractor() | |
| return gpt_extractor | |
| # Create MCP server | |
| mcp = FastMCP("MaTableGPT-MCP") | |
| # ============================================================================= | |
| # MCP Tools | |
| # ============================================================================= | |
| def create_session() -> Dict: | |
| """ | |
| Create a new extraction session. | |
| Returns a session ID that should be used for subsequent operations. | |
| Sessions help organize and track table processing workflows. | |
| """ | |
| session_id = session_manager.create_session() | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "message": "Session created successfully. Use this session_id for subsequent operations." | |
| } | |
| def html_to_tsv_representation( | |
| html_table: str, | |
| title: str = "", | |
| caption: str = "", | |
| session_id: str = "", | |
| table_name: str = "" | |
| ) -> Dict: | |
| """ | |
| Convert an HTML table to TSV (Tab-Separated Values) representation. | |
| This format is optimized for GPT extraction as it preserves table structure | |
| including merged cells, headers, and captions in a text format. | |
| Args: | |
| html_table: HTML string containing the table element | |
| title: Optional title of the table | |
| caption: Optional caption/footnotes of the table | |
| session_id: Optional session ID to save the representation | |
| table_name: Optional name for the table (used for saving) | |
| Returns: | |
| Dictionary containing the TSV representation | |
| """ | |
| try: | |
| representation = table_representer.html_to_tsv(html_table, title, caption) | |
| result = { | |
| "success": True, | |
| "format": "TSV", | |
| "representation": representation | |
| } | |
| # Save to session if provided | |
| if session_id and table_name: | |
| session_manager.save_representation(session_id, table_name, representation, "tsv") | |
| result["saved_to_session"] = session_id | |
| return result | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def html_to_json_representation( | |
| html_table: str, | |
| title: str = "", | |
| caption: str = "", | |
| session_id: str = "", | |
| table_name: str = "" | |
| ) -> Dict: | |
| """ | |
| Convert an HTML table to JSON representation. | |
| This format converts the table structure into a nested JSON dictionary | |
| with column headers as keys and cell values as lists. | |
| Args: | |
| html_table: HTML string containing the table element | |
| title: Optional title of the table | |
| caption: Optional caption/footnotes of the table | |
| session_id: Optional session ID to save the representation | |
| table_name: Optional name for the table (used for saving) | |
| Returns: | |
| Dictionary containing the JSON representation | |
| """ | |
| try: | |
| representation = table_to_json.html_to_json(html_table, title, caption) | |
| result = { | |
| "success": True, | |
| "format": "JSON", | |
| "representation": representation | |
| } | |
| # Save to session if provided | |
| if session_id and table_name: | |
| session_manager.save_representation( | |
| session_id, table_name, json.dumps(representation), "json" | |
| ) | |
| result["saved_to_session"] = session_id | |
| return result | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def analyze_table_structure(html_table: str) -> Dict: | |
| """ | |
| Analyze the structure of an HTML table. | |
| This tool examines the table to identify: | |
| - Total number of rows | |
| - Presence of thead/tbody elements | |
| - Header rows vs body rows | |
| - Merged cells | |
| Use this to understand complex tables before processing. | |
| Args: | |
| html_table: HTML string containing the table element | |
| Returns: | |
| Dictionary containing structural analysis | |
| """ | |
| try: | |
| analysis = table_splitter.analyze_table_structure(html_table) | |
| return {"success": True, "analysis": analysis} | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def split_complex_table( | |
| html_table: str, | |
| title: str = "", | |
| caption: str = "" | |
| ) -> Dict: | |
| """ | |
| Split a complex table into simpler components. | |
| Complex tables with multiple internal headers or sub-tables are split | |
| into individual tables that are easier to process. | |
| Args: | |
| html_table: HTML string containing the table element | |
| title: Optional title of the table | |
| caption: Optional caption/footnotes of the table | |
| Returns: | |
| Dictionary containing list of split table components | |
| """ | |
| try: | |
| split_tables = table_splitter.split_table(html_table, title, caption) | |
| return { | |
| "success": True, | |
| "table_count": len(split_tables), | |
| "tables": split_tables | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def extract_catalyst_data_zero_shot( | |
| table_representation: str, | |
| session_id: str = "", | |
| table_name: str = "" | |
| ) -> Dict: | |
| """ | |
| Extract catalyst data from table representation using zero-shot GPT. | |
| This uses a multi-step questioning approach to: | |
| 1. Identify catalysts in the table | |
| 2. Determine performance metrics for each catalyst | |
| 3. Extract property values | |
| 4. Clean up the result | |
| Args: | |
| table_representation: TSV or JSON representation of the table | |
| session_id: Optional session ID to save the extraction | |
| table_name: Optional name for the table | |
| Returns: | |
| Dictionary containing extracted catalyst data | |
| """ | |
| try: | |
| extractor = get_extractor() | |
| result = extractor.extract_zero_shot(table_representation) | |
| extraction_result = ExtractionResult( | |
| session_id=session_id or "no_session", | |
| table_name=table_name or "unnamed", | |
| model_type="zero-shot", | |
| result=result, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| if session_id: | |
| session_manager.save_extraction(session_id, extraction_result) | |
| return { | |
| "success": True, | |
| "model_type": "zero-shot", | |
| "extraction": result | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def extract_catalyst_data_few_shot( | |
| table_representation: str, | |
| examples: List[Dict] = None, | |
| session_id: str = "", | |
| table_name: str = "" | |
| ) -> Dict: | |
| """ | |
| Extract catalyst data from table representation using few-shot GPT. | |
| Provide example input/output pairs to guide the extraction. | |
| Args: | |
| table_representation: TSV or JSON representation of the table | |
| examples: List of {"input": ..., "output": ...} example pairs | |
| session_id: Optional session ID to save the extraction | |
| table_name: Optional name for the table | |
| Returns: | |
| Dictionary containing extracted catalyst data | |
| """ | |
| try: | |
| extractor = get_extractor() | |
| result = extractor.extract_few_shot(table_representation, examples or []) | |
| extraction_result = ExtractionResult( | |
| session_id=session_id or "no_session", | |
| table_name=table_name or "unnamed", | |
| model_type="few-shot", | |
| result=result, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| if session_id: | |
| session_manager.save_extraction(session_id, extraction_result) | |
| return { | |
| "success": True, | |
| "model_type": "few-shot", | |
| "extraction": result | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def extract_catalyst_data_fine_tuned( | |
| table_representation: str, | |
| model_name: str, | |
| session_id: str = "", | |
| table_name: str = "" | |
| ) -> Dict: | |
| """ | |
| Extract catalyst data using a fine-tuned GPT model. | |
| Requires a pre-trained fine-tuned model name from OpenAI. | |
| Args: | |
| table_representation: TSV or JSON representation of the table | |
| model_name: Name of the fine-tuned OpenAI model | |
| session_id: Optional session ID to save the extraction | |
| table_name: Optional name for the table | |
| Returns: | |
| Dictionary containing extracted catalyst data | |
| """ | |
| try: | |
| extractor = get_extractor() | |
| result = extractor.extract_with_fine_tuned(table_representation, model_name) | |
| extraction_result = ExtractionResult( | |
| session_id=session_id or "no_session", | |
| table_name=table_name or "unnamed", | |
| model_type="fine-tuning", | |
| result=result, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| if session_id: | |
| session_manager.save_extraction(session_id, extraction_result) | |
| return { | |
| "success": True, | |
| "model_type": "fine-tuning", | |
| "model_name": model_name, | |
| "extraction": result | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def get_session_data(session_id: str) -> Dict: | |
| """ | |
| Get all data from a session. | |
| Returns tables, representations, and extractions stored in the session. | |
| Args: | |
| session_id: The session ID to retrieve | |
| Returns: | |
| Dictionary containing session data | |
| """ | |
| return session_manager.export_session(session_id) | |
| def list_performance_types() -> Dict: | |
| """ | |
| List all supported performance types for catalyst extraction. | |
| These are the standard property names that can be extracted from | |
| materials science literature tables about catalysts. | |
| Returns: | |
| Dictionary containing list of performance types | |
| """ | |
| return { | |
| "success": True, | |
| "performance_types": GPTExtractor.PERFORMANCE_LIST, | |
| "property_template": GPTExtractor.PROPERTY_TEMPLATE | |
| } | |
| def validate_extraction_result(extraction: Dict) -> Dict: | |
| """ | |
| Validate an extraction result against expected schema. | |
| Checks if the extraction follows the expected format with | |
| catalyst names, performance types, and property values. | |
| Args: | |
| extraction: The extraction result to validate | |
| Returns: | |
| Dictionary containing validation results | |
| """ | |
| issues = [] | |
| warnings = [] | |
| if not isinstance(extraction, dict): | |
| return {"valid": False, "issues": ["Extraction must be a dictionary"]} | |
| # Check for error | |
| if "error" in extraction: | |
| issues.append(f"Extraction contains error: {extraction['error']}") | |
| # Check structure | |
| valid_performance_types = set(GPTExtractor.PERFORMANCE_LIST) | |
| for catalyst_name, performances in extraction.items(): | |
| if catalyst_name in ["error", "raw_response", "catalysts"]: | |
| continue | |
| if not isinstance(performances, dict): | |
| warnings.append(f"Catalyst '{catalyst_name}' should have dict of performances") | |
| continue | |
| for perf_name, properties in performances.items(): | |
| if perf_name not in valid_performance_types: | |
| warnings.append(f"Unknown performance type: {perf_name}") | |
| if isinstance(properties, dict): | |
| for prop_key in properties.keys(): | |
| if prop_key not in GPTExtractor.PROPERTY_TEMPLATE: | |
| warnings.append(f"Unknown property key: {prop_key}") | |
| return { | |
| "valid": len(issues) == 0, | |
| "issues": issues, | |
| "warnings": warnings | |
| } | |
| def get_extraction_code_template(representation_format: str = "tsv", model_type: str = "zero-shot") -> Dict: | |
| """ | |
| Get Python code template for local extraction. | |
| Returns code that can be run locally to perform extraction | |
| without relying on the MCP service. | |
| Args: | |
| representation_format: Either 'tsv' or 'json' | |
| model_type: One of 'zero-shot', 'few-shot', or 'fine-tuning' | |
| Returns: | |
| Dictionary containing code template and instructions | |
| """ | |
| code = f'''""" | |
| MaTableGPT Local Extraction Template | |
| Model Type: {model_type} | |
| Representation Format: {representation_format} | |
| """ | |
| from openai import OpenAI | |
| import json | |
| # Initialize client | |
| client = OpenAI(api_key="YOUR_API_KEY") | |
| # Performance types to extract | |
| PERFORMANCE_LIST = [ | |
| 'overpotential', 'tafel_slope', 'Rct', 'stability', 'Cdl', | |
| 'onset_potential', 'current_density', 'potential', 'TOF', 'ECSA', | |
| 'water_splitting_potential', 'mass_activity', 'exchange_current_density', | |
| 'Rs', 'specific_activity', 'onset_overpotential', 'BET', 'surface_area', | |
| 'loading', 'apparent_activation_energy' | |
| ] | |
| # Your table representation | |
| table_representation = """ | |
| # Paste your {representation_format.upper()} representation here | |
| """ | |
| # System prompt | |
| system_prompt = """I will extract catalyst performance information from the table and create JSON format. | |
| Performance types: """ + str(PERFORMANCE_LIST) + """ | |
| The JSON format will have performance within the catalyst, with elements: | |
| reaction type, value, electrolyte, condition, current density, versus, substrate. | |
| Output must contain only JSON dictionary.""" | |
| # Extract | |
| response = client.chat.completions.create( | |
| model="gpt-4-turbo-preview", | |
| messages=[ | |
| {{"role": "system", "content": system_prompt}}, | |
| {{"role": "user", "content": table_representation}} | |
| ], | |
| temperature=0 | |
| ) | |
| result = response.choices[0].message.content.strip() | |
| print(json.dumps(json.loads(result), indent=2)) | |
| ''' | |
| return { | |
| "success": True, | |
| "code": code, | |
| "instructions": [ | |
| "1. Install openai package: pip install openai", | |
| "2. Replace YOUR_API_KEY with your OpenAI API key", | |
| "3. Paste your table representation in the designated area", | |
| "4. Run the script" | |
| ] | |
| } | |
| def apply_follow_up_questions( | |
| extraction_result: Dict, | |
| table_representation: str, | |
| session_id: str = "", | |
| table_name: str = "" | |
| ) -> Dict: | |
| """ | |
| Apply follow-up questions to refine and validate extraction results. | |
| This implements the iterative questioning process from the original MaTableGPT | |
| to improve extraction accuracy by: | |
| 1. Verifying catalyst names against the table | |
| 2. Checking performance types | |
| 3. Validating property values | |
| 4. Checking for reaction_type, electrolyte, substrate in title/caption | |
| Args: | |
| extraction_result: Initial extraction result to refine | |
| table_representation: Original table representation for verification | |
| session_id: Optional session ID to save refined results | |
| table_name: Optional table name | |
| Returns: | |
| Dictionary containing refined extraction result | |
| """ | |
| try: | |
| extractor = get_extractor() | |
| # Initialize message context | |
| system_prompt = """You need to modify the JSON representing the table. | |
| JSON template: {'catalyst_name': {'performance_name': {property_template}}} | |
| property_template: {'electrolyte': '', 'reaction_type': '', 'value': '', 'current_density': '', 'overpotential': '', 'potential': '', 'substrate': '', 'versus': '', 'condition': ''} | |
| performance_list = """ + str(GPTExtractor.PERFORMANCE_LIST) + """ | |
| Replace 'catalyst_name' and 'performance_name' with actual names from the table.""" | |
| messages = [{"role": "system", "content": system_prompt}] | |
| # Step 1: Verify catalysts in table | |
| verify_q = f"""<input representation> | |
| {table_representation} | |
| Question 1: List all catalyst names in the table representation as a Python list. Only output the Python list.""" | |
| messages.append({"role": "user", "content": verify_q}) | |
| response = extractor.client.chat.completions.create( | |
| model=extractor.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| catalysts_in_table = response.choices[0].message.content.strip() | |
| messages.append({"role": "assistant", "content": catalysts_in_table}) | |
| # Step 2: Get catalysts from extraction | |
| extraction_catalysts_q = f"""<input json> | |
| {json.dumps(extraction_result)} | |
| Question 2: List all catalyst names from the input json as a Python list. Only output the Python list.""" | |
| messages.append({"role": "user", "content": extraction_catalysts_q}) | |
| response = extractor.client.chat.completions.create( | |
| model=extractor.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| catalysts_in_json = response.choices[0].message.content.strip() | |
| messages.append({"role": "assistant", "content": catalysts_in_json}) | |
| # Step 3: Reconcile catalysts | |
| reconcile_q = """Question 3: Based on answers to Question 1 and 2, modify or remove any catalysts | |
| from Question 2 that don't match Question 1. Output the corrected Python list.""" | |
| messages.append({"role": "user", "content": reconcile_q}) | |
| response = extractor.client.chat.completions.create( | |
| model=extractor.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| reconciled_catalysts = response.choices[0].message.content.strip() | |
| messages.append({"role": "assistant", "content": reconciled_catalysts}) | |
| # Step 4: Check for title/caption info | |
| title_caption_q = f"""<input representation> | |
| {table_representation} | |
| Question 4: Check the title and caption of the table. | |
| - Is there reaction type info (OER, HER, oxygen evolution, hydrogen evolution)? | |
| - Is there electrolyte info? | |
| - Is there substrate info? | |
| Answer in format: {{"reaction_type": "yes/no", "electrolyte": "yes/no", "substrate": "yes/no"}}""" | |
| messages.append({"role": "user", "content": title_caption_q}) | |
| response = extractor.client.chat.completions.create( | |
| model=extractor.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| metadata_check = response.choices[0].message.content.strip() | |
| messages.append({"role": "assistant", "content": metadata_check}) | |
| # Step 5: Apply refinements | |
| refine_q = f"""<input json> | |
| {json.dumps(extraction_result)} | |
| Based on the above analysis: | |
| 1. Keep only catalysts that exist in the table | |
| 2. Remove any 'NA', 'unknown', or empty values | |
| 3. If title/caption lacks reaction_type/electrolyte/substrate info, remove those keys | |
| 4. Output the refined JSON only. No explanation.""" | |
| messages.append({"role": "user", "content": refine_q}) | |
| response = extractor.client.chat.completions.create( | |
| model=extractor.get_model(), | |
| messages=messages, | |
| temperature=0 | |
| ) | |
| refined_result = response.choices[0].message.content.strip() | |
| # Parse result | |
| if "```" in refined_result: | |
| refined_result = refined_result.replace("```json", "").replace("```", "") | |
| try: | |
| refined_json = json.loads(refined_result) | |
| except json.JSONDecodeError: | |
| refined_json = extraction_result # Fall back to original | |
| # Save if session provided | |
| if session_id: | |
| extraction_record = ExtractionResult( | |
| session_id=session_id, | |
| table_name=table_name or "unnamed", | |
| model_type="follow-up-refined", | |
| result=refined_json, | |
| timestamp=datetime.now().isoformat(), | |
| follow_up_applied=True | |
| ) | |
| session_manager.save_extraction(session_id, extraction_record) | |
| return { | |
| "success": True, | |
| "original": extraction_result, | |
| "refined": refined_json, | |
| "follow_up_applied": True, | |
| "verification_steps": { | |
| "catalysts_in_table": catalysts_in_table, | |
| "catalysts_in_json": catalysts_in_json, | |
| "reconciled": reconciled_catalysts, | |
| "metadata_check": metadata_check | |
| } | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "original": extraction_result, | |
| "follow_up_applied": False | |
| } | |
| def evaluate_extraction( | |
| prediction: Dict, | |
| ground_truth: Dict, | |
| evaluation_type: str = "both" | |
| ) -> Dict: | |
| """ | |
| Evaluate extraction results against ground truth. | |
| Computes metrics from the original MaTableGPT evaluation: | |
| - Structure F1 Score: Measures correctness of JSON structure | |
| - Value Accuracy: Measures correctness of extracted values | |
| Args: | |
| prediction: The extracted/predicted result | |
| ground_truth: The expected correct result | |
| evaluation_type: "structure", "value", or "both" | |
| Returns: | |
| Dictionary containing evaluation metrics | |
| """ | |
| import re | |
| import unicodedata | |
| def normalize_text(text: str) -> str: | |
| """Normalize text for comparison.""" | |
| if not isinstance(text, str): | |
| return str(text) | |
| # Remove unicode variations | |
| text = unicodedata.normalize('NFKD', text) | |
| # Common substitutions | |
| text = re.sub(r'–|−', '-', text) | |
| text = re.sub(r'<sup>|</sup>', '', text) | |
| text = re.sub(r'm2 g−1', 'm2/g', text) | |
| text = re.sub(r'mA cm−2', 'mA/cm2', text) | |
| text = re.sub(r'\s+', '', text) | |
| return text.lower() | |
| def get_all_keys(d: Dict, parent_key: str = '', sep: str = '//') -> List[str]: | |
| """Recursively get all keys from nested dict.""" | |
| keys = [] | |
| if isinstance(d, dict): | |
| for k, v in d.items(): | |
| new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
| keys.append(new_key) | |
| keys.extend(get_all_keys(v, new_key, sep)) | |
| elif isinstance(d, list): | |
| for i, item in enumerate(d): | |
| keys.extend(get_all_keys(item, f"{parent_key}[{i}]", sep)) | |
| return keys | |
| def get_key_value_pairs(d: Dict, parent_key: str = '') -> List[tuple]: | |
| """Get all key-value pairs from nested dict.""" | |
| pairs = [] | |
| if isinstance(d, dict): | |
| for k, v in d.items(): | |
| new_key = f"{parent_key}//{k}" if parent_key else k | |
| if isinstance(v, (dict, list)): | |
| pairs.extend(get_key_value_pairs(v, new_key)) | |
| else: | |
| pairs.append((new_key, normalize_text(str(v)))) | |
| elif isinstance(d, list): | |
| for i, item in enumerate(d): | |
| pairs.extend(get_key_value_pairs(item, f"{parent_key}[{i}]")) | |
| return pairs | |
| results = {"success": True} | |
| try: | |
| # Normalize both inputs | |
| pred_keys = get_all_keys(prediction) | |
| gt_keys = get_all_keys(ground_truth) | |
| # Structure F1 Score | |
| if evaluation_type in ["structure", "both"]: | |
| # Remove 'condition' keys as per original | |
| pred_keys = [k for k in pred_keys if 'condition' not in k] | |
| gt_keys = [k for k in gt_keys if 'condition' not in k] | |
| # Calculate TP, FP, FN for structure | |
| tp = len(set(pred_keys) & set(gt_keys)) | |
| fp = len(set(pred_keys) - set(gt_keys)) | |
| fn = len(set(gt_keys) - set(pred_keys)) | |
| if tp + fp + fn > 0: | |
| f1_score = tp / (tp + 0.5 * (fp + fn)) | |
| else: | |
| f1_score = 1.0 if len(gt_keys) == 0 else 0.0 | |
| results["structure_f1"] = round(f1_score, 4) | |
| results["structure_details"] = { | |
| "true_positives": tp, | |
| "false_positives": fp, | |
| "false_negatives": fn, | |
| "matched_keys": list(set(pred_keys) & set(gt_keys))[:10], # Sample | |
| "missing_keys": list(set(gt_keys) - set(pred_keys))[:10], | |
| "extra_keys": list(set(pred_keys) - set(gt_keys))[:10] | |
| } | |
| # Value Accuracy | |
| if evaluation_type in ["value", "both"]: | |
| pred_pairs = get_key_value_pairs(prediction) | |
| gt_pairs = get_key_value_pairs(ground_truth) | |
| # Compare values | |
| correct = 0 | |
| total = len(gt_pairs) | |
| pred_dict = {k: v for k, v in pred_pairs} | |
| for key, value in gt_pairs: | |
| if key in pred_dict: | |
| # Normalize and compare | |
| if normalize_text(pred_dict[key]) == normalize_text(value): | |
| correct += 1 | |
| value_accuracy = correct / total if total > 0 else 1.0 | |
| results["value_accuracy"] = round(value_accuracy, 4) | |
| results["value_details"] = { | |
| "correct_values": correct, | |
| "total_values": total, | |
| "accuracy_percentage": round(value_accuracy * 100, 2) | |
| } | |
| # Overall score | |
| if evaluation_type == "both": | |
| results["overall_score"] = round( | |
| (results["structure_f1"] + results["value_accuracy"]) / 2, 4 | |
| ) | |
| except Exception as e: | |
| results["success"] = False | |
| results["error"] = str(e) | |
| return results | |
| def batch_extract_tables( | |
| tables: List[Dict], | |
| model_type: str = "zero-shot", | |
| apply_follow_up: bool = False, | |
| session_id: str = "" | |
| ) -> Dict: | |
| """ | |
| Extract data from multiple tables in batch. | |
| Args: | |
| tables: List of {"html": html_table, "title": title, "caption": caption, "name": table_name} | |
| model_type: "zero-shot", "few-shot", or "fine-tuning" | |
| apply_follow_up: Whether to apply follow-up questions for refinement | |
| session_id: Optional session ID | |
| Returns: | |
| Dictionary containing all extraction results | |
| """ | |
| if not session_id: | |
| session_id = session_manager.create_session() | |
| results = { | |
| "success": True, | |
| "session_id": session_id, | |
| "total_tables": len(tables), | |
| "extractions": [] | |
| } | |
| for i, table_info in enumerate(tables): | |
| html = table_info.get("html", "") | |
| title = table_info.get("title", "") | |
| caption = table_info.get("caption", "") | |
| table_name = table_info.get("name", f"table_{i+1}") | |
| try: | |
| # Convert to representation | |
| representation = table_representer.html_to_tsv(html, title, caption) | |
| # Extract based on model type | |
| extractor = get_extractor() | |
| if model_type == "zero-shot": | |
| extraction = extractor.extract_zero_shot(representation) | |
| elif model_type == "few-shot": | |
| extraction = extractor.extract_few_shot(representation) | |
| else: | |
| extraction = {"error": "Fine-tuning requires model_name parameter"} | |
| # Apply follow-up if requested | |
| if apply_follow_up and "error" not in extraction: | |
| from copy import deepcopy | |
| follow_up_result = apply_follow_up_questions( | |
| deepcopy(extraction), | |
| representation, | |
| session_id, | |
| table_name | |
| ) | |
| if follow_up_result.get("success"): | |
| extraction = follow_up_result.get("refined", extraction) | |
| results["extractions"].append({ | |
| "table_name": table_name, | |
| "success": True, | |
| "extraction": extraction | |
| }) | |
| except Exception as e: | |
| results["extractions"].append({ | |
| "table_name": table_name, | |
| "success": False, | |
| "error": str(e) | |
| }) | |
| results["successful_extractions"] = sum(1 for e in results["extractions"] if e["success"]) | |
| results["failed_extractions"] = results["total_tables"] - results["successful_extractions"] | |
| return results | |
| def format_extraction_as_table( | |
| extraction: Dict, | |
| output_format: str = "markdown", | |
| save_path: str = "" | |
| ) -> Dict: | |
| """ | |
| Format extraction results as a readable table and optionally save to file. | |
| Converts the nested extraction JSON into a flat table format that's easy | |
| to read and can be saved as CSV, Markdown, or JSON. | |
| Args: | |
| extraction: The extraction result from any extract_catalyst_data_* tool | |
| output_format: Output format - "markdown", "csv", "json", or "html" | |
| save_path: Optional file path to save the table (e.g., "results.csv") | |
| Returns: | |
| Dictionary containing formatted table and save status | |
| """ | |
| try: | |
| rows = [] | |
| # Handle different extraction structures | |
| catalysts_data = extraction | |
| # If wrapped in "catalysts" list | |
| if isinstance(extraction, dict) and "catalysts" in extraction: | |
| catalysts_data = extraction["catalysts"] | |
| # If it's a list of catalyst dicts | |
| if isinstance(catalysts_data, list): | |
| for item in catalysts_data: | |
| if isinstance(item, dict): | |
| for catalyst_name, performances in item.items(): | |
| if isinstance(performances, dict): | |
| for perf_name, properties in performances.items(): | |
| row = { | |
| "Catalyst": catalyst_name, | |
| "Performance": perf_name | |
| } | |
| if isinstance(properties, dict): | |
| for prop_key, prop_val in properties.items(): | |
| if isinstance(prop_val, list): | |
| row[prop_key.capitalize()] = "; ".join(str(v) for v in prop_val) | |
| else: | |
| row[prop_key.capitalize()] = str(prop_val) if prop_val else "" | |
| else: | |
| row["Value"] = str(properties) | |
| rows.append(row) | |
| # If it's a single dict of catalysts | |
| elif isinstance(catalysts_data, dict): | |
| for catalyst_name, performances in catalysts_data.items(): | |
| if catalyst_name in ["error", "raw_response", "success", "model_type"]: | |
| continue | |
| if isinstance(performances, dict): | |
| for perf_name, properties in performances.items(): | |
| row = { | |
| "Catalyst": catalyst_name, | |
| "Performance": perf_name | |
| } | |
| if isinstance(properties, dict): | |
| for prop_key, prop_val in properties.items(): | |
| if isinstance(prop_val, list): | |
| row[prop_key.capitalize()] = "; ".join(str(v) for v in prop_val) | |
| else: | |
| row[prop_key.capitalize()] = str(prop_val) if prop_val else "" | |
| else: | |
| row["Value"] = str(properties) | |
| rows.append(row) | |
| if not rows: | |
| return { | |
| "success": False, | |
| "error": "No catalyst data found in extraction", | |
| "raw_extraction": extraction | |
| } | |
| # Create DataFrame | |
| df = pd.DataFrame(rows) | |
| # Format output | |
| if output_format == "markdown": | |
| # Create markdown table | |
| headers = df.columns.tolist() | |
| md_lines = [] | |
| md_lines.append("| " + " | ".join(headers) + " |") | |
| md_lines.append("| " + " | ".join(["---"] * len(headers)) + " |") | |
| for _, row in df.iterrows(): | |
| md_lines.append("| " + " | ".join(str(v) for v in row.values) + " |") | |
| formatted_table = "\n".join(md_lines) | |
| elif output_format == "csv": | |
| formatted_table = df.to_csv(index=False) | |
| elif output_format == "json": | |
| formatted_table = df.to_json(orient="records", indent=2) | |
| elif output_format == "html": | |
| formatted_table = df.to_html(index=False, classes="catalyst-table") | |
| else: | |
| formatted_table = df.to_string(index=False) | |
| result = { | |
| "success": True, | |
| "format": output_format, | |
| "row_count": len(rows), | |
| "columns": df.columns.tolist(), | |
| "table": formatted_table | |
| } | |
| # Save to file if path provided | |
| if save_path: | |
| try: | |
| # Determine save format from extension | |
| ext = os.path.splitext(save_path)[1].lower() | |
| if ext == ".csv": | |
| df.to_csv(save_path, index=False) | |
| elif ext == ".json": | |
| df.to_json(save_path, orient="records", indent=2) | |
| elif ext == ".html": | |
| df.to_html(save_path, index=False) | |
| elif ext == ".xlsx": | |
| df.to_excel(save_path, index=False) | |
| elif ext == ".md": | |
| with open(save_path, "w", encoding="utf-8") as f: | |
| f.write(formatted_table if output_format == "markdown" else df.to_markdown(index=False)) | |
| else: | |
| # Default to CSV | |
| df.to_csv(save_path, index=False) | |
| result["saved_to"] = save_path | |
| result["save_success"] = True | |
| except Exception as e: | |
| result["save_success"] = False | |
| result["save_error"] = str(e) | |
| return result | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "raw_extraction": extraction | |
| } | |
| def export_session_results( | |
| session_id: str, | |
| output_format: str = "csv", | |
| save_dir: str = "" | |
| ) -> Dict: | |
| """ | |
| Export all extraction results from a session as formatted tables. | |
| Combines all extractions from a session into organized output files. | |
| Args: | |
| session_id: The session ID to export | |
| output_format: Output format - "csv", "json", "markdown", or "excel" | |
| save_dir: Directory to save files (optional, uses temp dir if not provided) | |
| Returns: | |
| Dictionary containing export status and file paths | |
| """ | |
| try: | |
| session = session_manager.get_session(session_id) | |
| if not session: | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| if not session.extractions: | |
| return {"success": False, "error": "No extractions in this session"} | |
| # Use temp dir if no save_dir provided | |
| if not save_dir: | |
| save_dir = tempfile.mkdtemp(prefix="matablgpt_export_") | |
| os.makedirs(save_dir, exist_ok=True) | |
| all_rows = [] | |
| exported_files = [] | |
| for extraction in session.extractions: | |
| # Format each extraction | |
| format_result = format_extraction_as_table( | |
| extraction.result, | |
| output_format="csv" # Always use CSV internally for combining | |
| ) | |
| if format_result.get("success") and "table" in format_result: | |
| # Parse the CSV back to add metadata | |
| import io | |
| df = pd.read_csv(io.StringIO(format_result["table"])) | |
| df["Table_Name"] = extraction.table_name | |
| df["Model_Type"] = extraction.model_type | |
| df["Timestamp"] = extraction.timestamp | |
| df["Follow_Up"] = extraction.follow_up_applied | |
| all_rows.append(df) | |
| if not all_rows: | |
| return {"success": False, "error": "No valid extractions to export"} | |
| # Combine all extractions | |
| combined_df = pd.concat(all_rows, ignore_index=True) | |
| # Save based on format | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| base_name = f"extraction_{session_id}_{timestamp}" | |
| if output_format == "csv": | |
| file_path = os.path.join(save_dir, f"{base_name}.csv") | |
| combined_df.to_csv(file_path, index=False) | |
| elif output_format == "json": | |
| file_path = os.path.join(save_dir, f"{base_name}.json") | |
| combined_df.to_json(file_path, orient="records", indent=2) | |
| elif output_format == "excel": | |
| file_path = os.path.join(save_dir, f"{base_name}.xlsx") | |
| combined_df.to_excel(file_path, index=False) | |
| elif output_format == "markdown": | |
| file_path = os.path.join(save_dir, f"{base_name}.md") | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.write(f"# Extraction Results\n\n") | |
| f.write(f"Session: {session_id}\n\n") | |
| f.write(f"Exported: {timestamp}\n\n") | |
| f.write(combined_df.to_markdown(index=False)) | |
| else: | |
| file_path = os.path.join(save_dir, f"{base_name}.csv") | |
| combined_df.to_csv(file_path, index=False) | |
| exported_files.append(file_path) | |
| # Also create a summary | |
| summary = { | |
| "session_id": session_id, | |
| "total_extractions": len(session.extractions), | |
| "total_rows": len(combined_df), | |
| "catalysts": combined_df["Catalyst"].unique().tolist() if "Catalyst" in combined_df.columns else [], | |
| "performances": combined_df["Performance"].unique().tolist() if "Performance" in combined_df.columns else [] | |
| } | |
| summary_path = os.path.join(save_dir, f"{base_name}_summary.json") | |
| with open(summary_path, "w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2) | |
| exported_files.append(summary_path) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "export_dir": save_dir, | |
| "files": exported_files, | |
| "summary": summary, | |
| "preview": combined_df.head(10).to_dict(orient="records") | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def get_environment_requirements() -> Dict: | |
| """ | |
| Get the required environment setup for MaTableGPT. | |
| Returns package requirements and setup instructions. | |
| Supports third-party API services (reverse proxy, API aggregators). | |
| Returns: | |
| Dictionary containing requirements and instructions | |
| """ | |
| return { | |
| "success": True, | |
| "python_version": ">=3.8", | |
| "required_packages": [ | |
| "openai>=1.0.0 # OpenAI-compatible client, works with third-party APIs", | |
| "beautifulsoup4>=4.9.0", | |
| "pandas>=1.0.0", | |
| "lxml>=4.0.0", | |
| "mcp>=0.1.0" | |
| ], | |
| "optional_packages": [ | |
| "nltk>=3.6.0 # For table splitting analysis" | |
| ], | |
| "environment_variables": { | |
| "LLM_API_KEY": "(Required) Your API key from third-party service", | |
| "LLM_API_BASE": "(Required) API base URL, e.g., https://api.your-service.com/v1", | |
| "LLM_MODEL": "(Optional) Model name, default: gpt-4-turbo-preview", | |
| "---": "--- Alternative variable names (also supported) ---", | |
| "OPENAI_API_KEY": "Alternative to LLM_API_KEY", | |
| "OPENAI_API_BASE": "Alternative to LLM_API_BASE", | |
| "OPENAI_MODEL": "Alternative to LLM_MODEL" | |
| }, | |
| "setup_instructions": [ | |
| "1. Create virtual environment: python -m venv venv", | |
| "2. Activate: venv\\Scripts\\activate (Windows) or source venv/bin/activate (Unix)", | |
| "3. Install: pip install -r requirements.txt", | |
| "4. Set environment variables (use your API provider's info):", | |
| " - LLM_API_KEY=your_api_key (Required)", | |
| " - LLM_API_BASE=https://api.your-service.com/v1 (Required)", | |
| " - LLM_MODEL=gpt-4-turbo-preview (Optional)", | |
| "5. Run: python start_mcp.py" | |
| ], | |
| "third_party_api_example": { | |
| "description": "Configuration for third-party API services (reverse proxy, OneAPI, etc.)", | |
| "windows_powershell": [ | |
| "$env:LLM_API_KEY = 'sk-xxxx'", | |
| "$env:LLM_API_BASE = 'https://api.your-service.com/v1'", | |
| "$env:LLM_MODEL = 'gpt-4-turbo-preview'", | |
| "python start_mcp.py" | |
| ], | |
| "windows_cmd": [ | |
| "set LLM_API_KEY=sk-xxxx", | |
| "set LLM_API_BASE=https://api.your-service.com/v1", | |
| "set LLM_MODEL=gpt-4-turbo-preview", | |
| "python start_mcp.py" | |
| ], | |
| "unix_bash": [ | |
| "export LLM_API_KEY=sk-xxxx", | |
| "export LLM_API_BASE=https://api.your-service.com/v1", | |
| "export LLM_MODEL=gpt-4-turbo-preview", | |
| "python start_mcp.py" | |
| ], | |
| "docker_env": [ | |
| "-e LLM_API_KEY=sk-xxxx", | |
| "-e LLM_API_BASE=https://api.your-service.com/v1", | |
| "-e LLM_MODEL=gpt-4-turbo-preview" | |
| ], | |
| "huggingface_secrets": [ | |
| "LLM_API_KEY = sk-xxxx", | |
| "LLM_API_BASE = https://api.your-service.com/v1", | |
| "LLM_MODEL = gpt-4-turbo-preview" | |
| ] | |
| } | |
| } | |
| # ============================================================================= | |
| # Server Entry Point | |
| # ============================================================================= | |
| def main(): | |
| """Run the MCP server.""" | |
| mcp.run() | |
| if __name__ == "__main__": | |
| main() | |