CropVLM / cropvlm /model.py
boudiafA's picture
Add CropVLM model card and code
0a8ae4e verified
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple
import torch
import torch.nn.functional as F
from PIL import Image
CROP_CLASSES = [
"apple",
"avocado",
"banana",
"barley",
"bell pepper",
"broccoli",
"cacao",
"canola",
"cauliflower",
"cherry",
"chilli",
"coconut",
"coffee",
"corn",
"cotton",
"cucumber",
"eggplant",
"kiwi",
"lemon",
"mango",
"olive",
"orange",
"pear",
"peas",
"pineapple",
"pomegranate",
"potato",
"pumpkin",
"rice",
"soyabean",
"strawberry",
"sugarcane",
"sunflower",
"tea",
"tomato",
"watermelon",
"wheat",
]
def _normalize(features: torch.Tensor) -> torch.Tensor:
return F.normalize(features.float(), dim=-1)
class CropVLMClassifier:
"""Small zero-shot wrapper around the CropVLM/OpenAI CLIP ViT-B/32 model."""
def __init__(
self,
checkpoint: str,
class_names: Sequence[str] = CROP_CLASSES,
device: str | None = None,
prompt_template: str = "{}",
):
import clip
self.clip = clip
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.prompt_template = prompt_template
self.class_names = list(class_names)
checkpoint_path = Path(checkpoint)
if not checkpoint_path.exists():
raise FileNotFoundError(f"CropVLM checkpoint not found: {checkpoint_path}")
self.model, self.preprocess = clip.load(
"ViT-B/32",
device=str(self.device),
download_root=str(Path.home() / ".cache" / "clip"),
)
ckpt = torch.load(checkpoint_path, map_location=self.device)
state = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
self.model.load_state_dict(state)
self.model.eval()
self.set_classes(self.class_names)
def set_classes(self, class_names: Sequence[str]) -> None:
self.class_names = [c.strip() for c in class_names if c.strip()]
prompts = [self.prompt_template.format(c) for c in self.class_names]
tokens = self.clip.tokenize(prompts, truncate=True).to(self.device)
with torch.no_grad():
self.text_features = _normalize(self.model.encode_text(tokens))
def encode_image(self, image: Image.Image) -> torch.Tensor:
image = image.convert("RGB")
batch = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
return _normalize(self.model.encode_image(batch))
def predict(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float]]:
return [(label, probability) for label, probability, _ in self.predict_with_scores(image, top_k=top_k)]
def predict_scores(self, image: Image.Image) -> Dict[str, float]:
image_features = self.encode_image(image)
logits = (image_features @ self.text_features.T).squeeze(0)
return {name: float(score) for name, score in zip(self.class_names, logits.tolist())}
def predict_with_scores(self, image: Image.Image, top_k: int = 5) -> List[Tuple[str, float, float]]:
image_features = self.encode_image(image)
cosine_scores = (image_features @ self.text_features.T).squeeze(0)
logit_scale = self.model.logit_scale.exp().clamp(max=100)
probabilities = (logit_scale * cosine_scores).softmax(dim=-1)
k = min(top_k, len(self.class_names))
probs, indices = probabilities.topk(k)
return [
(self.class_names[idx], float(prob), float(cosine_scores[idx]))
for prob, idx in zip(probs.tolist(), indices.tolist())
]
def load_cropvlm(
checkpoint: str,
class_names: Sequence[str] = CROP_CLASSES,
device: str | None = None,
prompt_template: str = "{}",
) -> CropVLMClassifier:
return CropVLMClassifier(
checkpoint=checkpoint,
class_names=class_names,
device=device,
prompt_template=prompt_template,
)
def parse_class_names(text: str | Iterable[str]) -> List[str]:
if isinstance(text, str):
raw = text.replace(",", "\n").splitlines()
else:
raw = list(text)
return [name.strip() for name in raw if name.strip()]