Spaces:
Sleeping
Sleeping
| """ | |
| Model inference logic for XRD pattern analysis. | |
| Loads the pretrained model from HuggingFace Hub and runs predictions. | |
| """ | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| import spglib | |
| import torch | |
| class XRDModelInference: | |
| """Handles loading and inference for the XRD analysis model""" | |
| # Build a lookup table mapping space group number (1-230) to the | |
| # corresponding Hall number. spglib.get_spacegroup_type() is indexed | |
| # by Hall number (1-530), NOT by space group number. We pick the | |
| # first (standard-setting) Hall number for each space group. | |
| _sg_to_hall: Dict[int, int] = {} | |
| for _hall in range(1, 531): | |
| _sg_type = spglib.get_spacegroup_type(_hall) | |
| _sg_num = _sg_type.number if hasattr(_sg_type, "number") else _sg_type["number"] | |
| if _sg_num not in _sg_to_hall: | |
| _sg_to_hall[_sg_num] = _hall | |
| HF_REPO_ID = "linked-liszt/OpenAlphaDiffract" | |
| def __init__(self): | |
| self.model = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def is_loaded(self) -> bool: | |
| """Check if model is loaded""" | |
| return self.model is not None | |
| def load_model(self): | |
| """Download and load the pretrained model from HuggingFace Hub.""" | |
| try: | |
| from huggingface_hub import snapshot_download | |
| print(f"Downloading model from {self.HF_REPO_ID}...") | |
| model_dir = snapshot_download(self.HF_REPO_ID) | |
| print(f"Model downloaded to {model_dir}") | |
| # Import the pure-PyTorch model class from the downloaded repo | |
| sys.path.insert(0, model_dir) | |
| from model import AlphaDiffract | |
| self.model = AlphaDiffract.from_pretrained( | |
| model_dir, device=str(self.device) | |
| ) | |
| print(f"Model loaded successfully on {self.device}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| self.model = None | |
| def preprocess_data(self, x: List[float], y: List[float]) -> torch.Tensor: | |
| """ | |
| Preprocess XRD data for model input. | |
| Args: | |
| x: 2theta values | |
| y: Intensity values | |
| Returns: | |
| Preprocessed tensor ready for model input | |
| """ | |
| y_array = np.array(y, dtype=np.float32) | |
| # Floor at zero (remove any negative values) | |
| y_array = np.maximum(y_array, 0.0) | |
| # Rescale intensity values to [0, 100] range (matching training preprocessing) | |
| y_min = np.min(y_array) | |
| y_max = np.max(y_array) | |
| if y_max - y_min < 1e-10: | |
| y_scaled = np.zeros_like(y_array, dtype=np.float32) | |
| else: | |
| y_normalized = (y_array - y_min) / (y_max - y_min) | |
| y_scaled = y_normalized * 100.0 | |
| tensor = torch.from_numpy(y_scaled).unsqueeze(0) | |
| return tensor.to(self.device) | |
| def predict(self, x: List[float], y: List[float]) -> Dict: | |
| """ | |
| Run inference on XRD data. | |
| Args: | |
| x: 2theta values | |
| y: Intensity values | |
| Returns: | |
| Dictionary with predictions and confidence scores | |
| """ | |
| if self.model is None: | |
| return { | |
| "status": "error", | |
| "error": "Model not loaded.", | |
| "http_status": 503, | |
| } | |
| try: | |
| input_tensor = self.preprocess_data(x, y) | |
| with torch.no_grad(): | |
| output = self.model(input_tensor) | |
| processed = self._process_model_output(output) | |
| overall_confidence = self._compute_overall_confidence(processed) | |
| predictions = { | |
| "status": "success", | |
| "predictions": processed, | |
| "model_info": { | |
| "type": "AlphaDiffract", | |
| "device": str(self.device), | |
| }, | |
| } | |
| if overall_confidence is not None: | |
| predictions["confidence"] = overall_confidence | |
| return predictions | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| "http_status": 500, | |
| } | |
| def _process_model_output(self, output) -> Dict: | |
| """Process raw model output into readable predictions""" | |
| if isinstance(output, dict): | |
| predictions = [] | |
| # Crystal System prediction (7 classes) | |
| if "cs_logits" in output: | |
| cs_logits = output["cs_logits"].cpu() | |
| cs_probs = torch.softmax(cs_logits, dim=-1) | |
| cs_prob, cs_idx = torch.max(cs_probs, dim=-1) | |
| cs_names = [ | |
| "Triclinic", "Monoclinic", "Orthorhombic", "Tetragonal", | |
| "Trigonal", "Hexagonal", "Cubic", | |
| ] | |
| cs_all_probs = [ | |
| { | |
| "class_name": cs_names[i], | |
| "probability": float(cs_probs[0, i].item()), | |
| } | |
| for i in range(len(cs_names)) | |
| ] | |
| cs_all_probs.sort(key=lambda x: x["probability"], reverse=True) | |
| predictions.append({ | |
| "phase": "Crystal System", | |
| "predicted_class": cs_names[cs_idx.item()], | |
| "confidence": float(cs_prob.item()), | |
| "all_probabilities": cs_all_probs, | |
| }) | |
| # Space Group prediction (230 classes) | |
| if "sg_logits" in output: | |
| sg_logits = output["sg_logits"].cpu() | |
| sg_probs = torch.softmax(sg_logits, dim=-1) | |
| sg_prob, sg_idx = torch.max(sg_probs, dim=-1) | |
| sg_number = sg_idx.item() + 1 | |
| top_k = min(10, sg_probs.shape[-1]) | |
| top_probs, top_indices = torch.topk(sg_probs[0], top_k) | |
| sg_top_probs = [ | |
| { | |
| "space_group_number": int(idx.item()) + 1, | |
| "space_group_symbol": self._get_space_group_symbol(int(idx.item()) + 1), | |
| "probability": float(prob.item()), | |
| } | |
| for prob, idx in zip(top_probs, top_indices) | |
| ] | |
| predictions.append({ | |
| "phase": "Space Group", | |
| "predicted_class": f"#{sg_number}", | |
| "space_group_symbol": self._get_space_group_symbol(sg_number), | |
| "confidence": float(sg_prob.item()), | |
| "top_probabilities": sg_top_probs, | |
| }) | |
| # Lattice Parameters | |
| if "lp" in output: | |
| lp_raw = output["lp"].cpu() | |
| if lp_raw.shape[0] == 1: | |
| lp = lp_raw[0].numpy() | |
| else: | |
| lp = lp_raw.squeeze().numpy() | |
| lp_labels = ["a", "b", "c", "\u03b1", "\u03b2", "\u03b3"] | |
| predictions.append({ | |
| "phase": "Lattice Parameters", | |
| "lattice_params": { | |
| label: float(val) for label, val in zip(lp_labels, lp) | |
| }, | |
| "is_lattice": True, | |
| }) | |
| return { | |
| "phase_predictions": predictions, | |
| "intensity_profile": [], | |
| } | |
| elif isinstance(output, torch.Tensor): | |
| probs = output.cpu().numpy() | |
| confidence = None | |
| if output.ndim >= 1 and output.shape[-1] > 1: | |
| prob_tensor = torch.softmax(output, dim=-1) | |
| confidence = float(prob_tensor.max().item()) | |
| predictions = [{"phase": "Predicted Phase", "details": f"Output shape: {probs.shape}"}] | |
| if confidence is not None: | |
| predictions[0]["confidence"] = confidence | |
| return { | |
| "phase_predictions": predictions, | |
| "intensity_profile": probs.tolist() if len(probs.shape) == 1 else [], | |
| } | |
| return {"phase_predictions": [], "intensity_profile": []} | |
| def _get_space_group_symbol(self, sg_number: int) -> str: | |
| """Get space group symbol from number using spglib.""" | |
| if sg_number < 1 or sg_number > 230: | |
| return f"SG{sg_number}" | |
| try: | |
| hall_number = self._sg_to_hall.get(sg_number) | |
| if hall_number is None: | |
| return f"SG{sg_number}" | |
| sg_type = spglib.get_spacegroup_type(hall_number) | |
| if sg_type is not None: | |
| symbol = ( | |
| sg_type.international_short | |
| if hasattr(sg_type, "international_short") | |
| else sg_type["international_short"] | |
| ) | |
| return symbol | |
| return f"SG{sg_number}" | |
| except Exception: | |
| return f"SG{sg_number}" | |
| def _compute_overall_confidence(self, processed: Dict) -> Optional[float]: | |
| """Compute overall confidence from available per-phase confidences.""" | |
| phase_predictions = ( | |
| processed.get("phase_predictions", []) if isinstance(processed, dict) else [] | |
| ) | |
| confidences = [ | |
| float(p["confidence"]) | |
| for p in phase_predictions | |
| if isinstance(p, dict) and "confidence" in p and p["confidence"] is not None | |
| ] | |
| if not confidences: | |
| return None | |
| return float(np.mean(confidences)) | |