Zeldeo commited on
Commit
fd06368
·
verified ·
1 Parent(s): 9f0ff72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -21
app.py CHANGED
@@ -4,7 +4,8 @@ from fastapi.middleware.cors import CORSMiddleware
4
  from PIL import Image
5
  import io
6
  import torch
7
- from transformers import AutoImageProcessor, AutoModelForObjectDetection
 
8
 
9
  app = FastAPI()
10
 
@@ -16,41 +17,84 @@ app.add_middleware(
16
  )
17
 
18
  processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
19
- model = AutoModelForObjectDetection.from_pretrained("czczup/textnet-base")
20
  model.eval()
21
 
22
  @app.post("/detect")
23
  async def detect_text(file: UploadFile = File(...)):
24
  try:
 
25
  image_bytes = await file.read()
26
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
27
 
28
- inputs = processor(images=image, return_tensors="pt")
29
-
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
 
33
- results = processor.post_process_object_detection(
34
- outputs,
35
- threshold=0.3,
36
- target_sizes=[image.size[::-1]]
37
- )[0]
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
39
  boxes = []
40
- for poly, score in zip(results["polygons"], results["scores"]):
41
- poly = poly.tolist()
42
- xs = [p[0] for p in poly]
43
- ys = [p[1] for p in poly]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  boxes.append({
46
- "polygon": poly,
47
- "x": int(min(xs)),
48
- "y": int(min(ys)),
49
- "w": int(max(xs) - min(xs)),
50
- "h": int(max(ys) - min(ys)),
51
- "score": float(score)
52
  })
53
-
 
54
  return JSONResponse({
55
  "image_width": image.width,
56
  "image_height": image.height,
@@ -58,4 +102,4 @@ async def detect_text(file: UploadFile = File(...)):
58
  })
59
 
60
  except Exception as e:
61
- return JSONResponse({"success": False, "error": str(e)}, status_code=500)
 
4
  from PIL import Image
5
  import io
6
  import torch
7
+ from transformers import AutoImageProcessor, AutoBackbone
8
+ import pytesseract # OCR
9
 
10
  app = FastAPI()
11
 
 
17
  )
18
 
19
  processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
20
+ model = AutoBackbone.from_pretrained("czczup/textnet-base")
21
  model.eval()
22
 
23
  @app.post("/detect")
24
  async def detect_text(file: UploadFile = File(...)):
25
  try:
26
+ # Lire image
27
  image_bytes = await file.read()
28
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
29
 
30
+ # Entrée TextNet
31
+ inputs = processor(image, return_tensors="pt")
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
 
35
+ # Feature map et heatmap
36
+ fm = outputs.feature_maps[-1][0] # dernière layer
37
+ heatmap = fm.mean(dim=0).numpy()
38
+ H, W = heatmap.shape
39
+ threshold = heatmap.max() * 0.2
40
+
41
+ # Points chauds
42
+ points = [(x, y) for y in range(H) for x in range(W) if heatmap[y, x] > threshold]
43
+
44
+ if not points:
45
+ return JSONResponse([])
46
 
47
+ # Regrouper par lignes simples
48
+ lines = {}
49
+ for x, y in points:
50
+ key = int(y / 10)
51
+ lines.setdefault(key, []).append((x, y))
52
+
53
+ # Générer boxes et extraire texte OCR
54
+ scale_x = image.width / W
55
+ scale_y = image.height / H
56
  boxes = []
57
+
58
+ for line in lines.values():
59
+ xs = [p[0] for p in line]
60
+ ys = [p[1] for p in line]
61
+ min_x, max_x = min(xs), max(xs)
62
+ min_y, max_y = min(ys), max(ys)
63
+
64
+ if (max_x - min_x) < 5 or (max_y - min_y) < 2:
65
+ continue
66
+
67
+ # Crop pour OCR
68
+ crop = image.crop((
69
+ int(min_x * scale_x),
70
+ int(min_y * scale_y),
71
+ int(max_x * scale_x),
72
+ int(max_y * scale_y)
73
+ ))
74
+
75
+ text = pytesseract.image_to_string(crop, lang='eng').strip()
76
+
77
+ if len(text) < 2:
78
+ continue
79
+
80
+ if len(boxes) == 0:
81
+ boxes.append({
82
+ "x": 10,
83
+ "y": 10,
84
+ "w": 100,
85
+ "h": 50,
86
+ "text": "Aucun texte détecté"
87
+ })
88
 
89
  boxes.append({
90
+ "x": int(min_x * scale_x),
91
+ "y": int(min_y * scale_y),
92
+ "w": int((max_x - min_x) * scale_x),
93
+ "h": int((max_y - min_y) * scale_y),
94
+ "text": text or "texte non reconnu"
 
95
  })
96
+ print("BOXES:", boxes)
97
+
98
  return JSONResponse({
99
  "image_width": image.width,
100
  "image_height": image.height,
 
102
  })
103
 
104
  except Exception as e:
105
+ return JSONResponse({"success": False, "error": str(e)}, status_code=500)