apehex commited on
Commit
f5e040b
·
1 Parent(s): 65fe064

Catch the failure to load the model too.

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -44,7 +44,7 @@ def fetch_model() -> object:
44
  if hasattr(_MODEL, 'name_or_path'):
45
  gradio.Info(title='Info', message='Switched to `{}`.'.format(getattr(_MODEL, 'name_or_path', 'None')), duration=2)
46
  else:
47
- gradio.Warning(title='Warning', message='The GPU time slot expired before the model could be loaded.', duration=4)
48
  # model object or `None`
49
  return _MODEL
50
 
@@ -117,16 +117,17 @@ def compute_logits(
117
  export_str: str,
118
  ) -> object:
119
  __logits = None
120
- # load the model inside the GPU wrapper (not before)
121
- __model = fetch_model()
122
  # the allocation might expire before the calculations are finished
123
  try:
 
 
 
124
  __logits = _ux.update_logits_state(
125
  indices_arr=indices_arr,
126
  export_str=export_str,
127
  model_obj=__model)
128
  except:
129
- gradio.Warning(title='Warning', message='Aborted because the GPU slot expired.', duration=4)
130
  # tensor or None
131
  return __logits
132
 
 
44
  if hasattr(_MODEL, 'name_or_path'):
45
  gradio.Info(title='Info', message='Switched to `{}`.'.format(getattr(_MODEL, 'name_or_path', 'None')), duration=2)
46
  else:
47
+ gradio.Warning(title='Warning', message='Aborted: the GPU slot expired.', duration=4)
48
  # model object or `None`
49
  return _MODEL
50
 
 
117
  export_str: str,
118
  ) -> object:
119
  __logits = None
 
 
120
  # the allocation might expire before the calculations are finished
121
  try:
122
+ # load the model inside the GPU wrapper (not before)
123
+ __model = fetch_model()
124
+ # compute the raw logits
125
  __logits = _ux.update_logits_state(
126
  indices_arr=indices_arr,
127
  export_str=export_str,
128
  model_obj=__model)
129
  except:
130
+ gradio.Warning(title='Warning', message='Aborted: the GPU slot expired.', duration=4)
131
  # tensor or None
132
  return __logits
133