ggapar commited on
Commit
0b8b078
·
verified ·
1 Parent(s): 4617712

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
  import gradio as gr
16
 
17
  from collections import Counter
18
- from transformers import AutoModelForCausalLM, AutoTokenizer
19
  from peft import PeftModel
20
 
21
  # ================================================================
@@ -54,15 +54,15 @@ if tokenizer.pad_token is None:
54
  tokenizer.pad_token = tokenizer.eos_token
55
 
56
  print("Loading base model (CPU, float32)...")
57
- # CPU Basic tidak support bfloat16/4-bit pakai float32
58
- # Model akan lebih lambat (~2-5 menit/request) tapi tetap fungsional
59
- base_model = AutoModelForCausalLM.from_pretrained(
60
  BASE_MODEL,
61
- torch_dtype = torch.float32, # ← CPU butuh float32
62
  device_map = "cpu",
63
  trust_remote_code = True,
64
  token = HF_TOKEN or None,
65
- low_cpu_mem_usage = True, # ← hemat RAM saat loading
66
  )
67
 
68
  print("Loading LoRA adapter...")
@@ -360,4 +360,4 @@ pip install gradio_client
360
  outputs=outputs)
361
 
362
  if __name__ == "__main__":
363
- demo.launch()
 
15
  import gradio as gr
16
 
17
  from collections import Counter
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Mistral3ForConditionalGeneration
19
  from peft import PeftModel
20
 
21
  # ================================================================
 
54
  tokenizer.pad_token = tokenizer.eos_token
55
 
56
  print("Loading base model (CPU, float32)...")
57
+ # Ministral-3-8B menggunakan Mistral3 architecture (VLM)
58
+ # Harus pakai Mistral3ForConditionalGeneration, bukan AutoModelForCausalLM
59
+ base_model = Mistral3ForConditionalGeneration.from_pretrained(
60
  BASE_MODEL,
61
+ dtype = torch.float32, # ← CPU butuh float32
62
  device_map = "cpu",
63
  trust_remote_code = True,
64
  token = HF_TOKEN or None,
65
+ low_cpu_mem_usage = True,
66
  )
67
 
68
  print("Loading LoRA adapter...")
 
360
  outputs=outputs)
361
 
362
  if __name__ == "__main__":
363
+ demo.launch()