FoodDesert commited on
Commit
39c299f
·
verified ·
1 Parent(s): 376e833

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +13 -1
model.py CHANGED
@@ -87,9 +87,21 @@ def _embed_image(image: Image.Image) -> torch.Tensor:
87
  dy = h // 8
88
  if dy > 0:
89
  image = image.crop((0, dy, w, h - dy))
90
- inputs = _CLIP_PROCESSOR(images=[image.convert("RGB")], return_tensors="pt").to(_DEVICE)
 
 
 
 
91
  with torch.no_grad():
 
92
  feats = _CLIP_MODEL.get_image_features(**inputs)
 
 
 
 
 
 
 
93
  feats = feats / feats.norm(dim=-1, keepdim=True)
94
  return feats # [1, d_in]
95
 
 
87
  dy = h // 8
88
  if dy > 0:
89
  image = image.crop((0, dy, w, h - dy))
90
+
91
+ # Newer HF processor outputs should be moved field-by-field for robustness.
92
+ inputs = _CLIP_PROCESSOR(images=[image.convert("RGB")], return_tensors="pt")
93
+ inputs = {k: v.to(_DEVICE) for k, v in inputs.items()}
94
+
95
  with torch.no_grad():
96
+ # Preferred API: projected CLIP image embeddings as a tensor.
97
  feats = _CLIP_MODEL.get_image_features(**inputs)
98
+
99
+ # Defensive fallback in case an HF-side change returns a structured output.
100
+ if hasattr(feats, "image_embeds"):
101
+ feats = feats.image_embeds
102
+ elif hasattr(feats, "pooler_output"):
103
+ feats = feats.pooler_output
104
+
105
  feats = feats / feats.norm(dim=-1, keepdim=True)
106
  return feats # [1, d_in]
107