Dev Nagaich commited on
Commit
35351a6
·
1 Parent(s): 6f575dc

Improve: Add better error handling and logging for model loading

Browse files
Files changed (1) hide show
  1. model_server.py +31 -8
model_server.py CHANGED
@@ -94,32 +94,55 @@ class ProtectedModelServer:
94
  # Add segment-anything-2 to path (internally only)
95
  base_dir = Path(__file__).parent
96
  sam2_path = base_dir / "segment-anything-2"
 
 
 
 
97
  sys.path.insert(0, str(sam2_path))
98
 
99
- from sam2.build_sam import build_sam2
100
- from sam2.sam2_image_predictor import SAM2ImagePredictor
 
 
 
101
 
102
  # Get paths internally - NEVER sent to client
103
  model_cfg = get_model_config_path()
104
  sam2_checkpoint = get_model_checkpoint_path()
105
- fine_tuned_weights = get_finetuned_weights_path()
106
 
107
- # Load model
108
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
109
 
 
 
110
  self._model = build_sam2(model_cfg, sam2_checkpoint, device=device)
111
  self._predictor = SAM2ImagePredictor(self._model)
112
 
113
- # Load fine-tuned weights
114
- state_dict = torch.load(fine_tuned_weights, map_location=device)
115
- self._predictor.model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
116
 
117
  # Model is now loaded - weights are NOT accessible to clients
118
  self._predictor.model.eval()
 
119
 
120
  return True
121
  except Exception as e:
122
- raise RuntimeError(f"Model initialization failed") from e
 
 
 
123
 
124
  def predict(self, image: np.ndarray, num_samples: int = 30) -> Tuple[np.ndarray, np.ndarray]:
125
  """
 
94
  # Add segment-anything-2 to path (internally only)
95
  base_dir = Path(__file__).parent
96
  sam2_path = base_dir / "segment-anything-2"
97
+
98
+ if not sam2_path.exists():
99
+ raise FileNotFoundError(f"SAM2 installation not found at {sam2_path}")
100
+
101
  sys.path.insert(0, str(sam2_path))
102
 
103
+ try:
104
+ from sam2.build_sam import build_sam2
105
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
106
+ except ImportError as e:
107
+ raise ImportError("SAM2 not properly installed. Check build logs.") from e
108
 
109
  # Get paths internally - NEVER sent to client
110
  model_cfg = get_model_config_path()
111
  sam2_checkpoint = get_model_checkpoint_path()
 
112
 
113
+ # Load device
114
  device = "cuda" if torch.cuda.is_available() else "cpu"
115
+ print(f"Loading model on device: {device}")
116
 
117
+ # Load base SAM2 model
118
+ print(f"Loading SAM2 from {sam2_checkpoint}")
119
  self._model = build_sam2(model_cfg, sam2_checkpoint, device=device)
120
  self._predictor = SAM2ImagePredictor(self._model)
121
 
122
+ # Try to load fine-tuned weights if available
123
+ try:
124
+ fine_tuned_weights = get_finetuned_weights_path()
125
+ print(f"Loading fine-tuned weights from {fine_tuned_weights}")
126
+ state_dict = torch.load(fine_tuned_weights, map_location=device)
127
+ self._predictor.model.load_state_dict(state_dict)
128
+ print("Fine-tuned weights loaded successfully")
129
+ except FileNotFoundError:
130
+ print("Warning: Fine-tuned weights not found. Using base SAM2 model.")
131
+ print("To use fine-tuned model, upload VREyeSAM_uncertainity_best.torch to Space Files")
132
+ except Exception as e:
133
+ print(f"Warning: Could not load fine-tuned weights: {e}")
134
+ print("Continuing with base SAM2 model")
135
 
136
  # Model is now loaded - weights are NOT accessible to clients
137
  self._predictor.model.eval()
138
+ print("Model loaded successfully")
139
 
140
  return True
141
  except Exception as e:
142
+ print(f"Error loading model: {e}")
143
+ import traceback
144
+ traceback.print_exc()
145
+ raise RuntimeError(f"Model initialization failed: {str(e)}") from e
146
 
147
  def predict(self, image: np.ndarray, num_samples: int = 30) -> Tuple[np.ndarray, np.ndarray]:
148
  """