dcorcoran commited on
Commit
6a18d52
·
1 Parent(s): 28fbb7e

Initial commit

Browse files
.dockerignore ADDED
File without changes
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ tesseract-ocr \
7
+ libgl1 \
8
+ libglib2.0-0 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ COPY app ./app
15
+
16
+ EXPOSE 8000
17
+
18
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
app/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.47 kB). View file
 
app/__pycache__/main.cpython-311.pyc ADDED
Binary file (3.65 kB). View file
 
app/__pycache__/schemas.cpython-311.pyc ADDED
Binary file (1.86 kB). View file
 
app/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+ from pathlib import Path
3
+
4
+ # --------------------------------
5
+ # ----- PATHS --------------------
6
+ # --------------------------------
7
+
8
+ BASE_DIR = Path(__file__).resolve().parent
9
+
10
+ class Settings(BaseSettings):
11
+
12
+ # Index paths
13
+ FAISS_INDEX_PATH: str = str(BASE_DIR / "index" / "faiss_index.bin")
14
+ METADATA_PATH: str = str(BASE_DIR / "index" / "metadata.json")
15
+
16
+ # Model settings
17
+ EMBEDDING_DIM: int = 2048
18
+ TOP_K: int = 5
19
+
20
+ # Tesseract path (Windows only)
21
+ TESSERACT_PATH: str = "C:/Program Files/Tesseract-OCR/tesseract.exe"
22
+
23
+ class Config:
24
+ env_file = ".env"
25
+
26
+ settings = Settings()
app/index/faiss_index.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67c8f11faee0a2a2d30eba64f64aa0f924b413983320ad3dd532e5f174b4d35f
3
+ size 4046893
app/index/metadata.json ADDED
The diff for this file is too large to render. See raw diff
 
app/main.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from contextlib import asynccontextmanager
4
+ from PIL import Image
5
+ import io
6
+
7
+ from app.services.embedding_service import EmbeddingService
8
+ from app.services.similarity_service import SimilarityService
9
+ from app.services.ocr_service import OCRService
10
+ from app.schemas import CardResponse
11
+ from app.config import settings
12
+
13
+ # --------------------
14
+ # ----- LIFESPAN -----
15
+ # --------------------
16
+
17
+ @asynccontextmanager
18
+ async def lifespan(app: FastAPI):
19
+ # Load all models and indexes once at startup
20
+ app.state.embedding_service = EmbeddingService()
21
+ app.state.similarity_service = SimilarityService()
22
+ app.state.ocr_service = OCRService()
23
+ print("Models and index loaded.")
24
+ yield
25
+ print("Shutting down.")
26
+
27
+
28
+
29
+ # ---------------
30
+ # ----- APP -----
31
+ # ---------------
32
+
33
+ app = FastAPI(
34
+ title="Pokemon Card Image Processor",
35
+ version="1.0.0",
36
+ lifespan=lifespan
37
+ )
38
+
39
+ app.add_middleware(
40
+ CORSMiddleware,
41
+ allow_origins=["*"],
42
+ allow_methods=["*"],
43
+ allow_headers=["*"],
44
+ )
45
+
46
+
47
+
48
+ # --------------------
49
+ # ----- ROUTES -------
50
+ # --------------------
51
+
52
+ @app.get("/health")
53
+ def health():
54
+ return {"status": "ok"}
55
+
56
+
57
+ @app.post("/predict", response_model=CardResponse)
58
+ async def predict(file: UploadFile = File(...)):
59
+ image_bytes = await file.read()
60
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
61
+
62
+ # Generate embedding and find similar cards first
63
+ embedding = app.state.embedding_service.embed(image)
64
+ similar_cards = app.state.similarity_service.search(embedding, top_k=5)
65
+
66
+ # Use OCR for extraction
67
+ ocr_data = app.state.ocr_service.extract(image)
68
+
69
+ # If top match is very confident, use its metadata to fill in OCR gaps
70
+ if similar_cards and similar_cards[0]["score"] > 0.99:
71
+ top_match = similar_cards[0]
72
+ ocr_data["name"] = ocr_data["name"] or top_match["name"]
73
+ ocr_data["types"] = ocr_data["types"] or top_match["types"]
74
+
75
+ return CardResponse(
76
+ name=ocr_data.get("name"),
77
+ hp=ocr_data.get("hp"),
78
+ types=ocr_data.get("types"),
79
+ moves=ocr_data.get("moves"),
80
+ similar_cards=similar_cards
81
+ )
app/models/embedding_model.py ADDED
File without changes
app/models/layout_detector.py ADDED
File without changes
app/models/ocr_model.py ADDED
File without changes
app/schemas.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional
3
+
4
+ # --------------------------------
5
+ # ----- SIMILAR CARD -------------
6
+ # --------------------------------
7
+
8
+ class SimilarCard(BaseModel):
9
+ id: str
10
+ name: str
11
+ set: Optional[str] = None
12
+ types: Optional[list[str]] = None
13
+ rarity: Optional[str] = None
14
+ image_url: Optional[str] = None
15
+ score: float
16
+
17
+
18
+
19
+ # --------------------------------
20
+ # ----- MOVE ---------------------
21
+ # --------------------------------
22
+
23
+ class Move(BaseModel):
24
+ name: str
25
+ damage: Optional[str] = None
26
+ text: Optional[str] = None
27
+
28
+
29
+
30
+ # --------------------------------
31
+ # ----- CARD RESPONSE ------------
32
+ # --------------------------------
33
+
34
+ class CardResponse(BaseModel):
35
+ name: Optional[str] = None
36
+ hp: Optional[str] = None
37
+ types: Optional[list[str]] = None
38
+ moves: Optional[list[Move]] = None
39
+ similar_cards: list[SimilarCard] = []
app/services/__pycache__/embedding_service.cpython-311.pyc ADDED
Binary file (2.76 kB). View file
 
