ecflow / digitizer.py
Bing Yan
Fix image digitization: robust axis tick disambiguation and mapping
58cb8ee
"""
Image digitizer for extracting data from plot images.
Uses OpenCV to trace curves from uploaded plot images, mapping pixel
coordinates to data coordinates via user-provided axis ranges.
Supports automatic axis-range detection via OCR (easyocr).
"""
import re
import numpy as np
def auto_detect_axis_bounds(image_array):
"""Detect axis bounds and current unit from a plot image using OCR.
Reads numeric tick labels, clusters them by position into x-axis
(similar y-coordinate, bottom half) and y-axis (similar x-coordinate,
left half), and returns the inferred data ranges.
Also attempts to detect the y-axis unit label by OCR-ing a rotated
crop of the left margin.
Args:
image_array: numpy array (H, W, 3) RGB image
Returns:
dict with keys 'x_min', 'x_max', 'y_min', 'y_max' (all float),
and optionally 'y_unit' (str, e.g. 'µA', 'mA', 'A'),
or None if detection fails.
"""
try:
import easyocr
except ImportError:
return None
if image_array.ndim == 3 and image_array.shape[2] == 4:
image_array = image_array[:, :, :3]
H, W = image_array.shape[:2]
reader = easyocr.Reader(["en"], gpu=False, verbose=False)
results = reader.readtext(image_array, detail=1)
_NUM_RE = re.compile(
r"^[−\-–~]?\d+\.?\d*(?:[eE][+\-]?\d+)?$"
)
detections = []
all_texts = []
for bbox, text, conf in results:
all_texts.append(text.strip())
cleaned = (text.strip().replace(" ", "")
.replace("−", "-").replace("–", "-").replace("~", "-"))
if not _NUM_RE.match(cleaned):
continue
try:
val = float(cleaned)
except ValueError:
continue
if conf < 0.2:
continue
cx = np.mean([p[0] for p in bbox])
cy = np.mean([p[1] for p in bbox])
detections.append((cx, cy, val))
if len(detections) < 4:
return None
x_candidates = [(cx, cy, v) for cx, cy, v in detections if cy > H * 0.65]
y_candidates = [(cx, cy, v) for cx, cy, v in detections if cx < W * 0.30]
# Disambiguate detections that appear in both candidate lists.
# Y-axis labels in the bottom-left corner satisfy both cy > 0.65*H
# and cx < 0.30*W. Assign each ambiguous detection to the axis
# whose alignment coordinate (cy for x-axis, cx for y-axis) it
# matches better.
if x_candidates and y_candidates:
ambiguous_indices = set()
for i, xc in enumerate(x_candidates):
for yc in y_candidates:
if abs(xc[0] - yc[0]) < 10 and abs(xc[1] - yc[1]) < 10:
ambiguous_indices.add(i)
if ambiguous_indices:
# X-axis ticks share a common cy (alignment); compute the
# median cy of non-ambiguous x-candidates as reference.
clean_x_cys = [x_candidates[i][1] for i in range(len(x_candidates))
if i not in ambiguous_indices]
clean_y_cxs = [c[0] for c in y_candidates
if not any(abs(c[0] - x_candidates[j][0]) < 10
and abs(c[1] - x_candidates[j][1]) < 10
for j in ambiguous_indices)]
ref_x_cy = np.median(clean_x_cys) if clean_x_cys else np.median([c[1] for c in x_candidates])
ref_y_cx = np.median(clean_y_cxs) if clean_y_cxs else np.median([c[0] for c in y_candidates])
x_candidates_filtered = []
for i, c in enumerate(x_candidates):
if i in ambiguous_indices:
dist_to_x_row = abs(c[1] - ref_x_cy)
dist_to_y_col = abs(c[0] - ref_y_cx)
if dist_to_x_row > dist_to_y_col:
continue
x_candidates_filtered.append(c)
y_candidates_filtered = []
for c in y_candidates:
is_amb = any(abs(c[0] - x_candidates[j][0]) < 10
and abs(c[1] - x_candidates[j][1]) < 10
for j in ambiguous_indices)
if is_amb:
dist_to_x_row = abs(c[1] - ref_x_cy)
dist_to_y_col = abs(c[0] - ref_y_cx)
if dist_to_y_col > dist_to_x_row:
continue
y_candidates_filtered.append(c)
x_candidates = x_candidates_filtered
y_candidates = y_candidates_filtered
x_ticks = _extract_axis_ticks(x_candidates, axis="x")
y_ticks = _extract_axis_ticks(y_candidates, axis="y")
if len(x_ticks) < 2 or len(y_ticks) < 2:
return None
x_vals = [v for _, v in x_ticks]
y_vals = [v for _, v in y_ticks]
# Detect y-axis unit from all OCR text and from rotated left margin
y_unit = _detect_current_unit(image_array, reader, all_texts)
result = {
"x_min": float(min(x_vals)),
"x_max": float(max(x_vals)),
"y_min": float(min(y_vals)),
"y_max": float(max(y_vals)),
"x_ticks": x_ticks,
"y_ticks": y_ticks,
}
if y_unit:
result["y_unit"] = y_unit
return result
def _detect_current_unit(image_array, reader, all_texts):
"""Try to detect the current unit from the y-axis label.
First checks all OCR text for unit patterns. If not found,
rotates the left margin 90° CW and re-runs OCR to read
the rotated y-axis label.
"""
import cv2
combined = " ".join(all_texts).lower()
for pattern, unit in [
("µa", "µA"), ("ua", "µA"), ("μa", "µA"),
("(ma)", "mA"), ("ma)", "mA"), ("i/ma", "mA"),
("(na)", "nA"), ("na)", "nA"),
("(a)", "A"),
]:
if pattern in combined:
return unit
H, W = image_array.shape[:2]
left_strip = image_array[:, : int(W * 0.12), :]
rotated = cv2.rotate(left_strip, cv2.ROTATE_90_CLOCKWISE)
try:
rot_results = reader.readtext(rotated, detail=1)
except Exception:
return None
rot_text_raw = " ".join(t.strip() for _, t, _ in rot_results)
rot_text_lower = rot_text_raw.lower()
# Check case-sensitive patterns first (µA vs mA)
for pattern, unit in [
("µA", "µA"), ("µa", "µA"), ("uA", "µA"), ("μA", "µA"),
("MA", "µA"), # OCR often misreads µ as M
("HA", "µA"), # OCR sometimes misreads µ as H
("mA", "mA"),
("nA", "nA"),
]:
if pattern in rot_text_raw:
return unit
# Fallback: case-insensitive but less reliable
if "ua" in rot_text_lower:
return "µA"
return None
def _extract_axis_ticks(candidates, axis="x"):
"""From (cx, cy, val) candidates, extract tick (position, value) pairs.
Clusters candidates along the alignment axis (cy for x-axis ticks,
cx for y-axis ticks), picks the best cluster, then uses spatial
ordering to fix missing minus signs from OCR.
Returns list of (position, value) sorted by position, where position
is cx for x-axis ticks and cy for y-axis ticks.
"""
if len(candidates) < 2:
return []
# Alignment axis: x-ticks share similar cy, y-ticks share similar cx
align_idx = 1 if axis == "x" else 0
# Position axis: x-ticks vary in cx, y-ticks vary in cy
pos_idx = 0 if axis == "x" else 1
coords = np.array([c[align_idx] for c in candidates])
best_cluster = []
for ref in coords:
cluster = [c for c in candidates if abs(c[align_idx] - ref) < 30]
if len(cluster) > len(best_cluster):
best_cluster = cluster
if len(best_cluster) < 2:
return []
# Sort by position: left-to-right for x, top-to-bottom for y
best_cluster.sort(key=lambda c: c[pos_idx])
ticks = [(c[pos_idx], c[2]) for c in best_cluster]
# Fix missing minus signs: tick values should be monotonic with position.
# x-axis: value increases with cx (left to right)
# y-axis: value DECREASES with cy (top to bottom, since cy increases downward)
ticks = _fix_missing_negatives(ticks, increasing=(axis == "x"))
# Remove outlier ticks whose values break the expected linear
# position-to-value mapping (e.g. OCR reading "1.0" as "10").
ticks = _remove_tick_outliers(ticks)
return ticks
def _remove_tick_outliers(ticks):
"""Remove ticks whose values deviate from the expected linear mapping.
Uses a leave-one-out approach: for each tick, fit a line to the
remaining ticks and check if the held-out tick's residual is large.
This is robust even when a single outlier distorts the overall fit
(e.g. OCR reading "1.0" as "10").
Also detects and corrects OCR misreads where a decimal point is
dropped (e.g. "1.0" → "10") by checking tick spacing consistency.
"""
if len(ticks) < 3:
return ticks
positions, values = zip(*ticks)
positions = np.array(positions, dtype=float)
values = np.array(values, dtype=float)
n = len(ticks)
# First pass: fix decimal-point-dropped misreads.
# If adjacent ticks have inconsistent value/pixel ratios, check if
# dividing a value by 10 makes the ratios consistent.
ticks = _fix_decimal_misreads(list(zip(positions, values)))
positions, values = zip(*ticks)
positions = np.array(positions, dtype=float)
values = np.array(values, dtype=float)
# Second pass: leave-one-out outlier removal
loo_residuals = np.zeros(n)
for i in range(n):
mask = np.ones(n, dtype=bool)
mask[i] = False
if mask.sum() < 2:
loo_residuals[i] = 0
continue
coeffs = np.polyfit(positions[mask], values[mask], 1)
predicted = np.polyval(coeffs, positions[i])
loo_residuals[i] = abs(values[i] - predicted)
coeffs_all = np.polyfit(positions, values, 1)
expected_spacing = abs(coeffs_all[0]) * np.median(np.diff(positions))
if expected_spacing < 1e-12:
expected_spacing = np.median(np.abs(np.diff(values))) + 1e-12
keep = loo_residuals < expected_spacing * 2
if keep.sum() < 2:
return ticks
return [(p, v) for p, v, k in zip(positions, values, keep) if k]
def _fix_decimal_misreads(ticks):
"""Fix OCR misreads where the decimal point is dropped.
E.g. "1.0" read as "10", "0.5" read as "5". Detects these by
checking if the value/pixel ratio between adjacent ticks is
inconsistent, and whether dividing a value by 10 fixes it.
"""
if len(ticks) < 3:
return ticks
positions = np.array([t[0] for t in ticks], dtype=float)
values = np.array([t[1] for t in ticks], dtype=float)
n = len(ticks)
# Compute the value step per pixel step for each adjacent pair
dv = np.diff(values)
dp = np.diff(positions)
ratios = dv / (dp + 1e-12)
# The median ratio represents the "true" scale
med_ratio = np.median(ratios)
if abs(med_ratio) < 1e-12:
return ticks
# For each tick, check if replacing its value with value/10
# produces a more consistent set of ratios
improved = True
max_iters = 5
while improved and max_iters > 0:
improved = False
max_iters -= 1
for i in range(n):
# Compute current residual from linear fit
coeffs = np.polyfit(positions, values, 1)
predicted = np.polyval(coeffs, positions[i])
current_res = abs(values[i] - predicted)
# Try value / 10
test_values = values.copy()
test_values[i] = values[i] / 10.0
coeffs_test = np.polyfit(positions, test_values, 1)
predicted_test = np.polyval(coeffs_test, positions[i])
test_res = abs(test_values[i] - predicted_test)
# Also compute overall fit quality
current_total = np.sum((values - np.polyval(coeffs, positions)) ** 2)
test_total = np.sum((test_values - np.polyval(coeffs_test, positions)) ** 2)
if test_total < current_total * 0.3:
values[i] = test_values[i]
improved = True
return list(zip(positions, values))
def _fix_missing_negatives(ticks, increasing=True):
"""Fix OCR-dropped minus signs using spatial monotonicity.
Tick labels on a plot axis must be monotonically ordered. If OCR drops
a minus sign, we'll see a value that breaks monotonicity. We can fix
this by negating values that should be negative.
Args:
ticks: list of (position, value) sorted by position
increasing: True if values should increase with position (x-axis),
False if values should decrease with position (y-axis)
"""
if len(ticks) < 2:
return ticks
positions, values = zip(*ticks)
values = list(values)
n = len(values)
# Try to find a consistent evenly-spaced sequence.
# Common tick spacings: the spacing between adjacent ticks should be constant.
# Strategy: find the most common absolute difference between adjacent values
# after sorting by position, then reconstruct the sequence.
# First, check if values are already monotonic
if increasing:
is_ok = all(values[i] <= values[i + 1] for i in range(n - 1))
else:
is_ok = all(values[i] >= values[i + 1] for i in range(n - 1))
if is_ok:
return ticks
# Values are NOT monotonic — OCR likely dropped minus signs.
# Strategy: for each value, try both +val and -val, and find the
# assignment that produces the best evenly-spaced monotonic sequence.
abs_vals = [abs(v) for v in values]
best_score = float("inf")
best_assignment = values[:]
# Try all 2^n sign assignments (feasible for typical n <= 10 ticks)
for mask in range(1 << n):
candidate = [(-abs_vals[i] if (mask >> i) & 1 else abs_vals[i])
for i in range(n)]
# Check monotonicity
if increasing:
if not all(candidate[i] <= candidate[i + 1] for i in range(n - 1)):
continue
else:
if not all(candidate[i] >= candidate[i + 1] for i in range(n - 1)):
continue
# Score: prefer evenly spaced ticks (low variance in step sizes)
steps = [candidate[i + 1] - candidate[i] for i in range(n - 1)]
if len(steps) > 1:
mean_step = np.mean(steps)
score = np.var(steps) / (mean_step ** 2 + 1e-12)
else:
score = 0.0
if score < best_score:
best_score = score
best_assignment = candidate
return list(zip(positions, best_assignment))
def _detect_plot_region(gray):
"""Detect the plot area (axes bounding box) from a grayscale image.
Uses Hough line detection to find the axis lines, then infers
the plot boundaries from the longest horizontal and vertical lines.
Returns:
(px_left, px_right, py_top, py_bottom) in pixel coordinates,
or None if detection fails.
"""
try:
import cv2
except ImportError:
return None
H, W = gray.shape
edges = cv2.Canny(gray, 50, 150)
# Find long line segments
min_len = min(W, H) // 4
lines = cv2.HoughLinesP(edges, 1, np.pi / 180,
threshold=80, minLineLength=min_len, maxLineGap=10)
if lines is None:
return None
h_lines = [] # (y, x1, x2, length)
v_lines = [] # (x, y1, y2, length)
for l in lines:
x1, y1, x2, y2 = l[0]
angle = abs(np.degrees(np.arctan2(y2 - y1, x2 - x1)))
length = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
if angle < 5 or angle > 175:
y_avg = (y1 + y2) / 2
h_lines.append((y_avg, min(x1, x2), max(x1, x2), length))
elif abs(angle - 90) < 5:
x_avg = (x1 + x2) / 2
v_lines.append((x_avg, min(y1, y2), max(y1, y2), length))
if not h_lines or not v_lines:
return None
# Bottom axis: longest horizontal line in the lower half
h_bottom = [l for l in h_lines if l[0] > H * 0.4]
if h_bottom:
best_h = max(h_bottom, key=lambda l: l[3])
py_bottom = int(best_h[0])
# Use the horizontal line's x-extent for right boundary
h_x_right = int(best_h[2])
else:
py_bottom = int(H * 0.85)
h_x_right = int(W * 0.92)
# Left axis: longest vertical line in the left half
v_left = [l for l in v_lines if l[0] < W * 0.5]
if v_left:
best_v = max(v_left, key=lambda l: l[3])
px_left = int(best_v[0])
# Use the vertical line's y-extent for top boundary
v_y_top = int(best_v[1])
else:
px_left = int(W * 0.12)
v_y_top = int(H * 0.08)
# Right boundary: from horizontal axis line extent, or right edge
px_right = h_x_right if h_x_right > px_left + W * 0.2 else int(W * 0.92)
# Top boundary: from vertical axis line extent, or top edge
py_top = v_y_top if v_y_top < py_bottom - H * 0.2 else int(H * 0.08)
if px_right - px_left < W * 0.15 or py_bottom - py_top < H * 0.15:
return None
return px_left, px_right, py_top, py_bottom
def _robust_tick_fit(ticks):
"""Fit a linear pixel→value mapping that handles missing intermediate ticks.
A simple polyfit fails when OCR misses some tick labels, because the
value-gap between detected ticks no longer matches the pixel-gap.
For example, ticks at values [0.5, -1.0, -1.5] with equal pixel
spacing means OCR missed 0.0 and -0.5 between 0.5 and -1.0.
Strategy: find the minimum |Δvalue/Δpixel| ratio among adjacent
tick pairs — this corresponds to the pair where no ticks are missing.
Use that ratio as the true scale, then anchor the mapping at the
tick pair that defines it.
"""
if len(ticks) < 2:
return np.array([0.0, 0.0])
positions = np.array([t[0] for t in ticks], dtype=float)
values = np.array([t[1] for t in ticks], dtype=float)
if len(ticks) == 2:
return np.polyfit(positions, values, 1)
# Compute |Δvalue / Δpixel| for each adjacent pair
dp = np.diff(positions)
dv = np.diff(values)
ratios = dv / (dp + 1e-12)
abs_ratios = np.abs(ratios)
# The minimum absolute ratio corresponds to the pair with no missing
# ticks between them (smallest value change per pixel step).
min_idx = np.argmin(abs_ratios)
true_ratio = ratios[min_idx]
# Check if all ratios are consistent (within 50% of each other).
# If so, just use polyfit — no missing ticks.
if abs_ratios.max() < abs_ratios.min() * 1.8:
return np.polyfit(positions, values, 1)
# Use the true ratio and anchor at the midpoint of the best pair
anchor_px = (positions[min_idx] + positions[min_idx + 1]) / 2
anchor_val = (values[min_idx] + values[min_idx + 1]) / 2
intercept = anchor_val - true_ratio * anchor_px
return np.array([true_ratio, intercept])
def digitize_plot(image_array, x_min, x_max, y_min, y_max,
threshold=0, min_contour_length=50,
x_ticks=None, y_ticks=None):
"""
Extract (x, y) data points from a plot image.
Uses axis detection to find the plot region, then maps pixel
coordinates to data coordinates. If tick positions (pixel, value)
are provided, uses them for a more accurate linear mapping that
correctly handles data extending beyond the last tick mark.
Args:
image_array: numpy array (H, W, 3) RGB image
x_min, x_max: data-space x-axis range (from tick labels)
y_min, y_max: data-space y-axis range (from tick labels)
threshold: binarization threshold (0-255)
min_contour_length: minimum contour length to consider
x_ticks: list of (pixel_x, value) from OCR tick detection
y_ticks: list of (pixel_y, value) from OCR tick detection
Returns:
x_data, y_data: 1-D arrays of extracted data points
"""
try:
import cv2
except ImportError:
raise ImportError("opencv-python-headless is required for image digitization")
if image_array.ndim == 3 and image_array.shape[2] == 4:
image_array = image_array[:, :, :3]
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
H, W = gray.shape
# Detect plot region from axis lines
region = _detect_plot_region(gray)
if region is not None:
px_left, px_right, py_top, py_bottom = region
else:
px_left = int(W * 0.12)
px_right = int(W * 0.92)
py_top = int(H * 0.08)
py_bottom = int(H * 0.85)
# Build pixel-to-data mapping from tick positions if available.
# This allows correct extrapolation for data beyond the last tick.
if x_ticks and len(x_ticks) >= 2:
x_slope = _robust_tick_fit(x_ticks)
eff_x_min = float(np.polyval(x_slope, px_left))
eff_x_max = float(np.polyval(x_slope, px_right))
else:
eff_x_min, eff_x_max = x_min, x_max
if y_ticks and len(y_ticks) >= 2:
y_slope = _robust_tick_fit(y_ticks)
eff_y_min = float(np.polyval(y_slope, py_bottom)) # bottom = y_min
eff_y_max = float(np.polyval(y_slope, py_top)) # top = y_max
else:
eff_y_min, eff_y_max = y_min, y_max
binary = _binarize(gray, threshold)
# Restrict to plot region
plot_binary = np.zeros_like(binary)
margin = 4
r0 = py_top + margin
r1 = py_bottom - margin
c0 = px_left + margin
c1 = px_right - margin
plot_binary[r0:r1, c0:c1] = binary[r0:r1, c0:c1]
if plot_binary.sum() == 0:
raise ValueError("No curves detected in image. Try adjusting the threshold.")
x_data, y_data = _column_scan(
plot_binary, px_left, px_right, py_top, py_bottom,
eff_x_min, eff_x_max, eff_y_min, eff_y_max,
)
return x_data.astype(np.float32), y_data.astype(np.float32)
def _binarize(gray, threshold=0):
"""Binarize a grayscale image with optional Otsu auto-threshold."""
import cv2
if threshold <= 0:
_, binary = cv2.threshold(gray, 0, 255,
cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
else:
_, binary = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY_INV)
kernel = np.ones((3, 3), np.uint8)
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=1)
return binary
def _column_scan(binary, px_left, px_right, py_top, py_bottom,
x_min, x_max, y_min, y_max):
"""Extract curve data by scanning each pixel column.
For each column in the plot region, finds clusters of dark pixels
and takes their midline. If two clusters exist (CV closed loop),
produces a forward-then-reverse trace matching standard potentiostat
output (high E → low E → high E).
"""
px_range = float(px_right - px_left)
py_range = float(py_bottom - py_top)
x_vals = []
y_branch_a = [] # higher-y branch (lower current in image coords)
y_branch_b = [] # lower-y branch (higher current in image coords)
has_two_branches = []
for col in range(px_left + 5, px_right - 5):
col_data = binary[py_top + 5 : py_bottom - 5, col] > 0
rows = np.where(col_data)[0] + py_top + 5
if len(rows) < 1:
continue
x_data = x_min + (col - px_left) / px_range * (x_max - x_min)
gaps = np.diff(rows)
gap_thr = max(5, np.median(gaps) * 3) if len(gaps) > 0 else 5
split_points = np.where(gaps > gap_thr)[0]
if len(split_points) >= 1:
c1 = rows[: split_points[0] + 1]
c2 = rows[split_points[0] + 1 :]
y1 = y_max - (np.mean(c1) - py_top) / py_range * (y_max - y_min)
y2 = y_max - (np.mean(c2) - py_top) / py_range * (y_max - y_min)
x_vals.append(x_data)
y_branch_a.append(max(y1, y2))
y_branch_b.append(min(y1, y2))
has_two_branches.append(True)
else:
y_mean = y_max - (np.mean(rows) - py_top) / py_range * (y_max - y_min)
x_vals.append(x_data)
y_branch_a.append(y_mean)
y_branch_b.append(y_mean)
has_two_branches.append(False)
x_vals = np.array(x_vals)
y_branch_a = np.array(y_branch_a)
y_branch_b = np.array(y_branch_b)
if len(x_vals) < 5:
raise ValueError("Too few data points extracted from the image.")
two_branch_frac = np.mean(has_two_branches)
if two_branch_frac > 0.3:
# CV closed loop: construct high-E → low-E → high-E trace
# Branch a = anodic (higher current), branch b = cathodic (lower current)
# Standard potentiostat: start at high E, sweep to low E (cathodic),
# then back to high E (anodic)
x_trace = np.concatenate([x_vals[::-1], x_vals])
y_trace = np.concatenate([y_branch_b[::-1], y_branch_a])
return x_trace, y_trace
else:
return x_vals, y_branch_a
def digitize_multiple_curves(image_array, x_min, x_max, y_min, y_max,
n_curves=1, threshold=0, min_contour_length=50):
"""
Extract curves from a plot image using column-scan approach.
For single-curve extraction (n_curves=1), delegates to digitize_plot.
"""
x_data, y_data = digitize_plot(
image_array, x_min, x_max, y_min, y_max,
threshold=threshold, min_contour_length=min_contour_length,
)
return [(x_data, y_data)]