Hypernova823 commited on
Commit
2a5d903
Β·
verified Β·
1 Parent(s): 3faf505

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +24 -15
src/streamlit_app.py CHANGED
@@ -110,7 +110,7 @@ div[data-testid="stFileUploader"] * {
110
 
111
  /* ═══════════════════════════════════════════════════════════════ */
112
 
113
- /* Stats & Output Box */
114
  .stat-card { background: #000; padding: 15px; border-radius: 4px; text-align: center; border: 1px solid rgba(143, 245, 255, 0.1); margin-bottom: 10px; }
115
  .stat-val { color: #8ff5ff; font-size: 24px; font-weight: 700; font-family: 'Space Grotesk'; }
116
  .stat-lbl { font-size: 9px; color: #46484d; text-transform: uppercase; letter-spacing: 2px; }
@@ -149,9 +149,12 @@ def load_vision_engine():
149
  @st.cache_resource(show_spinner=False)
150
  def load_trocr_model(model_path):
151
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
152
  proc = TrOCRProcessor.from_pretrained(model_path)
153
 
154
  if os.path.exists(model_path):
 
155
  config = VisionEncoderDecoderConfig.from_pretrained(model_path)
156
  model = VisionEncoderDecoderModel(config)
157
  safe_path = os.path.join(model_path, "model.safetensors")
@@ -163,35 +166,41 @@ def load_trocr_model(model_path):
163
  else:
164
  model.load_state_dict(torch.load(bin_path, map_location="cpu", weights_only=True), strict=False)
165
  else:
 
166
  model = VisionEncoderDecoderModel.from_pretrained(model_path)
167
 
168
  # Push standard registered parameters/buffers to device
169
  model.to(device)
170
 
171
- # ─── AGGRESSIVE ROGUE TENSOR MIGRATION (WITH META SAFEGUARD) ───
172
- # Snapshot dict to avoid runtime size change errors while finding unregistered weights
173
  for module in model.modules():
174
- # 1. Double check parameters (Ensure it's not a meta tensor)
175
  for name, param in list(module._parameters.items()):
176
- if param is not None and not param.is_meta:
177
- module._parameters[name] = torch.nn.Parameter(param.to(device))
178
- # 2. Double check buffers
 
 
179
  for name, buf in list(module._buffers.items()):
180
- if buf is not None and not buf.is_meta:
181
- module._buffers[name] = buf.to(device)
182
- # 3. Hunt down unregistered raw tensors (Fixes the TrOCR positional weights crash)
 
 
183
  for name, attr in list(module.__dict__.items()):
184
- if isinstance(attr, torch.Tensor) and not attr.is_meta:
185
- setattr(module, name, attr.to(device))
 
186
 
187
- # If on GPU, push the entire model to Half precision
188
  if device.type == "cuda":
189
  model = model.half()
190
  # Ensure those unregistered raw tensors are ALSO converted to half precision safely
191
  for module in model.modules():
192
  for name, attr in list(module.__dict__.items()):
193
- if isinstance(attr, torch.Tensor) and not attr.is_meta and attr.is_floating_point():
194
- setattr(module, name, attr.half())
 
195
 
196
  model.eval()
197
  return proc, model, device
 
110
 
111
  /* ═══════════════════════════════════════════════════════════════ */
112
 
113
+ /* Stats & DYNAMIC Output Box */
114
  .stat-card { background: #000; padding: 15px; border-radius: 4px; text-align: center; border: 1px solid rgba(143, 245, 255, 0.1); margin-bottom: 10px; }
115
  .stat-val { color: #8ff5ff; font-size: 24px; font-weight: 700; font-family: 'Space Grotesk'; }
116
  .stat-lbl { font-size: 9px; color: #46484d; text-transform: uppercase; letter-spacing: 2px; }
 
149
  @st.cache_resource(show_spinner=False)
150
  def load_trocr_model(model_path):
151
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
+
153
+ # Hugging Face natively downloads the processor via the repo ID
154
  proc = TrOCRProcessor.from_pretrained(model_path)
155
 
156
  if os.path.exists(model_path):
157
+ # Local Loading Logic
158
  config = VisionEncoderDecoderConfig.from_pretrained(model_path)
159
  model = VisionEncoderDecoderModel(config)
160
  safe_path = os.path.join(model_path, "model.safetensors")
 
166
  else:
167
  model.load_state_dict(torch.load(bin_path, map_location="cpu", weights_only=True), strict=False)
168
  else:
169
+ # Cloud Loading Logic: Natively pulls your model from the Hugging Face Hub
170
  model = VisionEncoderDecoderModel.from_pretrained(model_path)
171
 
172
  # Push standard registered parameters/buffers to device
173
  model.to(device)
174
 
175
+ # ─── BULLETPROOF TENSOR MIGRATION (WITH EXCEPTIONS CATCHER) ───
 
176
  for module in model.modules():
177
+ # 1. Double check parameters safely
178
  for name, param in list(module._parameters.items()):
179
+ if param is not None:
180
+ try: module._parameters[name] = torch.nn.Parameter(param.to(device))
181
+ except (NotImplementedError, RuntimeError): pass
182
+
183
+ # 2. Double check buffers safely
184
  for name, buf in list(module._buffers.items()):
185
+ if buf is not None:
186
+ try: module._buffers[name] = buf.to(device)
187
+ except (NotImplementedError, RuntimeError): pass
188
+
189
+ # 3. Hunt down unregistered raw tensors safely
190
  for name, attr in list(module.__dict__.items()):
191
+ if isinstance(attr, torch.Tensor):
192
+ try: setattr(module, name, attr.to(device))
193
+ except (NotImplementedError, RuntimeError): pass
194
 
195
+ # If on GPU, push the entire model to Half precision safely
196
  if device.type == "cuda":
197
  model = model.half()
198
  # Ensure those unregistered raw tensors are ALSO converted to half precision safely
199
  for module in model.modules():
200
  for name, attr in list(module.__dict__.items()):
201
+ if isinstance(attr, torch.Tensor) and attr.is_floating_point():
202
+ try: setattr(module, name, attr.half())
203
+ except (NotImplementedError, RuntimeError): pass
204
 
205
  model.eval()
206
  return proc, model, device