app/services/__pycache__/ocr_service.cpython-311.pyc ADDED
Binary file (5.74 kB). View file
 
app/services/__pycache__/similarity_service.cpython-311.pyc ADDED
Binary file (2.85 kB). View file
 
app/services/embedding_service.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from torchvision import models
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ class EmbeddingService:
8
+
9
+ def __init__(self):
10
+ print("Loading embedding model...")
11
+
12
+ # Load pretrained ResNet50
13
+ model = models.resnet50(pretrained=True)
14
+
15
+ # Remove final classification layer
16
+ self.model = torch.nn.Sequential(*list(model.children())[:-1])
17
+ self.model.eval()
18
+
19
+ # Preprocessing pipeline, must match what was used to build the index (in data_collection/build_faiss_index.py)
20
+ self.transform = transforms.Compose([
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225]
26
+ )
27
+ ])
28
+
29
+ print("Embedding model loaded.")
30
+
31
+ def embed(self, image: Image.Image) -> np.ndarray:
32
+ # Preprocess
33
+ tensor = self.transform(image).unsqueeze(0)
34
+
35
+ # Forward pass
36
+ with torch.no_grad():
37
+ embedding = self.model(tensor)
38
+
39
+ # Flatten to 1D and normalize
40
+ embedding = embedding.squeeze().numpy().astype("float32")
41
+ embedding = embedding / np.linalg.norm(embedding)
42
+
43
+ return embedding
app/services/ocr_service.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytesseract
2
+ import re
3
+ from PIL import Image
4
+
5
+ from app.config import settings
6
+
7
+ pytesseract.pytesseract.tesseract_cmd = settings.TESSERACT_PATH
8
+
9
+ class OCRService:
10
+
11
+ def __init__(self):
12
+ print("OCR service initialized.")
13
+
14
+ def extract(self, image: Image.Image) -> dict:
15
+ w, h = image.size
16
+
17
+ # --------------------------------
18
+ # ----- CROP REGIONS -------------
19
+ # --------------------------------
20
+
21
+ # Card name — top left area
22
+ name_region = image.crop((0.15 * w, 0.02 * h, 0.75 * w, 0.10 * h))
23
+
24
+ # HP — top right area
25
+ hp_region = image.crop((0.60 * w, 0.02 * h, 0.95 * w, 0.10 * h))
26
+
27
+ # Moves — lower middle section
28
+ moves_region = image.crop((0.00 * w, 0.55 * h, 1.00 * w, 0.85 * h))
29
+
30
+ # Full image for type detection
31
+ full_text = pytesseract.image_to_string(image)
32
+
33
+ # --------------------------------
34
+ # ----- EXTRACT FIELDS -----------
35
+ # --------------------------------
36
+
37
+ return {
38
+ "name": self._extract_name(name_region),
39
+ "hp": self._extract_hp(hp_region),
40
+ "types": self._extract_types(full_text),
41
+ "moves": self._extract_moves(moves_region),
42
+ }
43
+
44
+ # --------------------------------
45
+ # ----- EXTRACTORS ---------------
46
+ # --------------------------------
47
+
48
+ def _extract_name(self, region: Image.Image) -> str | None:
49
+ # Upscale region for better OCR accuracy
50
+ region = region.resize(
51
+ (region.width * 3, region.height * 3),
52
+ Image.LANCZOS
53
+ )
54
+ text = pytesseract.image_to_string(region, config="--psm 7").strip()
55
+ return text if text else None
56
+
57
+ def _extract_hp(self, region: Image.Image) -> str | None:
58
+ region = region.resize(
59
+ (region.width * 3, region.height * 3),
60
+ Image.LANCZOS
61
+ )
62
+ text = pytesseract.image_to_string(region, config="--psm 7")
63
+ match = re.search(r'(\d+)\s*HP|HP\s*(\d+)', text, re.IGNORECASE)
64
+ if match:
65
+ return match.group(1) or match.group(2)
66
+ return None
67
+
68
+ def _extract_types(self, text: str) -> list[str] | None:
69
+ types = [
70
+ "Fire", "Water", "Grass", "Electric", "Psychic",
71
+ "Fighting", "Darkness", "Metal", "Colorless",
72
+ "Dragon", "Fairy", "Lightning", "Normal"
73
+ ]
74
+ found = [t for t in types if t.lower() in text.lower()]
75
+ return found if found else None
76
+
77
+ def _extract_moves(self, region: Image.Image) -> list[dict] | None:
78
+ region = region.resize(
79
+ (region.width * 2, region.height * 2),
80
+ Image.LANCZOS
81
+ )
82
+ text = pytesseract.image_to_string(region)
83
+ lines = [line.strip() for line in text.splitlines() if line.strip()]
84
+
85
+ moves = []
86
+ i = 0
87
+ while i < len(lines):
88
+ # Match move name with damage e.g. "Lightning Flash 20"
89
+ match = re.match(r'^([A-Z][a-zA-Z\s]+?)\s+(\d+\+?)$', lines[i])
90
+ if match:
91
+ moves.append({
92
+ "name": match.group(1).strip(),
93
+ "damage": match.group(2).strip(),
94
+ "text": lines[i + 1] if i + 1 < len(lines) else None
95
+ })
96
+ i += 2
97
+ else:
98
+ i += 1
99
+
100
+ return moves if moves else None
app/services/similarity_service.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import json
3
+ import numpy as np
4
+
5
+ from app.config import settings
6
+
7
+ class SimilarityService:
8
+
9
+ def __init__(self):
10
+ print("Loading FAISS index...")
11
+
12
+ # Load FAISS index
13
+ self.index = faiss.read_index(settings.FAISS_INDEX_PATH)
14
+
15
+ # Load metadata
16
+ with open(settings.METADATA_PATH, "r") as f:
17
+ self.metadata = json.load(f)
18
+
19
+ print(f"FAISS index loaded with {self.index.ntotal} cards.")
20
+
21
+ def search(self, embedding: np.ndarray, top_k: int = 5) -> list:
22
+ # Reshape to 2D array for FAISS
23
+ query = embedding.reshape(1, -1).astype("float32")
24
+
25
+ # Normalize for cosine similarity
26
+ faiss.normalize_L2(query)
27
+
28
+ # Search index
29
+ scores, indices = self.index.search(query, top_k)
30
+
31
+ # Map results to metadata
32
+ results = []
33
+ for score, idx in zip(scores[0], indices[0]):
34
+ if idx == -1:
35
+ continue
36
+
37
+ card = self.metadata[idx]
38
+ results.append({
39
+ "id": card.get("id"),
40
+ "name": card.get("name"),
41
+ "set": card.get("set"),
42
+ "types": card.get("types"),
43
+ "rarity": card.get("rarity"),
44
+ "image_url": card.get("image_url"),
45
+ "score": float(score)
46
+ })
47
+
48
+ return results
app/utils.py ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pillow
4
+ numpy
5
+ torch
6
+ torchvision
7
+ faiss-cpu
8
+ pytesseract
9
+ opencv-python
10
+ pydantic
11
+ pydantic-settings
12
+ python-multipart
13
+ requests
training/build_faiss_index.py ADDED
File without changes
training/evaluate_similarity.py ADDED
File without changes
training/ocr_validation.py ADDED
File without changes
training/train_embedding.py ADDED
File without changes