Davy592 commited on
Commit
20503f2
·
1 Parent(s): 956e720

Removed helper method

Browse files
nygaardcodecommentclassification/api/models.py CHANGED
@@ -151,19 +151,9 @@ class ModelRegistry:
151
  lang.upper(),
152
  catboost_run_name,
153
  )
154
- # Configure torch to map GPU tensors to CPU during deserialization
155
- original_map_location = getattr(torch, '_utils', {}).get('_rebuild_device_tensor_from_numpy', None)
156
- try:
157
- # Set the map_location handler for torch.load to handle GPU-trained models
158
- model = mlflow.sklearn.load_model(model_uri)
159
- except Exception as load_error:
160
- logger.warning(
161
- "[%s] Initial model load failed, attempting CPU mapping: %s",
162
- lang.upper(),
163
- load_error
164
- )
165
- # Fallback: attempt to patch torch to use CPU for deserialization
166
- model = self._load_model_with_cpu_mapping(model_uri, lang)
167
 
168
  # Load the sentence transformer embedder from MLflow
169
  embedder_uri = f"runs:/{embedder_run_id}/model_{lang}"
@@ -172,16 +162,7 @@ class ModelRegistry:
172
  lang.upper(),
173
  embedder_run_name,
174
  )
175
- try:
176
- embedder = mlflow.sklearn.load_model(embedder_uri)
177
- except Exception as embed_error:
178
- logger.warning(
179
- "[%s] Initial embedder load failed, attempting CPU mapping: %s",
180
- lang.upper(),
181
- embed_error
182
- )
183
- embedder = self._load_model_with_cpu_mapping(embedder_uri, lang)
184
-
185
  if hasattr(embedder, "to"):
186
  embedder.to("cpu")
187
 
@@ -200,32 +181,6 @@ class ModelRegistry:
200
  except Exception as e:
201
  logger.error("[%s] Error loading models: %s", lang.upper(), e)
202
 
203
- def _load_model_with_cpu_mapping(self, model_uri: str, lang: str) -> Any:
204
- """Load a model with explicit CPU device mapping to handle GPU-trained models.
205
-
206
- This method attempts to load models trained on GPU on CPU-only machines
207
- by setting torch's device context and using torch.load with map_location.
208
-
209
- Args:
210
- model_uri: MLflow model URI
211
- lang: Programming language code for logging
212
-
213
- Returns:
214
- Loaded model object mapped to CPU
215
-
216
- Raises:
217
- Exception: If all loading attempts fail
218
- """
219
- # Set torch to use CPU for any tensor operations
220
- torch.set_default_device("cpu")
221
-
222
- try:
223
- model = mlflow.sklearn.load_model(model_uri)
224
- return model
225
- finally:
226
- # Reset to default behavior
227
- torch.set_default_device(None)
228
-
229
  def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]:
230
  """Retrieve a loaded model entry by language and type.
231
 
 
151
  lang.upper(),
152
  catboost_run_name,
153
  )
154
+ # Set default torch device to CPU before loading
155
+ torch.set_default_device("cpu")
156
+ model = mlflow.sklearn.load_model(model_uri)
 
 
 
 
 
 
 
 
 
 
157
 
158
  # Load the sentence transformer embedder from MLflow
159
  embedder_uri = f"runs:/{embedder_run_id}/model_{lang}"
 
162
  lang.upper(),
163
  embedder_run_name,
164
  )
165
+ embedder = mlflow.sklearn.load_model(embedder_uri)
 
 
 
 
 
 
 
 
 
166
  if hasattr(embedder, "to"):
167
  embedder.to("cpu")
168
 
 
181
  except Exception as e:
182
  logger.error("[%s] Error loading models: %s", lang.upper(), e)
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def get_model(self, language: str, model_type: str) -> Optional[Dict[str, Any]]:
185
  """Retrieve a loaded model entry by language and type.
186