artificialguybr commited on
Commit
99a0cab
·
verified ·
1 Parent(s): e8e5451

Switch to text-only causal LM loading

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
5
  import spaces
6
  import torch
7
  from transformers import (
8
- AutoModelForImageTextToText,
9
  AutoTokenizer,
10
  BitsAndBytesConfig,
11
  TextIteratorStreamer,
@@ -29,6 +29,7 @@ PLACEHOLDER = (
29
  MAX_INPUT_TOKENS = 16384
30
  DEFAULT_MAX_NEW_TOKENS = 4096
31
  MAX_NEW_TOKENS = 8192
 
32
 
33
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
34
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -40,21 +41,27 @@ BNB_CONFIG = BitsAndBytesConfig(
40
  bnb_4bit_compute_dtype=torch.bfloat16,
41
  )
42
 
43
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
44
  if tokenizer.pad_token is None:
45
  tokenizer.pad_token = tokenizer.eos_token
46
 
47
- model = AutoModelForImageTextToText.from_pretrained(
48
  MODEL_ID,
49
  trust_remote_code=True,
50
- device_map="auto",
51
- torch_dtype=torch.bfloat16,
 
52
  quantization_config=BNB_CONFIG,
53
  attn_implementation="sdpa",
 
54
  )
55
  model.eval()
56
 
57
 
 
 
 
 
58
  def estimate_duration(
59
  message,
60
  history,
@@ -113,7 +120,7 @@ def stream_chat(
113
  return_tensors="pt",
114
  truncation=True,
115
  max_length=MAX_INPUT_TOKENS,
116
- ).to(model.device)
117
 
118
  streamer = TextIteratorStreamer(
119
  tokenizer,
 
5
  import spaces
6
  import torch
7
  from transformers import (
8
+ AutoModelForCausalLM,
9
  AutoTokenizer,
10
  BitsAndBytesConfig,
11
  TextIteratorStreamer,
 
29
  MAX_INPUT_TOKENS = 16384
30
  DEFAULT_MAX_NEW_TOKENS = 4096
31
  MAX_NEW_TOKENS = 8192
32
+ HF_TOKEN = os.environ.get("HF_TOKEN")
33
 
34
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
35
  torch.backends.cuda.matmul.allow_tf32 = True
 
41
  bnb_4bit_compute_dtype=torch.bfloat16,
42
  )
43
 
44
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
45
  if tokenizer.pad_token is None:
46
  tokenizer.pad_token = tokenizer.eos_token
47
 
48
+ model = AutoModelForCausalLM.from_pretrained(
49
  MODEL_ID,
50
  trust_remote_code=True,
51
+ token=HF_TOKEN,
52
+ device_map={"": 0},
53
+ dtype=torch.bfloat16,
54
  quantization_config=BNB_CONFIG,
55
  attn_implementation="sdpa",
56
+ low_cpu_mem_usage=True,
57
  )
58
  model.eval()
59
 
60
 
61
+ def model_input_device():
62
+ return next(model.parameters()).device
63
+
64
+
65
  def estimate_duration(
66
  message,
67
  history,
 
120
  return_tensors="pt",
121
  truncation=True,
122
  max_length=MAX_INPUT_TOKENS,
123
+ ).to(model_input_device())
124
 
125
  streamer = TextIteratorStreamer(
126
  tokenizer,