Addax-Data-Science commited on
Commit
8817a3b
·
verified ·
1 Parent(s): 3bf0df7

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +22 -0
inference.py CHANGED
@@ -27,6 +27,7 @@ import platform
27
  import sys
28
  from pathlib import Path
29
 
 
30
  import pandas as pd
31
  import torch
32
  import torch.nn as nn
@@ -308,3 +309,24 @@ class ModelInference:
308
  class_names[class_id_str] = class_name
309
 
310
  return class_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  import sys
28
  from pathlib import Path
29
 
30
+ import numpy as np
31
  import pandas as pd
32
  import torch
33
  import torch.nn as nn
 
309
  class_names[class_id_str] = class_name
310
 
311
  return class_names
312
+
313
+ def get_tensor(self, crop: Image.Image):
314
+ """Preprocess a crop into a numpy array for batch inference."""
315
+ tensor = self.preprocess(crop)
316
+ return tensor.numpy()
317
+
318
+ def classify_batch(self, batch):
319
+ """Run inference on a batch of preprocessed numpy arrays."""
320
+ tensor = torch.from_numpy(batch).to(self.device)
321
+ with torch.no_grad():
322
+ output = self.model(tensor)
323
+ probs = F.softmax(output, dim=1).cpu().numpy()
324
+
325
+ results = []
326
+ for p in probs:
327
+ classifications = []
328
+ for i in range(len(p)):
329
+ pred_class = self.classes.iloc[i]['Code']
330
+ classifications.append([pred_class, float(p[i])])
331
+ results.append(classifications)
332
+ return results