| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
| def compute_memory_used_pct(device): |
| memory_used = torch.cuda.max_memory_allocated(device) / (1024**3) |
| memory_pct = ( |
| memory_used |
| / (torch.cuda.get_device_properties(device).total_memory / (1024**3)) |
| * 100 |
| ) |
| return memory_pct |
|
|
| model_path = "./out" |
|
|
| n_ahead = 8 |
| n_ahead_talk = 4 |
| merged_talk_heads = True |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| max_thoughts=n_ahead + n_ahead_talk + 1, |
| merged_talk_heads=merged_talk_heads, |
| merged_lm_and_talk_heads=False, |
| merged_lm_and_think_heads=True, |
| use_concat_talk_head=True, |
| use_shallow_think=True, |
| use_shallow_talk=False, |
| use_complex_think_head=False, |
| use_complex_talk_head=True, |
| use_weighted_talk_head=True, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model.tokenizer = tokenizer |
|
|
| model.use_end_thought_token = True |
| model.use_start_thought_token = True |
| model.wandb_enabled = True |
| model.n_ahead = n_ahead |
| model.n_passes = 2 |
| model.eval_mode = True |
| model.first_run = False |
| model.kill_after = 100 |
| model.rm_initialized = True |
| model.original_mode = False |
|
|
| |
| def custom_generate(model, input_ids, attention_mask, max_new_tokens, streamer, **kwargs): |
| with torch.no_grad(): |
| finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device) |
| for cur_token_idx in range(max_new_tokens): |
| |
| new_ids = model( |
| input_ids[~finished_generating], |
| attention_mask=attention_mask[~finished_generating] |
| )['logits'] |
| |
| new_ids[:, :, model.tokenizer.vocab_size:] = -float("inf") |
| for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): |
| |
| base_answer_ids = input_ids[answer_idx] |
| new_answer_ids = new_ids[list_idx] |
| last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() |
|
|
| new_ids_sampled = torch.multinomial( |
| torch.nn.functional.softmax(new_answer_ids[last_token_idx] / kwargs.get("temperature", 1.0), dim=-1), 1) |
| |
| if last_token_idx + 1 >= len(base_answer_ids): |
| |
| new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long, |
| device=input_ids.device) |
| input_ids = torch.cat([input_ids, new_padding], dim=-1) |
| attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) |
| attention_mask[answer_idx, last_token_idx + 1] = 1 |
| input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled |
| if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id: |
| finished_generating[answer_idx] = 1 |
| |
| if new_ids_sampled == model.tokenizer.convert_tokens_to_ids("<|/assistant|>"): |
| finished_generating[answer_idx] = 1 |
| if finished_generating.all(): |
| break |
| streamer.put(new_ids_sampled) |
| return input_ids, attention_mask |
|
|
| |
| prompt_template = "[INST] {prompt} [/INST]" |
|
|
| prompt = "You're standing on the surface of the Earth. "\ |
| "You walk one mile south, one mile west and one mile north. "\ |
| "You end up exactly where you started. Where are you?" |
|
|
| |
| tokens = tokenizer(prompt_template.format(prompt=prompt), return_tensors='pt').input_ids.to(model.device) |
|
|
| |
| attention_mask = torch.where(tokens != tokenizer.pad_token_id, torch.ones_like(tokens), torch.zeros_like(tokens)).to(model.device) |
|
|
| streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) |
|
|
| |
| output_ids, _ = custom_generate( |
| model, |
| input_ids=tokens, |
| attention_mask=attention_mask, |
| max_new_tokens=512, |
| streamer=streamer, |
| temperature=0.9, |
| ) |
|
|
| generated_text = "" |
|
|
| print() |
|
|
| |
| torch.cuda.empty_cache() |
|
|