Harry Pham commited on
Commit
4fef2dd
Β·
0 Parent(s):

init space

Browse files
Files changed (3) hide show
  1. .gitignore +9 -0
  2. app.py +150 -0
  3. src/inference.py +307 -0
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Dataset/
2
+ outputs/
3
+ venv/
4
+ __pycache__/
5
+ *.pyc
6
+ *.pkl
7
+ *.h5
8
+ *.log
9
+ *.json
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py β€” Gradio demo
2
+ # Đặt ở root project: engineering-drawing-ai/app.py
3
+
4
+ import os, sys, json, tempfile
5
+ import gradio as gr
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ # Auto-download weights tα»« HuggingFace Hub nαΊΏu chΖ°a cΓ³
11
+ CHECKPOINT = "best.pt"
12
+ HF_REPO = "phamha/drawing-model-weights" # ← sα»­a sau
13
+
14
+ def ensure_weights():
15
+ if not os.path.exists(CHECKPOINT):
16
+ print("[INFO] Downloading model weights from HuggingFace...")
17
+ try:
18
+ from huggingface_hub import hf_hub_download
19
+ hf_hub_download(
20
+ repo_id=HF_REPO,
21
+ filename="best.pt",
22
+ local_dir=".",
23
+ local_dir_use_symlinks=False,
24
+ )
25
+ print("[INFO] Weights downloaded.")
26
+ except Exception as e:
27
+ print(f"[ERROR] Cannot download weights: {e}")
28
+ raise
29
+
30
+ ensure_weights()
31
+
32
+ # Import pipeline SAU KHI Δ‘αΊ£m bαΊ£o weights tα»“n tαΊ‘i
33
+ sys.path.insert(0, ".")
34
+ from src.inference import run_pipeline
35
+
36
+
37
+ # ── Xα»­ lΓ½ αΊ£nh ──────────────────────────────────────────────
38
+ def process(image: Image.Image):
39
+ if image is None:
40
+ return None, "{}", "ChΖ°a cΓ³ αΊ£nh."
41
+
42
+ tmp_dir = tempfile.mkdtemp()
43
+ tmp_path = os.path.join(tmp_dir, "input.jpg")
44
+ image.save(tmp_path, quality=95)
45
+
46
+ try:
47
+ result, vis_path = run_pipeline(
48
+ image_path = tmp_path,
49
+ output_dir = tmp_dir,
50
+ checkpoint = CHECKPOINT,
51
+ conf_thresh = 0.3,
52
+ )
53
+ except Exception as e:
54
+ import traceback
55
+ return None, "{}", f"Lα»—i pipeline:\n{traceback.format_exc()}"
56
+
57
+ # αΊ’nh kαΊΏt quαΊ£
58
+ vis_bgr = cv2.imread(vis_path)
59
+ vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)
60
+
61
+ # JSON sẑch (bỏ crop_path)
62
+ clean_objs = []
63
+ for obj in result["objects"]:
64
+ clean_objs.append({
65
+ "id": obj["id"],
66
+ "class": obj["class"],
67
+ "confidence": obj["confidence"],
68
+ "bbox": obj["bbox"],
69
+ "ocr_content": obj["ocr_content"],
70
+ })
71
+ json_str = json.dumps(
72
+ {"image": result["image"], "objects": clean_objs},
73
+ ensure_ascii=False, indent=2
74
+ )
75
+
76
+ # OCR text Δ‘αΊΉp
77
+ ocr_parts = []
78
+ for obj in result["objects"]:
79
+ content = obj.get("ocr_content")
80
+ if not content:
81
+ continue
82
+ if isinstance(content, dict): # Table
83
+ content = content.get("text", "")
84
+ if not content.strip():
85
+ continue
86
+ sep = "─" * 46
87
+ ocr_parts.append(
88
+ f"{sep}\n"
89
+ f"[{obj['class']} #{obj['id']}] conf={obj['confidence']}\n"
90
+ f"{sep}\n{content}"
91
+ )
92
+ ocr_text = "\n\n".join(ocr_parts) or "KhΓ΄ng phΓ‘t hiện Note / Table."
93
+
94
+ return vis_rgb, json_str, ocr_text
95
+
96
+
97
+ # ── Gradio UI ───────────────────────────────────────────────
98
+ with gr.Blocks(title="Engineering Drawing Analyzer", theme=gr.themes.Soft()) as demo:
99
+
100
+ gr.Markdown("""
101
+ # πŸ”§ Engineering Drawing Analyzer
102
+ **Tα»± Δ‘α»™ng phΓ‘t hiện vΓ  trΓ­ch xuαΊ₯t vΔƒn bαΊ£n tα»« bαΊ£n vαΊ½ kα»Ή thuαΊ­t**
103
+
104
+ Hα»— trợ 3 loαΊ‘i vΓΉng:
105
+ - 🟒 **PartDrawing** β€” vΓΉng bαΊ£n vαΊ½ chi tiαΊΏt
106
+ - 🟠 **Note** β€” ghi chΓΊ, chΓΊ thΓ­ch
107
+ - πŸ”΄ **Table** β€” bαΊ£ng dα»― liệu kα»Ή thuαΊ­t
108
+ """)
109
+
110
+ with gr.Row():
111
+ with gr.Column(scale=1):
112
+ inp = gr.Image(type="pil", label="πŸ“ Upload bαΊ£n vαΊ½ kα»Ή thuαΊ­t")
113
+ btn = gr.Button("πŸ” Detect & OCR", variant="primary", size="lg")
114
+
115
+ with gr.Column(scale=1):
116
+ out_img = gr.Image(label="βœ… KαΊΏt quαΊ£ detection")
117
+
118
+ with gr.Row():
119
+ with gr.Column(scale=1):
120
+ out_json = gr.Code(
121
+ language="json",
122
+ label="πŸ“‹ JSON output",
123
+ lines=25,
124
+ )
125
+ with gr.Column(scale=1):
126
+ out_ocr = gr.Textbox(
127
+ label="πŸ“ OCR content (Note & Table)",
128
+ lines=25,
129
+ max_lines=60,
130
+ )
131
+
132
+ btn.click(
133
+ fn = process,
134
+ inputs = [inp],
135
+ outputs = [out_img, out_json, out_ocr],
136
+ )
137
+
138
+ gr.Markdown("""
139
+ ---
140
+ **Model:** RT-DETR-L fine-tuned | **OCR:** EasyOCR (vi+en) + PaddleOCR fallback
141
+ **mAP50:** 0.942 | **Dataset:** Engineering drawings (Vietnamese technical)
142
+ """)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ demo.launch(
147
+ server_name = "0.0.0.0",
148
+ server_port = 7860,
149
+ share = False, # Δ‘α»•i True nαΊΏu muα»‘n link public tαΊ‘m
150
+ )
src/inference.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/inference.py
2
+ # ── Patch torch.load β€” PHαΊ’I LΓ€ DΓ’NG ĐẦU TIÊN ──────────────
3
+ import torch
4
+ _orig_torch_load = torch.load
5
+ def _patched_load(*args, **kwargs):
6
+ kwargs.setdefault("weights_only", False)
7
+ return _orig_torch_load(*args, **kwargs)
8
+ torch.load = _patched_load
9
+ # ───────────────────────────────────────────────────────────
10
+
11
+ import cv2
12
+ import json
13
+ import numpy as np
14
+ from pathlib import Path
15
+ from ultralytics import RTDETR
16
+
17
+ # ── Device ─────────────────────────────────────────────────
18
+ DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
19
+ print(f"[INFO] Device: {DEVICE}")
20
+
21
+ # ── Class config ────────────────────────────────────────────
22
+ CLASS_NAMES = ["note", "part-drawing", "table"]
23
+ CLASS_DISPLAY = {
24
+ "note": "Note",
25
+ "part-drawing": "PartDrawing",
26
+ "table": "Table",
27
+ }
28
+ COLORS = {
29
+ "note": (0, 165, 255),
30
+ "part-drawing": (0, 200, 0),
31
+ "table": (0, 0, 220),
32
+ }
33
+
34
+ # ───────────────────────────────────────────────────────────
35
+ # DETECTION MODEL
36
+ # ───────────────────────────────────────────────────────────
37
+ _det_model = None
38
+
39
+ def get_det_model(checkpoint: str = "best.pt") -> RTDETR:
40
+ global _det_model
41
+ if _det_model is None:
42
+ print(f"[INFO] Loading detection model: {checkpoint}")
43
+ _det_model = RTDETR(checkpoint)
44
+ return _det_model
45
+
46
+
47
+ # ───────────────────────────────────────────────────────────
48
+ # OCR ENGINES
49
+ # ───────────────────────────────────────────────────────────
50
+ _easy_reader = None
51
+ _paddle_engine = None
52
+
53
+ def get_easy_reader():
54
+ global _easy_reader
55
+ if _easy_reader is None:
56
+ import easyocr
57
+ print("[INFO] Loading EasyOCR (vi + en)...")
58
+ _easy_reader = easyocr.Reader(
59
+ ["vi", "en"],
60
+ gpu=False,
61
+ verbose=False,
62
+ )
63
+ return _easy_reader
64
+
65
+
66
+ def get_paddle_engine():
67
+ global _paddle_engine
68
+ if _paddle_engine is None:
69
+ from paddleocr import PaddleOCR
70
+ print("[INFO] Loading PaddleOCR (vi)...")
71
+ _paddle_engine = PaddleOCR(
72
+ use_angle_cls=True,
73
+ lang="vi",
74
+ show_log=False,
75
+ use_gpu=False,
76
+ )
77
+ return _paddle_engine
78
+
79
+
80
+ # ───────────────────────────────────────────────────────────
81
+ # PREPROCESSING
82
+ # ───────────────────────────────────────────────────────────
83
+ def preprocess_for_ocr(img_bgr: np.ndarray) -> np.ndarray:
84
+ h, w = img_bgr.shape[:2]
85
+
86
+ # Upscale nếu quÑ nhỏ
87
+ if w < 800:
88
+ scale = 800 / w
89
+ img_bgr = cv2.resize(
90
+ img_bgr,
91
+ (int(w * scale), int(h * scale)),
92
+ interpolation=cv2.INTER_CUBIC,
93
+ )
94
+
95
+ gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
96
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
97
+ gray = clahe.apply(gray)
98
+ gray = cv2.fastNlMeansDenoising(gray, h=15,
99
+ templateWindowSize=7,
100
+ searchWindowSize=21)
101
+ kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
102
+ gray = cv2.filter2D(gray, -1, kernel)
103
+
104
+ return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
105
+
106
+
107
+ # ───────────────────────────────────────────────────────────
108
+ # OCR: NOTE
109
+ # ───────────────────────────────────────────────────────────
110
+ def ocr_note(img_path: str) -> str:
111
+ img = cv2.imread(img_path)
112
+ if img is None:
113
+ return ""
114
+
115
+ img_proc = preprocess_for_ocr(img)
116
+
117
+ # EasyOCR
118
+ try:
119
+ reader = get_easy_reader()
120
+ results = reader.readtext(img_proc, detail=1, paragraph=False,
121
+ width_ths=0.7, height_ths=0.7)
122
+ lines = [t for (_, t, c) in results if c >= 0.2 and t.strip()]
123
+ if lines:
124
+ return "\n".join(lines)
125
+ except Exception as e:
126
+ print(f"[WARN] EasyOCR note: {e}")
127
+
128
+ # Fallback PaddleOCR
129
+ try:
130
+ ocr = get_paddle_engine()
131
+ result = ocr.ocr(img_proc, cls=True)
132
+ if result and result[0]:
133
+ return "\n".join(l[1][0] for l in result[0] if l[1][1] >= 0.2)
134
+ except Exception as e:
135
+ print(f"[WARN] PaddleOCR note: {e}")
136
+
137
+ return ""
138
+
139
+
140
+ # ───────────────────────────────────────────────────────────
141
+ # OCR: TABLE
142
+ # ───────────────────────────────────────────────────────────
143
+ def _group_rows(items: list) -> list:
144
+ if not items:
145
+ return []
146
+ items = sorted(items, key=lambda x: x["y"])
147
+ y_vals = [it["y"] for it in items]
148
+ if len(y_vals) > 1:
149
+ gaps = [y_vals[i+1] - y_vals[i] for i in range(len(y_vals)-1)]
150
+ thresh = max(8, (sum(gaps)/len(gaps)) * 0.6)
151
+ else:
152
+ thresh = 12
153
+
154
+ rows, cur = [], [items[0]]
155
+ for item in items[1:]:
156
+ if item["y"] - cur[-1]["y"] < thresh:
157
+ cur.append(item)
158
+ else:
159
+ cur.sort(key=lambda x: x["x"])
160
+ rows.append([i["text"] for i in cur])
161
+ cur = [item]
162
+ cur.sort(key=lambda x: x["x"])
163
+ rows.append([i["text"] for i in cur])
164
+ return rows
165
+
166
+
167
+ def ocr_table(img_path: str) -> dict:
168
+ img = cv2.imread(img_path)
169
+ if img is None:
170
+ return {"rows": [], "text": ""}
171
+
172
+ img_proc = preprocess_for_ocr(img)
173
+ items = []
174
+
175
+ # EasyOCR
176
+ try:
177
+ reader = get_easy_reader()
178
+ results = reader.readtext(img_proc, detail=1, paragraph=False,
179
+ width_ths=0.5, height_ths=0.5)
180
+ for (pts, text, conf) in results:
181
+ if conf < 0.2 or not text.strip():
182
+ continue
183
+ items.append({
184
+ "text": text.strip(),
185
+ "y": sum(p[1] for p in pts) / 4,
186
+ "x": sum(p[0] for p in pts) / 4,
187
+ })
188
+ except Exception as e:
189
+ print(f"[WARN] EasyOCR table: {e}")
190
+
191
+ # Fallback PaddleOCR
192
+ if not items:
193
+ try:
194
+ ocr = get_paddle_engine()
195
+ result = ocr.ocr(img_proc, cls=True)
196
+ if result and result[0]:
197
+ for line in result[0]:
198
+ pts, (text, conf) = line[0], line[1]
199
+ if conf < 0.2 or not text.strip():
200
+ continue
201
+ items.append({
202
+ "text": text.strip(),
203
+ "y": sum(p[1] for p in pts) / 4,
204
+ "x": sum(p[0] for p in pts) / 4,
205
+ })
206
+ except Exception as e:
207
+ print(f"[WARN] PaddleOCR table: {e}")
208
+
209
+ if not items:
210
+ return {"rows": [], "text": ""}
211
+
212
+ rows = _group_rows(items)
213
+ return {
214
+ "rows": rows,
215
+ "text": "\n".join(" | ".join(r) for r in rows),
216
+ }
217
+
218
+
219
+ # ───────────────────────────────────────────────────────────
220
+ # MAIN PIPELINE
221
+ # ───────────────────────────────────────────────────────────
222
+ def run_pipeline(
223
+ image_path: str,
224
+ output_dir: str = "outputs",
225
+ checkpoint: str = "best.pt",
226
+ conf_thresh: float = 0.3,
227
+ ) -> tuple:
228
+ image_path = str(image_path)
229
+ img_name = Path(image_path).name
230
+ stem = Path(image_path).stem
231
+ crop_dir = Path(output_dir) / stem / "crops"
232
+ crop_dir.mkdir(parents=True, exist_ok=True)
233
+
234
+ # 1. Detect
235
+ model = get_det_model(checkpoint)
236
+ results = model(image_path, imgsz=1024, conf=conf_thresh,
237
+ iou=0.5, device=DEVICE, verbose=False)
238
+
239
+ img_bgr = cv2.imread(image_path)
240
+ if img_bgr is None:
241
+ raise ValueError(f"KhΓ΄ng đọc được αΊ£nh: {image_path}")
242
+
243
+ objects = []
244
+
245
+ for i, box in enumerate(results[0].boxes):
246
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
247
+ cls_idx = int(box.cls[0])
248
+ conf_val = round(float(box.conf[0]), 4)
249
+ cls_raw = CLASS_NAMES[cls_idx]
250
+ cls_show = CLASS_DISPLAY[cls_raw]
251
+
252
+ # 2. Crop
253
+ pad = 6
254
+ crop = img_bgr[max(0,y1-pad):min(img_bgr.shape[0],y2+pad),
255
+ max(0,x1-pad):min(img_bgr.shape[1],x2+pad)]
256
+ crop_path = str(crop_dir / f"{cls_show}_{i+1}.jpg")
257
+ cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 95])
258
+
259
+ # 3. OCR
260
+ ocr_content = None
261
+ if cls_raw == "note":
262
+ print(f"[OCR] Note #{i+1}...")
263
+ ocr_content = ocr_note(crop_path)
264
+ print(f" β†’ {repr(ocr_content[:80]) if ocr_content else 'EMPTY'}")
265
+ elif cls_raw == "table":
266
+ print(f"[OCR] Table #{i+1}...")
267
+ ocr_content = ocr_table(crop_path)
268
+ print(f" β†’ {repr(ocr_content.get('text','')[:80]) if ocr_content else 'EMPTY'}")
269
+
270
+ objects.append({
271
+ "id": i + 1,
272
+ "class": cls_show,
273
+ "confidence": conf_val,
274
+ "bbox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
275
+ "crop_path": crop_path,
276
+ "ocr_content": ocr_content,
277
+ })
278
+
279
+ # 4. VαΊ½ bbox
280
+ color = COLORS[cls_raw]
281
+ cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2)
282
+ label = f"{cls_show} {conf_val:.2f}"
283
+ (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
284
+ cv2.rectangle(img_bgr, (x1, y1-th-10), (x1+tw+8, y1), color, -1)
285
+ cv2.putText(img_bgr, label, (x1+4, y1-4),
286
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
287
+
288
+ # 5. LΖ°u visualize
289
+ vis_path = str(Path(output_dir) / stem / "result_vis.jpg")
290
+ cv2.imwrite(vis_path, img_bgr)
291
+
292
+ # 6. LΖ°u JSON
293
+ result = {"image": img_name, "objects": objects}
294
+ json_path = str(Path(output_dir) / stem / "result.json")
295
+ with open(json_path, "w", encoding="utf-8") as f:
296
+ json.dump(result, f, ensure_ascii=False, indent=2)
297
+
298
+ print(f"\n[βœ“] {len(objects)} objects | visβ†’{vis_path} | jsonβ†’{json_path}")
299
+ return result, vis_path
300
+
301
+
302
+ # ── CLI ──────────────────────────────────────────────────────
303
+ if __name__ == "__main__":
304
+ import sys
305
+ img = sys.argv[1] if len(sys.argv) > 1 else "test.jpg"
306
+ result, _ = run_pipeline(img)
307
+ print(json.dumps(result, ensure_ascii=False, indent=2))