| import sys |
| import torch |
| import torch.utils._triton |
|
|
| |
| def fake_is_available(): |
| return True |
| def fake_device_capability(*args, **kwargs): |
| return (8, 0) |
| def fake_current_device(): |
| return 0 |
| def fake_device_count(): |
| return 1 |
| def has_triton(): |
| return False |
| def get_fake_stream(*args, **kwargs): |
| return 0 |
|
|
| sys.modules["torch"].cuda.is_available = fake_is_available |
| sys.modules["torch"].cuda.get_device_capability = fake_device_capability |
| sys.modules["torch"].cuda.current_device = fake_current_device |
| sys.modules["torch"].cuda.device_count = fake_device_count |
| sys.modules["torch.utils._triton"].has_triton = has_triton |
| sys.modules["torch._C"]._cuda_getCurrentRawStream = get_fake_stream |
| |
|
|
| import streamlit as st |
| from unsloth.chat_templates import get_chat_template, CHAT_TEMPLATES |
| from unsloth_zoo.dataset_utils import train_on_responses_only |
| from transformers import AutoProcessor |
| class DummyArgs: |
| pass |
|
|
| class DummyDataset: |
| def __init__(self, example): |
| self.example = [example] |
|
|
| def map(self, function, *args, **kwargs): |
| self.example[0].update(function(self.example[0])) |
| return self |
| |
| def __len__(self): |
| return 1 |
| |
| def __getitem__(self, idx): |
| return self.example[idx] |
|
|
| class DummyTrainer: |
| pass |
|
|
| st.title('Train With Response Only Analyzer') |
| |
| _, col, _ = st.columns([1, 1, 1]) |
| col.image("https://raw.githubusercontent.com/unslothai/unsloth/main/images/made%20with%20unsloth.png", width=200) |
|
|
| model = st.text_input("Enter HuggingFace model name", st.query_params.get("model", "Qwen/Qwen2-VL-7B-Instruct")) |
| processor = AutoProcessor.from_pretrained(model, trust_remote_code=True) |
| text_tokenizer = processor if not hasattr(processor, "tokenizer") else processor.tokenizer |
| chat_template_predefined = st.query_params.get("chat_template_idx", None) |
| possible_templates = ["model_default"] + sorted(CHAT_TEMPLATES.keys()) |
| if chat_template_predefined is not None: |
| chat_template_idx = possible_templates.index(chat_template_predefined) |
| else: |
| chat_template_idx = 0 |
|
|
| chat_template_key = st.selectbox("Select chat template", possible_templates, index=chat_template_idx) |
|
|
| if chat_template_key == "model_default": |
| chat_template = None |
| else: |
| chat_template = CHAT_TEMPLATES.get(chat_template_key)[0] |
|
|
| if chat_template is None: |
| chat_template = text_tokenizer.chat_template |
| if chat_template is None: |
| |
| import warnings |
| st.warning("Chat template not found in the tokenizer. Not using any chat template.") |
|
|
| with st.expander("Click to see the chat template"): |
| st.markdown("#### Chat Template (in Jinja2 format)") |
| st.code(chat_template, language="jinja2") |
|
|
| sample = {"conversations": [{'content': 'Do you like Unsloth?', 'role': 'user'}, {'content': 'Yes', 'role': 'assistant'}, {'content': 'Will you star them on GitHub?', 'role': 'user'}, {'content': 'Sure!', 'role': 'assistant'}]} |
|
|
| message_sample = sample.get("conversations", "") |
| message = st.text_area("Enter your message here", st.query_params.get("message", str(message_sample))) |
|
|
| try: |
| message = eval(message) |
| except: |
| pass |
|
|
| if chat_template is not None: |
| converted_message = text_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False, chat_template=chat_template) |
| else: |
| converted_message = message |
|
|
| st.markdown("#### Original Message") |
| st.code(converted_message, language="html") |
|
|
| instruction_part = st.text_input("Enter instruction Part", st.query_params.get("instruction_part", "<|im_start|>user")) |
| response_part = st.text_input("Enter response Part", st.query_params.get("response_part", "<|im_start|>assistant")) |
|
|
| trainer = DummyTrainer() |
| trainer.train_dataset = DummyDataset({"input_ids": [text_tokenizer.encode(converted_message)]}) |
| trainer.tokenizer = text_tokenizer |
| trainer.args = DummyArgs() |
| trainer.args.dataset_kwargs = {"skip_prepare_dataset": False} |
|
|
| trainer = train_on_responses_only(trainer, instruction_part, response_part) |
| ids = trainer.train_dataset[0]["labels"][0] |
| mask = text_tokenizer.encode("x", add_special_tokens = False)[0] |
| masked_text = text_tokenizer.decode([mask if x == -100 else x for x in ids]) |
|
|
| st.markdown("#### Masked Prompt ('x' is the mask token)") |
| st.code(masked_text, language="html") |
|
|
| st.markdown("#### Your Unsloth code snippet") |
| code = f"""from unsloth.chat_templates import train_on_responses_only |
| trainer = train_on_responses_only( |
| trainer, |
| instruction_part = "{instruction_part}", |
| response_part = "{response_part}", |
| ) |
| """ |
| st.code(code, language="python") |
|
|
| st.markdown("#### You may share the following URL with others to show them the results") |
| |
| url = "https://zeel-twro.hf.space" |
| params = { |
| "model": model, |
| "message": message, |
| "instruction_part": instruction_part, |
| "response_part": response_part, |
| "chat_template_idx": chat_template_key, |
| } |
| import urllib.parse |
| url = url + "?" + urllib.parse.urlencode(params) |
| st.markdown(f"`{url}`") |