Jaavid25 commited on
Commit
e0ebeb3
·
verified ·
1 Parent(s): 2dd2b06

Update modeling_apex.py

Browse files
Files changed (1) hide show
  1. modeling_apex.py +7 -6
modeling_apex.py CHANGED
@@ -65,12 +65,13 @@ class APEXModel(PreTrainedModel):
65
  self.mert_processor = AutoProcessor.from_pretrained(
66
  config.mert_model_name, trust_remote_code=True
67
  )
68
- self.mert = AutoModel.from_pretrained(
69
- config.mert_model_name,
70
- trust_remote_code = True,
71
- device_map = None,
72
- low_cpu_mem_usage = False
73
- )
 
74
  self.mert.eval()
75
  for param in self.mert.parameters():
76
  param.requires_grad = False
 
65
  self.mert_processor = AutoProcessor.from_pretrained(
66
  config.mert_model_name, trust_remote_code=True
67
  )
68
+ with torch.device('cpu'):
69
+ self.mert = AutoModel.from_pretrained(
70
+ config.mert_model_name,
71
+ trust_remote_code = True,
72
+ device_map = None,
73
+ low_cpu_mem_usage = False
74
+ )
75
  self.mert.eval()
76
  for param in self.mert.parameters():
77
  param.requires_grad = False