YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from safetensors.torch import load_file as safetensors_load_file
# Define a transform to convert PIL images to tensors
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
class ViTSalesModel(nn.Module):
def __init__(self):
super(ViTSalesModel, self).__init__()
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
self.classifier = nn.Linear(self.vit.config.hidden_size, 1)
def forward(self, pixel_values, labels=None):
outputs = self.vit(pixel_values=pixel_values)
cls_output = outputs.last_hidden_state[:, 0, :] # Take the [CLS] token
sales = self.classifier(cls_output)
loss = None
if labels is not None:
loss_fct = nn.MSELoss()
loss = loss_fct(sales.view(-1), labels.view(-1))
return (loss, sales) if loss is not None else sales
model = ViTSalesModel()
# Load the saved model checkpoint
checkpoint_path = "/content/results/checkpoint-940/model.safetensors"
state_dict = safetensors_load_file(checkpoint_path)
model.load_state_dict(state_dict)
model.eval()
# Maximum sales value for de-normalization (from training)
max_sales_value = 100000 # Replace with the actual max sales value used during training
def predict_sales(image_path):
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
# Run the model
prediction = model(image)
print(prediction)
# De-normalize the prediction
sales_prediction = prediction.item() * max_sales_value
return sales_prediction
# Example usage
image_path = "/content/0000.png"
predicted_sales = predict_sales(image_path)
print(f"Predicted sales: {predicted_sales}")
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support