File size: 4,311 Bytes
0a8ae4e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple
import torch
import torch.nn.functional as F
from PIL import Image
CROP_CLASSES = [
"apple",
"avocado",
"banana",
"barley",
"bell pepper",
"broccoli",
"cacao",
"canola",
"cauliflower",
"cherry",
"chilli",
"coconut",
"coffee",
"corn",
"cotton",
"cucumber",
"eggplant",
"kiwi",
"lemon",
"mango",
"olive",
"orange",
"pear",
"peas",
"pineapple",
"pomegranate",
"potato",
"pumpkin",
"rice",
"soyabean",
"strawberry",
"sugarcane",
"sunflower",
"tea",
"tomato",
"watermelon",
"wheat",
]
def _normalize(features: torch.Tensor) -> torch.Tensor:
return F.normalize(features.float(), dim=-1)
class CropVLMClassifier:
"""Small zero-shot wrapper around the CropVLM/OpenAI CLIP ViT-B/32 model."""
def __init__(
self,
checkpoint: str,
class_names: Sequence[str] = CROP_CLASSES,
device: str | None = None,
prompt_template: str = "{}",
):
import clip
self.clip = clip
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.prompt_template = prompt_template
self.class_names = list(class_names)
checkpoint_path = Path(checkpoint)
if not checkpoint_path.exists():
raise FileNotFoundError(f"CropVLM checkpoint not found: {checkpoint_path}")
self.model, self.preprocess = clip.load(
"ViT-B/32",
device=str(self.device),
download_root=str(Path.home() / ".cache" / "clip"),
)
ckpt = torch.load(checkpoint_path, map_location=self.device)
state = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
self.model.load_state_dict(state)
self.model.eval()
self.set_classes(self.class_names)
def set_classes(self, class_names: Sequence[str]) -> None:
self.class_names = [c.strip() for c in class_names if c.strip()]
prompts = [self.prompt_template.format(c) for c in self.class_names]
tokens = self.clip.tokenize(prompts, truncate=True).to(self.device)
with torch.no_grad():
self.text_features = _normalize(self.model.encode_text(tokens))
def encode_image(self, image: Image.Image) -> torch.Tensor:
image = image.convert("RGB")
batch = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
return _normalize(self.model.encode_image(batch))
def predict(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float]]:
return [(label, probability) for label, probability, _ in self.predict_with_scores(image, top_k=top_k)]
def predict_scores(self, image: Image.Image) -> Dict[str, float]:
image_features = self.encode_image(image)
logits = (image_features @ self.text_features.T).squeeze(0)
return {name: float(score) for name, score in zip(self.class_names, logits.tolist())}
def predict_with_scores(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float, float]]:
image_features = self.encode_image(image)
cosine_scores = (image_features @ self.text_features.T).squeeze(0)
logit_scale = self.model.logit_scale.exp().clamp(max=100)
probabilities = (logit_scale * cosine_scores).softmax(dim=-1)
k = min(top_k, len(self.class_names))
probs, indices = probabilities.topk(k)
return [
(self.class_names[idx], float(prob), float(cosine_scores[idx]))
for prob, idx in zip(probs.tolist(), indices.tolist())
]
def load_cropvlm(
checkpoint: str,
class_names: Sequence[str] = CROP_CLASSES,
device: str | None = None,
prompt_template: str = "{}",
) -> CropVLMClassifier:
return CropVLMClassifier(
checkpoint=checkpoint,
class_names=class_names,
device=device,
prompt_template=prompt_template,
)
def parse_class_names(text: str | Iterable[str]) -> List[str]:
if isinstance(text, str):
raw = text.replace(",", "\n").splitlines()
else:
raw = list(text)
return [name.strip() for name in raw if name.strip()]
|