| import os |
| import sys |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
| os.environ['GRADIO_ANALYTICS_ENABLED'] = '0' |
| sys.path.insert(0, os.getcwd()) |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts')) |
| import subprocess |
| import gradio as gr |
| from PIL import Image |
| import torch |
| import uuid |
| import shutil |
| import json |
| import yaml |
| from slugify import slugify |
| from transformers import AutoProcessor, AutoModelForCausalLM |
| from gradio_logsview import LogsView, LogsViewRunner |
| from huggingface_hub import hf_hub_download, HfApi |
| from library import flux_train_utils, huggingface_util |
| from argparse import Namespace |
| import train_network |
| import toml |
| import re |
| MAX_IMAGES = 150 |
|
|
| with open('models.yaml', 'r') as file: |
| models = yaml.safe_load(file) |
|
|
| def readme(base_model, lora_name, instance_prompt, sample_prompts): |
|
|
| |
| model_config = models[base_model] |
| model_file = model_config["file"] |
| base_model_name = model_config["base"] |
| license = None |
| license_name = None |
| license_link = None |
| license_items = [] |
| if "license" in model_config: |
| license = model_config["license"] |
| license_items.append(f"license: {license}") |
| if "license_name" in model_config: |
| license_name = model_config["license_name"] |
| license_items.append(f"license_name: {license_name}") |
| if "license_link" in model_config: |
| license_link = model_config["license_link"] |
| license_items.append(f"license_link: {license_link}") |
| license_str = "\n".join(license_items) |
| print(f"license_items={license_items}") |
| print(f"license_str = {license_str}") |
|
|
| |
| tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ] |
|
|
| |
| widgets = [] |
| sample_image_paths = [] |
| output_name = slugify(lora_name) |
| samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample") |
| try: |
| for filename in os.listdir(samples_dir): |
| |
| match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename) |
| if match: |
| steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3)) |
| sample_image_paths.append((steps, index, f"sample/{filename}")) |
|
|
| |
| sample_image_paths.sort(key=lambda x: x[0], reverse=True) |
|
|
| final_sample_image_paths = sample_image_paths[:len(sample_prompts)] |
| final_sample_image_paths.sort(key=lambda x: x[1]) |
| for i, prompt in enumerate(sample_prompts): |
| _, _, image_path = final_sample_image_paths[i] |
| widgets.append( |
| { |
| "text": prompt, |
| "output": { |
| "url": image_path |
| }, |
| } |
| ) |
| except: |
| print(f"no samples") |
| dtype = "torch.bfloat16" |
| |
| readme_content = f"""--- |
| tags: |
| {yaml.dump(tags, indent=4).strip()} |
| {"widget:" if os.path.isdir(samples_dir) else ""} |
| {yaml.dump(widgets, indent=4).strip() if widgets else ""} |
| base_model: {base_model_name} |
| {"instance_prompt: " + instance_prompt if instance_prompt else ""} |
| {license_str} |
| --- |
| |
| # {lora_name} |
| |
| A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym) |
| |
| <Gallery /> |
| |
| ## Trigger words |
| |
| {"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."} |
| |
| ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc. |
| |
| Weights for this model are available in Safetensors format. |
| |
| """ |
| return readme_content |
|
|
| def account_hf(): |
| try: |
| with open("HF_TOKEN", "r") as file: |
| token = file.read() |
| api = HfApi(token=token) |
| try: |
| account = api.whoami() |
| return { "token": token, "account": account['name'] } |
| except: |
| return None |
| except: |
| return None |
|
|
| """ |
| hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner]) |
| """ |
| def logout_hf(): |
| os.remove("HF_TOKEN") |
| global current_account |
| current_account = account_hf() |
| print(f"current_account={current_account}") |
| return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False) |
|
|
|
|
| """ |
| hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner]) |
| """ |
| def login_hf(hf_token): |
| api = HfApi(token=hf_token) |
| try: |
| account = api.whoami() |
| if account != None: |
| if "name" in account: |
| with open("HF_TOKEN", "w") as file: |
| file.write(hf_token) |
| global current_account |
| current_account = account_hf() |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True) |
| return gr.update(), gr.update(), gr.update(), gr.update() |
| except: |
| print(f"incorrect hf_token") |
| return gr.update(), gr.update(), gr.update(), gr.update() |
|
|
| def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token): |
| src = lora_rows |
| repo_id = f"{repo_owner}/{repo_name}" |
| gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None) |
| args = Namespace( |
| huggingface_repo_id=repo_id, |
| huggingface_repo_type="model", |
| huggingface_repo_visibility=repo_visibility, |
| huggingface_path_in_repo="", |
| huggingface_token=hf_token, |
| async_upload=False |
| ) |
| print(f"upload_hf args={args}") |
| huggingface_util.upload(args=args, src=src) |
| gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None) |
|
|
| def load_captioning(uploaded_files, concept_sentence): |
| uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')] |
| txt_files = [file for file in uploaded_files if file.endswith('.txt')] |
| txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files} |
| updates = [] |
| if len(uploaded_images) <= 1: |
| raise gr.Error( |
| "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)" |
| ) |
| elif len(uploaded_images) > MAX_IMAGES: |
| raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training") |
| |
| |
| updates.append(gr.update(visible=True)) |
| |
| for i in range(1, MAX_IMAGES + 1): |
| |
| visible = i <= len(uploaded_images) |
|
|
| |
| updates.append(gr.update(visible=visible)) |
|
|
| |
| image_value = uploaded_images[i - 1] if visible else None |
| updates.append(gr.update(value=image_value, visible=visible)) |
|
|
| corresponding_caption = False |
| if(image_value): |
| base_name = os.path.splitext(os.path.basename(image_value))[0] |
| if base_name in txt_files_dict: |
| with open(txt_files_dict[base_name], 'r') as file: |
| corresponding_caption = file.read() |
|
|
| |
| text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None |
| updates.append(gr.update(value=text_value, visible=visible)) |
|
|
| |
| updates.append(gr.update(visible=True)) |
| updates.append(gr.update(visible=True)) |
|
|
| return updates |
|
|
| def hide_captioning(): |
| return gr.update(visible=False), gr.update(visible=False) |
|
|
| def resize_image(image_path, output_path, size): |
| with Image.open(image_path) as img: |
| width, height = img.size |
| if width < height: |
| new_width = size |
| new_height = int((size/width) * height) |
| else: |
| new_height = size |
| new_width = int((size/height) * width) |
| print(f"resize {image_path} : {new_width}x{new_height}") |
| img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| img_resized.save(output_path) |
|
|
| def create_dataset(destination_folder, size, *inputs): |
| print("Creating dataset") |
| images = inputs[0] |
| if not os.path.exists(destination_folder): |
| os.makedirs(destination_folder) |
|
|
| for index, image in enumerate(images): |
| |
| new_image_path = shutil.copy(image, destination_folder) |
|
|
| |
| ext = os.path.splitext(new_image_path)[-1].lower() |
| if ext == '.txt': |
| continue |
|
|
| |
| resize_image(new_image_path, new_image_path, size) |
|
|
| |
|
|
| original_caption = inputs[index + 1] |
|
|
| image_file_name = os.path.basename(new_image_path) |
| caption_file_name = os.path.splitext(image_file_name)[0] + ".txt" |
| caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name)) |
| print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}") |
| |
| if os.path.exists(caption_path): |
| print(f"{caption_path} already exists. use the existing .txt file") |
| else: |
| print(f"{caption_path} create a .txt caption file") |
| with open(caption_path, 'w') as file: |
| file.write(original_caption) |
|
|
| print(f"destination_folder {destination_folder}") |
| return destination_folder |
|
|
|
|
| def run_captioning(images, concept_sentence, *captions): |
| print(f"run_captioning") |
| print(f"concept sentence {concept_sentence}") |
| print(f"captions {captions}") |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"device={device}") |
| torch_dtype = torch.float16 |
| model = AutoModelForCausalLM.from_pretrained( |
| "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True |
| ).to(device) |
| processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True) |
|
|
| captions = list(captions) |
| for i, image_path in enumerate(images): |
| print(captions[i]) |
| if isinstance(image_path, str): |
| image = Image.open(image_path).convert("RGB") |
|
|
| prompt = "<DETAILED_CAPTION>" |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) |
| print(f"inputs {inputs}") |
|
|
| generated_ids = model.generate( |
| input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 |
| ) |
| print(f"generated_ids {generated_ids}") |
|
|
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| print(f"generated_text: {generated_text}") |
| parsed_answer = processor.post_process_generation( |
| generated_text, task=prompt, image_size=(image.width, image.height) |
| ) |
| print(f"parsed_answer = {parsed_answer}") |
| caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "") |
| print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}") |
| if concept_sentence: |
| caption_text = f"{concept_sentence} {caption_text}" |
| captions[i] = caption_text |
|
|
| yield captions |
| model.to("cpu") |
| del model |
| del processor |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def recursive_update(d, u): |
| for k, v in u.items(): |
| if isinstance(v, dict) and v: |
| d[k] = recursive_update(d.get(k, {}), v) |
| else: |
| d[k] = v |
| return d |
|
|
| def download(base_model): |
| model = models[base_model] |
| model_file = model["file"] |
| repo = model["repo"] |
|
|
| |
| if base_model == "flux-dev" or base_model == "flux-schnell": |
| unet_folder = "models/unet" |
| else: |
| unet_folder = f"models/unet/{repo}" |
| unet_path = os.path.join(unet_folder, model_file) |
| if not os.path.exists(unet_path): |
| os.makedirs(unet_folder, exist_ok=True) |
| gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None) |
| print(f"download {base_model}") |
| hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file) |
|
|
| |
| vae_folder = "models/vae" |
| vae_path = os.path.join(vae_folder, "ae.sft") |
| if not os.path.exists(vae_path): |
| os.makedirs(vae_folder, exist_ok=True) |
| gr.Info(f"Downloading vae") |
| print(f"downloading ae.sft...") |
| hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft") |
|
|
| |
| clip_folder = "models/clip" |
| clip_l_path = os.path.join(clip_folder, "clip_l.safetensors") |
| if not os.path.exists(clip_l_path): |
| os.makedirs(clip_folder, exist_ok=True) |
| gr.Info(f"Downloading clip...") |
| print(f"download clip_l.safetensors") |
| hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors") |
|
|
| |
| t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors") |
| if not os.path.exists(t5xxl_path): |
| print(f"download t5xxl_fp16.safetensors") |
| gr.Info(f"Downloading t5xxl...") |
| hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors") |
|
|
|
|
| def resolve_path(p): |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| norm_path = os.path.normpath(os.path.join(current_dir, p)) |
| return f"\"{norm_path}\"" |
| def resolve_path_without_quotes(p): |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| norm_path = os.path.normpath(os.path.join(current_dir, p)) |
| return norm_path |
|
|
| def gen_sh( |
| base_model, |
| output_name, |
| resolution, |
| seed, |
| workers, |
| learning_rate, |
| network_dim, |
| max_train_epochs, |
| save_every_n_epochs, |
| timestep_sampling, |
| guidance_scale, |
| vram, |
| sample_prompts, |
| sample_every_n_steps, |
| *advanced_components |
| ): |
|
|
| print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}") |
|
|
| output_dir = resolve_path(f"outputs/{output_name}") |
| sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt") |
|
|
| line_break = "\\" |
| file_type = "sh" |
| if sys.platform == "win32": |
| line_break = "^" |
| file_type = "bat" |
|
|
| |
| sample = "" |
| if len(sample_prompts) > 0 and sample_every_n_steps > 0: |
| sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}""" |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| if vram == "16G": |
| |
| optimizer = f"""--optimizer_type adafactor {line_break} |
| --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} |
| --lr_scheduler constant_with_warmup {line_break} |
| --max_grad_norm 0.0 {line_break}""" |
| elif vram == "12G": |
| |
| optimizer = f"""--optimizer_type adafactor {line_break} |
| --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} |
| --split_mode {line_break} |
| --network_args "train_blocks=single" {line_break} |
| --lr_scheduler constant_with_warmup {line_break} |
| --max_grad_norm 0.0 {line_break}""" |
| else: |
| |
| optimizer = f"--optimizer_type adamw8bit {line_break}" |
|
|
|
|
| |
| model_config = models[base_model] |
| model_file = model_config["file"] |
| repo = model_config["repo"] |
| if base_model == "flux-dev" or base_model == "flux-schnell": |
| model_folder = "models/unet" |
| else: |
| model_folder = f"models/unet/{repo}" |
| model_path = os.path.join(model_folder, model_file) |
| pretrained_model_path = resolve_path(model_path) |
|
|
| clip_path = resolve_path("models/clip/clip_l.safetensors") |
| t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors") |
| ae_path = resolve_path("models/vae/ae.sft") |
| sh = f"""accelerate launch {line_break} |
| --mixed_precision bf16 {line_break} |
| --num_cpu_threads_per_process 1 {line_break} |
| sd-scripts/flux_train_network.py {line_break} |
| --pretrained_model_name_or_path {pretrained_model_path} {line_break} |
| --clip_l {clip_path} {line_break} |
| --t5xxl {t5_path} {line_break} |
| --ae {ae_path} {line_break} |
| --cache_latents_to_disk {line_break} |
| --save_model_as safetensors {line_break} |
| --sdpa --persistent_data_loader_workers {line_break} |
| --max_data_loader_n_workers {workers} {line_break} |
| --seed {seed} {line_break} |
| --gradient_checkpointing {line_break} |
| --mixed_precision bf16 {line_break} |
| --save_precision bf16 {line_break} |
| --network_module networks.lora_flux {line_break} |
| --network_dim {network_dim} {line_break} |
| {optimizer}{sample} |
| --learning_rate {learning_rate} {line_break} |
| --cache_text_encoder_outputs {line_break} |
| --cache_text_encoder_outputs_to_disk {line_break} |
| --fp8_base {line_break} |
| --highvram {line_break} |
| --max_train_epochs {max_train_epochs} {line_break} |
| --save_every_n_epochs {save_every_n_epochs} {line_break} |
| --dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break} |
| --output_dir {output_dir} {line_break} |
| --output_name {output_name} {line_break} |
| --timestep_sampling {timestep_sampling} {line_break} |
| --discrete_flow_shift 3.1582 {line_break} |
| --model_prediction_type raw {line_break} |
| --guidance_scale {guidance_scale} {line_break} |
| --loss_type l2 {line_break}""" |
| |
|
|
|
|
| |
| global advanced_component_ids |
| global original_advanced_component_values |
| |
| |
| print(f"original_advanced_component_values = {original_advanced_component_values}") |
| advanced_flags = [] |
| for i, current_value in enumerate(advanced_components): |
| |
| if original_advanced_component_values[i] != current_value: |
| |
| if current_value == True: |
| |
| advanced_flags.append(advanced_component_ids[i]) |
| else: |
| |
| advanced_flags.append(f"{advanced_component_ids[i]} {current_value}") |
|
|
| if len(advanced_flags) > 0: |
| advanced_flags_str = f" {line_break}\n ".join(advanced_flags) |
| sh = sh + "\n " + advanced_flags_str |
|
|
| return sh |
|
|
| def gen_toml( |
| dataset_folder, |
| resolution, |
| class_tokens, |
| num_repeats |
| ): |
| toml = f"""[general] |
| shuffle_caption = false |
| caption_extension = '.txt' |
| keep_tokens = 1 |
| |
| [[datasets]] |
| resolution = {resolution} |
| batch_size = 1 |
| keep_tokens = 1 |
| |
| [[datasets.subsets]] |
| image_dir = '{resolve_path_without_quotes(dataset_folder)}' |
| class_tokens = '{class_tokens}' |
| num_repeats = {num_repeats}""" |
| return toml |
|
|
| def update_total_steps(max_train_epochs, num_repeats, images): |
| try: |
| num_images = len(images) |
| total_steps = max_train_epochs * num_images * num_repeats |
| print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}") |
| return gr.update(value = total_steps) |
| except: |
| print("") |
|
|
| def set_repo(lora_rows): |
| selected_name = os.path.basename(lora_rows) |
| return gr.update(value=selected_name) |
|
|
| def get_loras(): |
| try: |
| outputs_path = resolve_path_without_quotes(f"outputs") |
| files = os.listdir(outputs_path) |
| folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"] |
| folders.sort(key=lambda file: os.path.getctime(file), reverse=True) |
| return folders |
| except Exception as e: |
| return [] |
|
|
| def get_samples(lora_name): |
| output_name = slugify(lora_name) |
| try: |
| samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample") |
| files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)] |
| files.sort(key=lambda file: os.path.getctime(file), reverse=True) |
| return files |
| except: |
| return [] |
|
|
| def start_training( |
| base_model, |
| lora_name, |
| train_script, |
| train_config, |
| sample_prompts, |
| ): |
| |
| if not os.path.exists("models"): |
| os.makedirs("models", exist_ok=True) |
| if not os.path.exists("outputs"): |
| os.makedirs("outputs", exist_ok=True) |
| output_name = slugify(lora_name) |
| output_dir = resolve_path_without_quotes(f"outputs/{output_name}") |
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| download(base_model) |
|
|
| file_type = "sh" |
| if sys.platform == "win32": |
| file_type = "bat" |
|
|
| sh_filename = f"train.{file_type}" |
| sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}") |
| with open(sh_filepath, 'w', encoding="utf-8") as file: |
| file.write(train_script) |
| gr.Info(f"Generated train script at {sh_filename}") |
|
|
|
|
| dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml") |
| with open(dataset_path, 'w', encoding="utf-8") as file: |
| file.write(train_config) |
| gr.Info(f"Generated dataset.toml") |
|
|
| sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt") |
| with open(sample_prompts_path, 'w', encoding='utf-8') as file: |
| file.write(sample_prompts) |
| gr.Info(f"Generated sample_prompts.txt") |
|
|
| |
| if sys.platform == "win32": |
| command = sh_filepath |
| else: |
| command = f"bash \"{sh_filepath}\"" |
|
|
| |
| env = os.environ.copy() |
| env['PYTHONIOENCODING'] = 'utf-8' |
| env['LOG_LEVEL'] = 'DEBUG' |
| runner = LogsViewRunner() |
| cwd = os.path.dirname(os.path.abspath(__file__)) |
| gr.Info(f"Started training") |
| yield from runner.run_command([command], cwd=cwd) |
| yield runner.log(f"Runner: {runner}") |
|
|
| |
| config = toml.loads(train_config) |
| concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens'] |
| print(f"concept_sentence={concept_sentence}") |
| print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}") |
| sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt") |
| with open(sample_prompts_path, "r", encoding="utf-8") as f: |
| lines = f.readlines() |
| sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] |
| md = readme(base_model, lora_name, concept_sentence, sample_prompts) |
| readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md") |
| with open(readme_path, "w", encoding="utf-8") as f: |
| f.write(md) |
|
|
| gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None) |
|
|
|
|
| def update( |
| base_model, |
| lora_name, |
| resolution, |
| seed, |
| workers, |
| class_tokens, |
| learning_rate, |
| network_dim, |
| max_train_epochs, |
| save_every_n_epochs, |
| timestep_sampling, |
| guidance_scale, |
| vram, |
| num_repeats, |
| sample_prompts, |
| sample_every_n_steps, |
| *advanced_components, |
| ): |
| output_name = slugify(lora_name) |
| dataset_folder = str(f"datasets/{output_name}") |
| sh = gen_sh( |
| base_model, |
| output_name, |
| resolution, |
| seed, |
| workers, |
| learning_rate, |
| network_dim, |
| max_train_epochs, |
| save_every_n_epochs, |
| timestep_sampling, |
| guidance_scale, |
| vram, |
| sample_prompts, |
| sample_every_n_steps, |
| *advanced_components, |
| ) |
| toml = gen_toml( |
| dataset_folder, |
| resolution, |
| class_tokens, |
| num_repeats |
| ) |
| return gr.update(value=sh), gr.update(value=toml), dataset_folder |
|
|
| """ |
| demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account]) |
| """ |
| def loaded(): |
| global current_account |
| current_account = account_hf() |
| print(f"current_account={current_account}") |
| if current_account != None: |
| return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True) |
| else: |
| return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False) |
|
|
| def update_sample(concept_sentence): |
| return gr.update(value=concept_sentence) |
|
|
| def refresh_publish_tab(): |
| loras = get_loras() |
| return gr.Dropdown(label="Trained LoRAs", choices=loras) |
|
|
| def init_advanced(): |
| |
| basic_args = { |
| 'pretrained_model_name_or_path', |
| 'clip_l', |
| 't5xxl', |
| 'ae', |
| 'cache_latents_to_disk', |
| 'save_model_as', |
| 'sdpa', |
| 'persistent_data_loader_workers', |
| 'max_data_loader_n_workers', |
| 'seed', |
| 'gradient_checkpointing', |
| 'mixed_precision', |
| 'save_precision', |
| 'network_module', |
| 'network_dim', |
| 'learning_rate', |
| 'cache_text_encoder_outputs', |
| 'cache_text_encoder_outputs_to_disk', |
| 'fp8_base', |
| 'highvram', |
| 'max_train_epochs', |
| 'save_every_n_epochs', |
| 'dataset_config', |
| 'output_dir', |
| 'output_name', |
| 'timestep_sampling', |
| 'discrete_flow_shift', |
| 'model_prediction_type', |
| 'guidance_scale', |
| 'loss_type', |
| 'optimizer_type', |
| 'optimizer_args', |
| 'lr_scheduler', |
| 'sample_prompts', |
| 'sample_every_n_steps', |
| 'max_grad_norm', |
| 'split_mode', |
| 'network_args' |
| } |
|
|
| |
| |
| parser = train_network.setup_parser() |
| flux_train_utils.add_flux_train_arguments(parser) |
| args_info = {} |
| for action in parser._actions: |
| if action.dest != 'help': |
| |
| args_info[action.dest] = { |
| "action": action.option_strings, |
| "type": action.type, |
| "help": action.help, |
| "default": action.default, |
| "required": action.required |
| } |
| temp = [] |
| for key in args_info: |
| temp.append({ 'key': key, 'action': args_info[key] }) |
| temp.sort(key=lambda x: x['key']) |
| advanced_component_ids = [] |
| advanced_components = [] |
| for item in temp: |
| key = item['key'] |
| action = item['action'] |
| if key in basic_args: |
| print("") |
| else: |
| action_type = str(action['type']) |
| component = None |
| with gr.Column(min_width=300): |
| if action_type == "None": |
| |
| component = gr.Checkbox() |
| |
| |
| |
| |
| |
| |
| |
| |
| else: |
| component = gr.Textbox(value="") |
| if component != None: |
| component.interactive = True |
| component.elem_id = action['action'][0] |
| component.label = component.elem_id |
| component.elem_classes = ["advanced"] |
| if action['help'] != None: |
| component.info = action['help'] |
| advanced_components.append(component) |
| advanced_component_ids.append(component.elem_id) |
| return advanced_components, advanced_component_ids |
|
|
|
|
| theme = gr.themes.Monochrome( |
| text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"), |
| font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], |
| ) |
| css = """ |
| @keyframes rotate { |
| 0% { |
| transform: rotate(0deg); |
| } |
| 100% { |
| transform: rotate(360deg); |
| } |
| } |
| #advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; } |
| h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;} |
| h3{margin-top: 0} |
| .tabitem{border: 0px} |
| .group_padding{} |
| nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); } |
| nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; } |
| nav img { height: 40px; width: 40px; border-radius: 40px; } |
| nav img.rotate { animation: rotate 2s linear infinite; } |
| .flexible { flex-grow: 1; } |
| .tast-details { margin: 10px 0 !important; } |
| .toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); } |
| .toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; } |
| .toast-body { border: none !important; } |
| #terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); } |
| #terminal .generating { border: none !important; } |
| #terminal label { position: absolute !important; } |
| .tabs { margin-top: 50px; } |
| .hidden { display: none !important; } |
| .codemirror-wrapper .cm-line { font-size: 12px !important; } |
| label { font-weight: bold !important; } |
| #start_training.clicked { background: silver; color: black; } |
| """ |
|
|
| js = """ |
| function() { |
| let autoscroll = document.querySelector("#autoscroll") |
| if (window.iidxx) { |
| window.clearInterval(window.iidxx); |
| } |
| window.iidxx = window.setInterval(function() { |
| let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim() |
| let img = document.querySelector("#logo") |
| if (text.length > 0) { |
| autoscroll.classList.remove("hidden") |
| if (autoscroll.classList.contains("on")) { |
| autoscroll.textContent = "Autoscroll ON" |
| window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" }); |
| img.classList.add("rotate") |
| } else { |
| autoscroll.textContent = "Autoscroll OFF" |
| img.classList.remove("rotate") |
| } |
| } |
| }, 500); |
| console.log("autoscroll", autoscroll) |
| autoscroll.addEventListener("click", (e) => { |
| autoscroll.classList.toggle("on") |
| }) |
| function debounce(fn, delay) { |
| let timeoutId; |
| return function(...args) { |
| clearTimeout(timeoutId); |
| timeoutId = setTimeout(() => fn(...args), delay); |
| }; |
| } |
| |
| function handleClick() { |
| console.log("refresh") |
| document.querySelector("#refresh").click(); |
| } |
| const debouncedClick = debounce(handleClick, 1000); |
| document.addEventListener("input", debouncedClick); |
| |
| document.querySelector("#start_training").addEventListener("click", (e) => { |
| e.target.classList.add("clicked") |
| e.target.innerHTML = "Training..." |
| }) |
| |
| } |
| """ |
|
|
| current_account = account_hf() |
| print(f"current_account={current_account}") |
|
|
| with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo: |
| with gr.Tabs() as tabs: |
| with gr.TabItem("Gym"): |
| output_components = [] |
| with gr.Row(): |
| gr.HTML("""<nav> |
| <img id='logo' src='/file=icon.png' width='80' height='80'> |
| <div class='flexible'></div> |
| <button id='autoscroll' class='on hidden'></button> |
| </nav> |
| """) |
| with gr.Row(elem_id='container'): |
| with gr.Column(): |
| gr.Markdown( |
| """# Step 1. LoRA Info |
| <p style="margin-top:0">Configure your LoRA train settings.</p> |
| """, elem_classes="group_padding") |
| lora_name = gr.Textbox( |
| label="The name of your LoRA", |
| info="This has to be a unique name", |
| placeholder="e.g.: Persian Miniature Painting style, Cat Toy", |
| ) |
| concept_sentence = gr.Textbox( |
| elem_id="--concept_sentence", |
| label="Trigger word/sentence", |
| info="Trigger word or sentence to be used", |
| placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'", |
| interactive=True, |
| ) |
| model_names = list(models.keys()) |
| print(f"model_names={model_names}") |
| base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0]) |
| vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True) |
| num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True) |
| max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True) |
| total_steps = gr.Number(0, interactive=False, label="Expected training steps") |
| sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True) |
| sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True) |
| resolution = gr.Number(value=512, precision=0, label="Resize dataset images") |
| with gr.Column(): |
| gr.Markdown( |
| """# Step 2. Dataset |
| <p style="margin-top:0">Make sure the captions include the trigger word.</p> |
| """, elem_classes="group_padding") |
| with gr.Group(): |
| images = gr.File( |
| file_types=["image", ".txt"], |
| label="Upload your images", |
| |
| file_count="multiple", |
| interactive=True, |
| visible=True, |
| scale=1, |
| ) |
| with gr.Group(visible=False) as captioning_area: |
| do_captioning = gr.Button("Add AI captions with Florence-2") |
| output_components.append(captioning_area) |
| |
| caption_list = [] |
| for i in range(1, MAX_IMAGES + 1): |
| locals()[f"captioning_row_{i}"] = gr.Row(visible=False) |
| with locals()[f"captioning_row_{i}"]: |
| locals()[f"image_{i}"] = gr.Image( |
| type="filepath", |
| width=111, |
| height=111, |
| min_width=111, |
| interactive=False, |
| scale=2, |
| show_label=False, |
| show_share_button=False, |
| show_download_button=False, |
| ) |
| locals()[f"caption_{i}"] = gr.Textbox( |
| label=f"Caption {i}", scale=15, interactive=True |
| ) |
|
|
| output_components.append(locals()[f"captioning_row_{i}"]) |
| output_components.append(locals()[f"image_{i}"]) |
| output_components.append(locals()[f"caption_{i}"]) |
| caption_list.append(locals()[f"caption_{i}"]) |
| with gr.Column(): |
| gr.Markdown( |
| """# Step 3. Train |
| <p style="margin-top:0">Press start to start training.</p> |
| """, elem_classes="group_padding") |
| refresh = gr.Button("Refresh", elem_id="refresh", visible=False) |
| start = gr.Button("Start training", visible=False, elem_id="start_training") |
| output_components.append(start) |
| train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True) |
| train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True) |
| with gr.Accordion("Advanced options", elem_id='advanced_options', open=False): |
| with gr.Row(): |
| with gr.Column(min_width=300): |
| seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True) |
| with gr.Column(min_width=300): |
| workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True) |
| with gr.Column(min_width=300): |
| learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True) |
| with gr.Column(min_width=300): |
| save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True) |
| with gr.Column(min_width=300): |
| guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True) |
| with gr.Column(min_width=300): |
| timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True) |
| with gr.Column(min_width=300): |
| network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True) |
| advanced_components, advanced_component_ids = init_advanced() |
| with gr.Row(): |
| terminal = LogsView(label="Train log", elem_id="terminal") |
| with gr.Row(): |
| gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6) |
|
|
| with gr.TabItem("Publish") as publish_tab: |
| hf_token = gr.Textbox(label="Huggingface Token") |
| hf_login = gr.Button("Login") |
| hf_logout = gr.Button("Logout") |
| with gr.Row() as row: |
| gr.Markdown("**LoRA**") |
| gr.Markdown("**Upload**") |
| loras = get_loras() |
| with gr.Row(): |
| lora_rows = refresh_publish_tab() |
| with gr.Column(): |
| with gr.Row(): |
| repo_owner = gr.Textbox(label="Account", interactive=False) |
| repo_name = gr.Textbox(label="Repository Name") |
| repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public") |
| upload_button = gr.Button("Upload to HuggingFace") |
| upload_button.click( |
| fn=upload_hf, |
| inputs=[ |
| base_model, |
| lora_rows, |
| repo_owner, |
| repo_name, |
| repo_visibility, |
| hf_token, |
| ] |
| ) |
| hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner]) |
| hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner]) |
|
|
|
|
| publish_tab.select(refresh_publish_tab, outputs=lora_rows) |
| lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name]) |
|
|
| dataset_folder = gr.State() |
|
|
| listeners = [ |
| base_model, |
| lora_name, |
| resolution, |
| seed, |
| workers, |
| concept_sentence, |
| learning_rate, |
| network_dim, |
| max_train_epochs, |
| save_every_n_epochs, |
| timestep_sampling, |
| guidance_scale, |
| vram, |
| num_repeats, |
| sample_prompts, |
| sample_every_n_steps, |
| *advanced_components |
| ] |
| advanced_component_ids = [x.elem_id for x in advanced_components] |
| original_advanced_component_values = [comp.value for comp in advanced_components] |
| images.upload( |
| load_captioning, |
| inputs=[images, concept_sentence], |
| outputs=output_components |
| ) |
| images.delete( |
| load_captioning, |
| inputs=[images, concept_sentence], |
| outputs=output_components |
| ) |
| images.clear( |
| hide_captioning, |
| outputs=[captioning_area, start] |
| ) |
| max_train_epochs.change( |
| fn=update_total_steps, |
| inputs=[max_train_epochs, num_repeats, images], |
| outputs=[total_steps] |
| ) |
| num_repeats.change( |
| fn=update_total_steps, |
| inputs=[max_train_epochs, num_repeats, images], |
| outputs=[total_steps] |
| ) |
| images.upload( |
| fn=update_total_steps, |
| inputs=[max_train_epochs, num_repeats, images], |
| outputs=[total_steps] |
| ) |
| images.delete( |
| fn=update_total_steps, |
| inputs=[max_train_epochs, num_repeats, images], |
| outputs=[total_steps] |
| ) |
| images.clear( |
| fn=update_total_steps, |
| inputs=[max_train_epochs, num_repeats, images], |
| outputs=[total_steps] |
| ) |
| concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts) |
| start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then( |
| fn=start_training, |
| inputs=[ |
| base_model, |
| lora_name, |
| train_script, |
| train_config, |
| sample_prompts, |
| ], |
| outputs=terminal, |
| ) |
| do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list) |
| demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner]) |
| refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder]) |
| if __name__ == "__main__": |
| cwd = os.path.dirname(os.path.abspath(__file__)) |
| demo.launch(debug=True, show_error=True, allowed_paths=[cwd]) |
|
|