| import os |
| import json |
| import glob |
| import xml.etree.ElementTree as ET |
| import numpy as np |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader |
| import torchvision.transforms as T |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from shapely.geometry import Polygon |
| from pathlib import Path |
|
|
| |
| |
| |
|
|
| import numpy as np |
| import json |
|
|
| def flat_corners_from_mockup(mockup_path): |
| """ |
| Returns 4 corners of print area from mockup.json |
| ordered TL, TR, BR, BL and normalized [0,1] w.r.t background. |
| """ |
| d = json.loads(Path(mockup_path).read_text()) |
| bg_w = d["background"]["width"] |
| bg_h = d["background"]["height"] |
| area = d["printAreas"][0] |
| x, y = area["position"]["x"], area["position"]["y"] |
| w, h = area["width"], area["height"] |
| angle = area["rotation"] |
| cx, cy = x + w/2.0, y + h/2.0 |
|
|
| |
| dx, dy = w/2.0, h/2.0 |
| corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]], dtype=np.float32) |
| theta = np.deg2rad(angle) |
| R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], dtype=np.float32) |
| rot = (corners @ R.T) + np.array([cx, cy], dtype=np.float32) |
|
|
| |
| norm = np.zeros_like(rot) |
| norm[:,0] = rot[:,0] / bg_w |
| norm[:,1] = rot[:,1] / bg_h |
| return rot.astype(np.float32), norm.astype(np.float32) |
|
|
| def parse_xml_points(xml_path): |
| """ |
| Parse the 4 corner points from the XML (FourPoint transform). |
| Returns normalized coordinates (TL, TR, BR, BL). |
| """ |
| tree = ET.parse(xml_path) |
| root = tree.getroot() |
|
|
| points = [] |
| bg_w = int(root.find("background").get("width")) |
| bg_h = int(root.find("background").get("height")) |
|
|
| for transform in root.findall(".//transform"): |
| if transform.get("type") == "FourPoint": |
| for pt in ["TopLeft", "TopRight", "BottomRight", "BottomLeft"]: |
| node = transform.find(f".//point[@type='{pt}']") |
| if node is not None: |
| x = float(node.get("x")) / bg_w |
| y = float(node.get("y")) / bg_h |
| points.append([x, y]) |
| break |
|
|
| return np.array(points, dtype=np.float32) |
|
|
| class KP4Dataset(Dataset): |
| def __init__(self, root, img_size=512): |
| self.root = Path(root) |
| self.img_size = img_size |
| self.samples = [] |
|
|
| |
| self.transform = T.Compose([ |
| T.Resize((img_size, img_size)), |
| T.ToTensor(), |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| ]) |
|
|
| |
| for xml_file in self.root.rglob("*.xml"): |
| if "_visual" not in xml_file.stem: |
| continue |
|
|
| |
| base = xml_file.stem |
| img_file = None |
| for ext in [".png", ".jpg", ".jpeg"]: |
| cand = xml_file.with_suffix(ext) |
| if cand.exists(): |
| img_file = cand |
| break |
| if img_file is None: |
| continue |
|
|
| |
| flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".png") |
| if not flat_img.exists(): |
| flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".jpg") |
| if not flat_img.exists(): |
| continue |
|
|
| |
| json_file = xml_file.parent / "mockup.json" |
| if not json_file.exists(): |
| continue |
|
|
| self.samples.append((img_file, xml_file, flat_img, json_file)) |
|
|
| if not self.samples: |
| raise RuntimeError(f"No valid samples found under {root}") |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| img_file, xml_file, flat_img, json_file = self.samples[idx] |
|
|
| img = self.transform(Image.open(img_file).convert("RGB")) |
| flat = self.transform(Image.open(flat_img).convert("RGB")) |
|
|
| |
| _, flat_norm = flat_corners_from_mockup(json_file) |
| flat_pts = torch.tensor(flat_norm, dtype=torch.float32) |
|
|
| |
| persp_norm = parse_xml_points(xml_file) |
| persp_pts = torch.tensor(persp_norm, dtype=torch.float32) |
|
|
| return { |
| "persp_img": img, |
| "flat_img": flat, |
| "flat_pts": flat_pts, |
| "persp_pts": persp_pts, |
| "xml": str(xml_file), |
| "json": str(json_file), |
| } |
|
|
| |
| |
| |
| class SimpleTransformer(nn.Module): |
| def __init__(self, d_model=128, nhead=4, num_layers=2): |
| super().__init__() |
| self.fc_in = nn.Linear(8, d_model) |
| encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| self.fc_out = nn.Linear(d_model, 8) |
|
|
| def forward(self, x): |
| x = self.fc_in(x).unsqueeze(1) |
| x = self.transformer(x) |
| x = self.fc_out(x).squeeze(1) |
| return x |
|
|
|
|
| |
| |
| |
| def mse_loss(pred, gt): |
| return ((pred-gt)**2).mean() |
|
|
| def mean_corner_error(pred, gt, img_w, img_h): |
| pred_px = pred * torch.tensor([img_w,img_h], device=pred.device) |
| gt_px = gt * torch.tensor([img_w,img_h], device=gt.device) |
| err = torch.norm(pred_px-gt_px, dim=-1).mean().item() |
| return err |
|
|
| def iou_quad(pred, gt): |
| pred_poly = Polygon(pred.tolist()) |
| gt_poly = Polygon(gt.tolist()) |
| if not pred_poly.is_valid or not gt_poly.is_valid: |
| return 0.0 |
| inter = pred_poly.intersection(gt_poly).area |
| union = pred_poly.union(gt_poly).area |
| return inter/union if union > 0 else 0.0 |
|
|
|
|
| |
| |
| |
| def train_model( |
| train_root, |
| test_root, |
| epochs=20, |
| batch_size=8, |
| lr=1e-3, |
| img_size=256, |
| save_dir="Transformer/checkpoints", |
| resume_path=None |
| ): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| train_ds = KP4Dataset(train_root, img_size=img_size) |
| val_ds = KP4Dataset(test_root, img_size=img_size) |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) |
| val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) |
|
|
| model = SimpleTransformer().to(device) |
| optimizer = optim.Adam(model.parameters(), lr=lr) |
| start_epoch = 0 |
|
|
| os.makedirs(save_dir, exist_ok=True) |
|
|
| |
| if resume_path is not None and os.path.exists(resume_path): |
| print(f"Loading checkpoint from {resume_path}") |
| checkpoint = torch.load(resume_path, map_location=device) |
| model.load_state_dict(checkpoint["model_state"]) |
| optimizer.load_state_dict(checkpoint["optimizer_state"]) |
| start_epoch = checkpoint["epoch"] |
| print(f"Resumed from epoch {start_epoch}") |
|
|
| |
| best_iou = -1.0 |
| best_model_path = os.path.join(save_dir, "best_model.pth") |
|
|
| for epoch in range(start_epoch, epochs): |
| |
| model.train() |
| total_loss = 0 |
| for batch in train_loader: |
| flat_pts = batch["flat_pts"].to(device) |
| persp_pts = batch["persp_pts"].to(device) |
|
|
| flat_pts_in = flat_pts.view(flat_pts.size(0), -1) |
| target = persp_pts.view(persp_pts.size(0), -1) |
|
|
| pred = model(flat_pts_in) |
| loss = mse_loss(pred, target) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() |
|
|
| print(f"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss/len(train_loader):.6f}") |
|
|
| |
| model.eval() |
| mse_all, ce_all, iou_all = [], [], [] |
| with torch.no_grad(): |
| for batch in val_loader: |
| flat_pts = batch["flat_pts"].to(device) |
| persp_pts = batch["persp_pts"].to(device) |
|
|
| flat_pts_in = flat_pts.view(1, -1) |
| target = persp_pts.view(1, -1) |
|
|
| pred = model(flat_pts_in) |
| mse_all.append(mse_loss(pred, target).item()) |
|
|
| pred_quad = pred.view(4,2).cpu() |
| gt_quad = persp_pts.view(4,2).cpu() |
|
|
| w,h = batch["persp_img"].shape[2], batch["persp_img"].shape[1] |
| ce_all.append(mean_corner_error(pred_quad, gt_quad, w, h)) |
| iou_all.append(iou_quad(pred_quad, gt_quad)) |
|
|
| val_mse = np.mean(mse_all) |
| val_ce = np.mean(ce_all) |
| val_iou = np.mean(iou_all) |
|
|
| print(f" Val MSE: {val_mse:.6f}, CornerErr(px): {val_ce:.2f}, IoU: {val_iou:.3f}") |
| if (epoch + 1) % 100 == 0: |
| |
| checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth") |
| torch.save({ |
| "epoch": epoch+1, |
| "model_state": model.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "val_iou": val_iou, |
| }, checkpoint_path) |
| print(f"Checkpoint saved: {checkpoint_path}") |
|
|
| |
| if val_iou > best_iou: |
| best_iou = val_iou |
| torch.save({ |
| "epoch": epoch+1, |
| "model_state": model.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "best_iou": best_iou, |
| }, best_model_path) |
| print(f"Best model updated at epoch {epoch+1} (IoU={val_iou:.3f})") |
|
|
| |
| final_path = os.path.join(save_dir, "final_model.pth") |
| torch.save(model.state_dict(), final_path) |
| print(f"Final model saved at {final_path}") |
| print(f"Best model saved at {best_model_path} with IoU={best_iou:.3f}") |
|
|
| return model |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| model = train_model( |
| train_root="Transformer/train", |
| test_root="Transformer/test", |
| epochs=3000, |
| batch_size=4, |
| lr=1e-3, |
| img_size=256, |
| resume_path=None |
| ) |
|
|