Spaces:
Sleeping
Sleeping
| import React, { createContext, useContext, useState, useMemo, useCallback } from 'react' | |
| const XRDContext = createContext() | |
| export const useXRD = () => { | |
| const context = useContext(XRDContext) | |
| if (!context) { | |
| throw new Error('useXRD must be used within XRDProvider') | |
| } | |
| return context | |
| } | |
| export const XRDProvider = ({ children }) => { | |
| // Model training specifications (from simulator.yaml) | |
| const MODEL_INPUT_SIZE = 8192 | |
| const MODEL_WAVELENGTH = 0.6199 // Ångströms (synchrotron) | |
| const MODEL_MIN_2THETA = 5.0 // degrees | |
| const MODEL_MAX_2THETA = 20.0 // degrees | |
| // Raw data from file upload | |
| const [rawData, setRawData] = useState(null) | |
| const [filename, setFilename] = useState(null) | |
| // Wavelength management | |
| const [detectedWavelength, setDetectedWavelength] = useState(null) | |
| const [userWavelength, setUserWavelength] = useState(MODEL_WAVELENGTH) | |
| const [wavelengthSource, setWavelengthSource] = useState('default') // 'detected', 'user', 'default' | |
| // Processing parameters | |
| const [baselineCorrection, setBaselineCorrection] = useState(false) | |
| const [interpolationEnabled, setInterpolationEnabled] = useState(true) | |
| const [scalingEnabled, setScalingEnabled] = useState(true) | |
| const [interpolationStrategy, setInterpolationStrategy] = useState('linear') // 'linear' or 'cubic' | |
| // Warnings and metadata | |
| const [dataWarnings, setDataWarnings] = useState([]) | |
| // Model results from API | |
| const [modelResults, setModelResults] = useState(null) | |
| const [isLoading, setIsLoading] = useState(false) | |
| const [analysisStatus, setAnalysisStatus] = useState('IDLE') // IDLE, PROCESSING, COMPLETE | |
| // UI state | |
| const [isLogitDrawerOpen, setIsLogitDrawerOpen] = useState(false) | |
| // Request tracking - ensure every click creates a new request | |
| const [analysisCount, setAnalysisCount] = useState(0) | |
| // Convert wavelength using Bragg's law: λ = 2d·sin(θ) | |
| // For same d-spacing: sin(θ₂) = (λ₂/λ₁)·sin(θ₁) | |
| const convertWavelength = (theta_deg, sourceWavelength, targetWavelength) => { | |
| if (Math.abs(sourceWavelength - targetWavelength) < 0.0001) { | |
| return theta_deg // No conversion needed | |
| } | |
| const theta_rad = (theta_deg * Math.PI) / 180 | |
| const sin_theta2 = (targetWavelength / sourceWavelength) * Math.sin(theta_rad) | |
| // Check if conversion is physically possible | |
| if (Math.abs(sin_theta2) > 1) { | |
| return null // Peak not observable at target wavelength | |
| } | |
| const theta2_rad = Math.asin(sin_theta2) | |
| return (theta2_rad * 180) / Math.PI | |
| } | |
| // Interpolate data to fixed size for model input | |
| const interpolateData = (x, y, targetSize, xMin, xMax, strategy = 'linear') => { | |
| if (x.length === targetSize && xMin === undefined) { | |
| return { x, y } | |
| } | |
| const minX = xMin !== undefined ? xMin : Math.min(...x) | |
| const maxX = xMax !== undefined ? xMax : Math.max(...x) | |
| const step = (maxX - minX) / (targetSize - 1) | |
| const newX = Array.from({ length: targetSize }, (_, i) => minX + i * step) | |
| const newY = new Array(targetSize) | |
| // Get data range bounds | |
| const dataMinX = Math.min(...x) | |
| const dataMaxX = Math.max(...x) | |
| if (strategy === 'linear') { | |
| // Linear interpolation | |
| for (let i = 0; i < targetSize; i++) { | |
| const targetX = newX[i] | |
| // Check if out of range - set to 0 instead of extrapolating | |
| if (targetX < dataMinX || targetX > dataMaxX) { | |
| newY[i] = 0 | |
| continue | |
| } | |
| // Find surrounding points | |
| let idx = x.findIndex(val => val >= targetX) | |
| if (idx === -1) idx = x.length - 1 | |
| if (idx === 0) idx = 1 | |
| const x0 = x[idx - 1] | |
| const x1 = x[idx] | |
| const y0 = y[idx - 1] | |
| const y1 = y[idx] | |
| // Linear interpolation | |
| newY[i] = y0 + ((targetX - x0) * (y1 - y0)) / (x1 - x0) | |
| } | |
| } else if (strategy === 'cubic') { | |
| // Cubic spline interpolation (simplified Catmull-Rom) | |
| for (let i = 0; i < targetSize; i++) { | |
| const targetX = newX[i] | |
| // Check if out of range - set to 0 instead of extrapolating | |
| if (targetX < dataMinX || targetX > dataMaxX) { | |
| newY[i] = 0 | |
| continue | |
| } | |
| // Find surrounding points | |
| let idx = x.findIndex(val => val >= targetX) | |
| if (idx === -1) idx = x.length - 1 | |
| if (idx === 0) idx = 1 | |
| // Get 4 points for cubic interpolation | |
| const i0 = Math.max(0, idx - 2) | |
| const i1 = Math.max(0, idx - 1) | |
| const i2 = Math.min(x.length - 1, idx) | |
| const i3 = Math.min(x.length - 1, idx + 1) | |
| // Use linear interpolation if we don't have enough points | |
| if (i2 === i1) { | |
| newY[i] = y[i1] | |
| } else { | |
| const t = (targetX - x[i1]) / (x[i2] - x[i1]) | |
| const t2 = t * t | |
| const t3 = t2 * t | |
| // Catmull-Rom spline coefficients | |
| const v0 = y[i0] | |
| const v1 = y[i1] | |
| const v2 = y[i2] | |
| const v3 = y[i3] | |
| newY[i] = 0.5 * ( | |
| 2 * v1 + | |
| (-v0 + v2) * t + | |
| (2 * v0 - 5 * v1 + 4 * v2 - v3) * t2 + | |
| (-v0 + 3 * v1 - 3 * v2 + v3) * t3 | |
| ) | |
| } | |
| } | |
| } | |
| return { x: newX, y: newY } | |
| } | |
| // Process data with optional interpolation | |
| const processedData = useMemo(() => { | |
| if (!rawData) return null | |
| try { | |
| const warnings = [] | |
| let processedY = [...rawData.y] | |
| let processedX = [...rawData.x] | |
| // Step 1: Wavelength conversion (if needed) | |
| const sourceWavelength = userWavelength | |
| if (sourceWavelength && Math.abs(sourceWavelength - MODEL_WAVELENGTH) > 0.0001) { | |
| const convertedData = [] | |
| for (let i = 0; i < processedX.length; i++) { | |
| const convertedTheta = convertWavelength(processedX[i], sourceWavelength, MODEL_WAVELENGTH) | |
| if (convertedTheta !== null) { | |
| convertedData.push({ x: convertedTheta, y: processedY[i] }) | |
| } | |
| } | |
| if (convertedData.length < processedX.length) { | |
| warnings.push(`${processedX.length - convertedData.length} points outside physical range after wavelength conversion`) | |
| } | |
| processedX = convertedData.map(d => d.x) | |
| processedY = convertedData.map(d => d.y) | |
| warnings.push(`Converted from ${sourceWavelength.toFixed(4)} Å to ${MODEL_WAVELENGTH} Å`) | |
| } | |
| // Step 2: Apply baseline correction if enabled | |
| if (baselineCorrection) { | |
| const baseline = Math.min(...processedY) | |
| processedY = processedY.map(val => val - baseline) | |
| } | |
| // Step 3: Crop to model's 2θ range (5-20°) | |
| const inRangeData = [] | |
| for (let i = 0; i < processedX.length; i++) { | |
| if (processedX[i] >= MODEL_MIN_2THETA && processedX[i] <= MODEL_MAX_2THETA) { | |
| inRangeData.push({ x: processedX[i], y: processedY[i] }) | |
| } | |
| } | |
| if (inRangeData.length === 0) { | |
| warnings.push(`⚠️ No data points in model range (${MODEL_MIN_2THETA}-${MODEL_MAX_2THETA}°)`) | |
| // Use original data but warn | |
| inRangeData.push(...processedX.map((x, i) => ({ x, y: processedY[i] }))) | |
| } else if (inRangeData.length < processedX.length) { | |
| const coverage = (inRangeData.length / processedX.length * 100).toFixed(1) | |
| warnings.push(`${coverage}% of data in model range (${MODEL_MIN_2THETA}-${MODEL_MAX_2THETA}°)`) | |
| } | |
| let croppedX = inRangeData.map(d => d.x) | |
| let croppedY = inRangeData.map(d => d.y) | |
| // Step 4: Apply 0-100 scaling if enabled (matching training data) | |
| // NOTE: Scaling happens AFTER cropping so the max peak in the visible range = 100 | |
| if (scalingEnabled) { | |
| const minY = Math.min(...croppedY) | |
| const maxY = Math.max(...croppedY) | |
| if (maxY - minY > 0) { | |
| croppedY = croppedY.map(val => ((val - minY) / (maxY - minY)) * 100) | |
| } | |
| } | |
| // Step 5: Interpolate to model input size with fixed range | |
| const interpolated = interpolateData( | |
| croppedX, | |
| croppedY, | |
| MODEL_INPUT_SIZE, | |
| MODEL_MIN_2THETA, | |
| MODEL_MAX_2THETA, | |
| interpolationStrategy | |
| ) | |
| // Update warnings | |
| setDataWarnings(warnings) | |
| return { | |
| x: interpolated.x, | |
| y: interpolated.y | |
| } | |
| } catch (error) { | |
| console.error('Error processing data:', error) | |
| setDataWarnings([`Error: ${error.message}`]) | |
| return rawData | |
| } | |
| }, [rawData, baselineCorrection, userWavelength, interpolationStrategy, scalingEnabled]) | |
| // Extract metadata from CIF/DIF files | |
| const extractMetadata = (text) => { | |
| const metadata = { | |
| wavelength: null, | |
| cellParams: null, | |
| spaceGroup: null, | |
| crystalSystem: null | |
| } | |
| const lines = text.split('\n') | |
| // Common wavelength patterns in headers | |
| const wavelengthPatterns = [ | |
| /wavelength[:\s=]+([0-9.]+)/i, | |
| /lambda[:\s=]+([0-9.]+)/i, | |
| /wave[:\s=]+([0-9.]+)/i, | |
| /_pd_wavelength[:\s]+([0-9.]+)/i, // CIF format | |
| /_diffrn_radiation_wavelength[:\s]+([0-9.]+)/i, // CIF format | |
| /radiation.*?([0-9.]+)\s*[AÅ]/i, | |
| ] | |
| for (const line of lines) { | |
| // Extract wavelength | |
| if (!metadata.wavelength) { | |
| for (const pattern of wavelengthPatterns) { | |
| const match = line.match(pattern) | |
| if (match && match[1]) { | |
| const wavelength = parseFloat(match[1]) | |
| if (wavelength > 0.1 && wavelength < 3.0) { // Reasonable X-ray range | |
| metadata.wavelength = wavelength | |
| break | |
| } | |
| } | |
| } | |
| // Check for common radiation types | |
| if (/Cu\s*K[αa]/i.test(line)) metadata.wavelength = 1.5406 // Cu Kα | |
| else if (/Mo\s*K[αa]/i.test(line)) metadata.wavelength = 0.7107 // Mo Kα | |
| else if (/Co\s*K[αa]/i.test(line)) metadata.wavelength = 1.7889 // Co Kα | |
| else if (/Cr\s*K[αa]/i.test(line)) metadata.wavelength = 2.2897 // Cr Kα | |
| } | |
| // Extract cell parameters (DIF format) | |
| if (/CELL PARAMETERS:/i.test(line)) { | |
| const match = line.match(/CELL PARAMETERS:\s*([\d.\s]+)/) | |
| if (match) { | |
| metadata.cellParams = match[1].trim() | |
| } | |
| } | |
| // Extract space group | |
| if (/SPACE GROUP:/i.test(line) || /_symmetry_Int_Tables_number/i.test(line)) { | |
| const match = line.match(/(?:SPACE GROUP:|_symmetry_Int_Tables_number)[:\s]+(\d+)/) | |
| if (match) { | |
| metadata.spaceGroup = match[1] | |
| } | |
| } | |
| // Extract crystal system | |
| if (/Crystal System:/i.test(line)) { | |
| const match = line.match(/Crystal System:\s*(\d+)/) | |
| if (match) { | |
| metadata.crystalSystem = match[1] | |
| } | |
| } | |
| } | |
| return metadata | |
| } | |
| // Parse CIF format data | |
| const parseCIF = (text) => { | |
| const lines = text.split('\n') | |
| const x = [] | |
| const y = [] | |
| let inDataLoop = false | |
| let dataColumns = [] | |
| let thetaIndex = -1 | |
| let intensityIndex = -1 | |
| for (let i = 0; i < lines.length; i++) { | |
| const line = lines[i].trim() | |
| // Detect start of data loop | |
| if (line === 'loop_') { | |
| inDataLoop = true | |
| dataColumns = [] | |
| continue | |
| } | |
| // Collect column names in loop | |
| if (inDataLoop && line.startsWith('_')) { | |
| dataColumns.push(line) | |
| // Identify 2theta column | |
| if (/_pd_meas_angle_2theta/i.test(line) || /_pd_calc_angle_2theta/i.test(line)) { | |
| thetaIndex = dataColumns.length - 1 | |
| } | |
| // Identify intensity column | |
| if (/_pd_proc_intensity/i.test(line) || /_pd_calc_intensity/i.test(line) || /_pd_meas_counts/i.test(line)) { | |
| intensityIndex = dataColumns.length - 1 | |
| } | |
| continue | |
| } | |
| // Parse data lines | |
| if (inDataLoop && !line.startsWith('_') && !line.startsWith('loop_') && line.length > 0 && !line.startsWith('#')) { | |
| // Check if we've found the data section | |
| if (thetaIndex >= 0 && intensityIndex >= 0) { | |
| const parts = line.split(/\s+/) | |
| if (parts.length >= Math.max(thetaIndex, intensityIndex) + 1) { | |
| const xVal = parseFloat(parts[thetaIndex]) | |
| const yVal = parseFloat(parts[intensityIndex]) | |
| if (!isNaN(xVal) && !isNaN(yVal)) { | |
| x.push(xVal) | |
| y.push(yVal) | |
| } | |
| } | |
| } else { | |
| // End of loop, no data found | |
| inDataLoop = false | |
| dataColumns = [] | |
| thetaIndex = -1 | |
| intensityIndex = -1 | |
| } | |
| } | |
| // Reset if we hit another loop_ or data block | |
| if (inDataLoop && (line.startsWith('data_') || (line === 'loop_' && dataColumns.length > 0))) { | |
| inDataLoop = false | |
| } | |
| } | |
| return { x, y } | |
| } | |
| // Parse DIF or XY format (space-separated 2theta intensity) | |
| const parseDIF = (text) => { | |
| const lines = text.split('\n') | |
| const x = [] | |
| const y = [] | |
| for (const line of lines) { | |
| const trimmed = line.trim() | |
| // Skip comment lines, headers, and metadata | |
| if (!trimmed || | |
| trimmed.startsWith('#') || | |
| trimmed.startsWith('_') || | |
| trimmed.startsWith('CELL') || | |
| trimmed.startsWith('SPACE') || | |
| /^[a-zA-Z]/.test(trimmed)) { // Skip lines starting with letters (metadata) | |
| continue | |
| } | |
| // Split by whitespace | |
| const parts = trimmed.split(/\s+/) | |
| if (parts.length >= 2) { | |
| const xVal = parseFloat(parts[0]) | |
| const yVal = parseFloat(parts[1]) | |
| if (!isNaN(xVal) && !isNaN(yVal)) { | |
| x.push(xVal) | |
| y.push(yVal) | |
| } | |
| } | |
| } | |
| return { x, y } | |
| } | |
| // Parse uploaded file | |
| const parseFile = (file) => { | |
| return new Promise((resolve, reject) => { | |
| const reader = new FileReader() | |
| reader.onload = (e) => { | |
| try { | |
| const text = e.target.result | |
| // Extract metadata (including wavelength) | |
| const metadata = extractMetadata(text) | |
| if (metadata.wavelength) { | |
| setDetectedWavelength(metadata.wavelength) | |
| setUserWavelength(metadata.wavelength) | |
| setWavelengthSource('detected') | |
| } else { | |
| setDetectedWavelength(null) | |
| setWavelengthSource('default') | |
| } | |
| // Determine file format and parse accordingly | |
| const fileName = file.name.toLowerCase() | |
| let data = { x: [], y: [] } | |
| if (fileName.endsWith('.cif')) { | |
| // CIF format - look for loop_ structures | |
| data = parseCIF(text) | |
| // Fallback to simple parsing if CIF parsing didn't find data | |
| if (data.x.length === 0) { | |
| console.log('CIF loop parsing failed, falling back to simple parser') | |
| data = parseDIF(text) | |
| } | |
| } else { | |
| // DIF, XY, CSV, TXT - simple space/comma separated | |
| data = parseDIF(text) | |
| } | |
| if (data.x.length === 0 || data.y.length === 0) { | |
| reject(new Error('No valid data points found in file')) | |
| return | |
| } | |
| console.log(`Parsed ${data.x.length} data points from ${fileName}`) | |
| resolve(data) | |
| } catch (error) { | |
| reject(error) | |
| } | |
| } | |
| reader.onerror = () => reject(new Error('Failed to read file')) | |
| reader.readAsText(file) | |
| }) | |
| } | |
| // Upload and parse file | |
| const handleFileUpload = async (file) => { | |
| try { | |
| const data = await parseFile(file) | |
| setRawData(data) | |
| setFilename(file.name) | |
| setModelResults(null) // Clear previous results | |
| setAnalysisStatus('IDLE') | |
| setIsLogitDrawerOpen(false) // Close logit drawer if open | |
| return true | |
| } catch (error) { | |
| console.error('Error uploading file:', error) | |
| alert(`Error loading file: ${error.message}`) | |
| return false | |
| } | |
| } | |
| // Send processed data to API for inference | |
| const runInference = useCallback(async () => { | |
| if (!processedData) { | |
| alert('No data to analyze') | |
| return | |
| } | |
| // Increment analysis counter - tracks button clicks | |
| const currentCount = analysisCount + 1 | |
| setAnalysisCount(currentCount) | |
| setIsLoading(true) | |
| setAnalysisStatus('PROCESSING') | |
| try { | |
| const requestTimestamp = Date.now() | |
| const response = await fetch('/api/predict', { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| // Anti-caching headers | |
| 'Cache-Control': 'no-cache, no-store, must-revalidate', | |
| 'Pragma': 'no-cache', | |
| 'Expires': '0', | |
| // Request tracking | |
| 'X-Request-ID': String(requestTimestamp), | |
| 'X-Filename': filename || 'unknown', | |
| }, | |
| // Explicitly disable caching for this request | |
| cache: 'no-store', | |
| body: JSON.stringify({ | |
| x: processedData.x, | |
| y: processedData.y, | |
| // Include metadata to help track requests | |
| metadata: { | |
| timestamp: requestTimestamp, | |
| filename: filename, | |
| analysisCount: currentCount, | |
| } | |
| }), | |
| }) | |
| if (!response.ok) { | |
| throw new Error(`API error: ${response.status}`) | |
| } | |
| const results = await response.json() | |
| setModelResults(results) | |
| setAnalysisStatus('COMPLETE') | |
| } catch (error) { | |
| console.error('Error running inference:', error) | |
| alert(`Inference failed: ${error.message}`) | |
| setAnalysisStatus('IDLE') | |
| } finally { | |
| setIsLoading(false) | |
| } | |
| }, [processedData, analysisCount, filename]) | |
| // Load an example data file from the API | |
| const loadExampleFile = useCallback(async (filename) => { | |
| try { | |
| const response = await fetch(`/api/examples/${encodeURIComponent(filename)}`) | |
| if (!response.ok) { | |
| throw new Error(`Failed to fetch example: ${response.status}`) | |
| } | |
| const text = await response.text() | |
| // Extract metadata (including wavelength) — same as normal file upload | |
| const metadata = extractMetadata(text) | |
| if (metadata.wavelength) { | |
| setDetectedWavelength(metadata.wavelength) | |
| setUserWavelength(metadata.wavelength) | |
| setWavelengthSource('detected') | |
| } else { | |
| setDetectedWavelength(null) | |
| setWavelengthSource('default') | |
| } | |
| // Parse using the DIF parser (all examples are .dif) | |
| const data = parseDIF(text) | |
| if (data.x.length === 0 || data.y.length === 0) { | |
| throw new Error('No valid data points found in example file') | |
| } | |
| setRawData(data) | |
| setFilename(filename) | |
| setModelResults(null) | |
| setAnalysisStatus('IDLE') | |
| setIsLogitDrawerOpen(false) | |
| return true | |
| } catch (error) { | |
| console.error('Error loading example file:', error) | |
| alert(`Error loading example: ${error.message}`) | |
| return false | |
| } | |
| }, []) | |
| // Reset application state | |
| const handleReset = () => { | |
| setRawData(null) | |
| setFilename(null) | |
| setModelResults(null) | |
| setAnalysisStatus('IDLE') | |
| setIsLogitDrawerOpen(false) | |
| } | |
| const value = { | |
| rawData, | |
| processedData, | |
| modelResults, | |
| isLoading, | |
| filename, | |
| analysisStatus, | |
| detectedWavelength, | |
| userWavelength, | |
| setUserWavelength, | |
| wavelengthSource, | |
| dataWarnings, | |
| baselineCorrection, | |
| setBaselineCorrection, | |
| interpolationEnabled, | |
| setInterpolationEnabled, | |
| scalingEnabled, | |
| setScalingEnabled, | |
| interpolationStrategy, | |
| setInterpolationStrategy, | |
| isLogitDrawerOpen, | |
| setIsLogitDrawerOpen, | |
| handleFileUpload, | |
| loadExampleFile, | |
| runInference, | |
| handleReset, | |
| MODEL_WAVELENGTH, | |
| MODEL_MIN_2THETA, | |
| MODEL_MAX_2THETA, | |
| MODEL_INPUT_SIZE, | |
| } | |
| return <XRDContext.Provider value={value}>{children}</XRDContext.Provider> | |
| } | |