diff --git a/config/examples/train_lora_flux_24gb.yaml b/config/examples/train_lora_flux_24gb.yaml index c3fd119a..8e29402b 100644 --- a/config/examples/train_lora_flux_24gb.yaml +++ b/config/examples/train_lora_flux_24gb.yaml @@ -21,6 +21,10 @@ config: dtype: float16 # precision to save save_every: 250 # save every this many steps max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public datasets: # datasets are a folder of images. captions need to be txt files with the same name as the image # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently diff --git a/config/examples/train_lora_flux_schnell_24gb.yaml b/config/examples/train_lora_flux_schnell_24gb.yaml index c6ef95e4..a4aef078 100644 --- a/config/examples/train_lora_flux_schnell_24gb.yaml +++ b/config/examples/train_lora_flux_schnell_24gb.yaml @@ -21,6 +21,10 @@ config: dtype: float16 # precision to save save_every: 250 # save every this many steps max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public datasets: # datasets are a folder of images. captions need to be txt files with the same name as the image # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 474457d8..0392933c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -6,6 +6,7 @@ import random import shutil from collections import OrderedDict import os +import re from typing import Union, List, Optional import numpy as np @@ -17,6 +18,8 @@ from safetensors.torch import save_file, load_file from torch.utils.data import DataLoader import torch import torch.backends.cuda +from huggingface_hub import HfApi, Repository, interpreter_login +from huggingface_hub.utils import HfFolder from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter @@ -1790,7 +1793,13 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sample(self.step_num) print("") self.save() - + if self.save_config.push_to_hub: + if("HF_TOKEN" not in os.environ): + interpreter_login(new_session=False, write_permission=True) + self.push_to_hub( + repo_id=self.save_config.hf_repo_id, + private=self.save_config.hf_private + ) del ( self.sd, unet, @@ -1802,3 +1811,120 @@ class BaseSDTrainProcess(BaseTrainProcess): ) flush() + + def push_to_hub( + self, + repo_id: str, + private: bool = False, + ): + readme_content = self._generate_readme(repo_id) + readme_path = os.path.join(self.save_root, "README.md") + with open(readme_path, "w", encoding="utf-8") as f: + f.write(readme_content) + + api = HfApi() + + api.create_repo( + repo_id, + private=private, + exist_ok=True + ) + + api.upload_folder( + repo_id=repo_id, + folder_path=self.save_root, + ignore_patterns=["*.yaml", "*.pt"], + repo_type="model", + ) + + + def _generate_readme(self, repo_id: str) -> str: + """Generates the content of the README.md file.""" + + # Gather model info + base_model = self.model_config.name_or_path + instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None + if base_model == "black-forest-labs/FLUX.1-schnell": + license = "apache-2.0" + elif base_model == "black-forest-labs/FLUX.1-dev": + license = "other" + license_name = "flux-1-dev-non-commercial-license" + license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md" + else: + license = "creativeml-openrail-m" + tags = [ + "text-to-image", + ] + if self.model_config.is_xl: + tags.append("stable-diffusion-xl") + if self.model_config.is_flux: + tags.append("flux") + if self.network_config: + tags.extend( + [ + "lora", + "diffusers", + "template:sd-lora", + ] + ) + + # Generate the widget section + widgets = [] + sample_image_paths = [] + samples_dir = os.path.join(self.save_root, "samples") + if os.path.isdir(samples_dir): + for filename in os.listdir(samples_dir): + #The filenames are structured as 1724085406830__00000500_0.jpg + #So here we capture the 2nd part (steps) and 3rd (index the matches the prompt) + match = re.search(r"__(\d+)_(\d+)\.jpg$", filename) + if match: + steps, index = int(match.group(1)), int(match.group(2)) + #Here we only care about uploading the latest samples, the match with the # of steps + if steps == self.train_config.steps: + sample_image_paths.append((index, f"samples/{filename}")) + + # Sort by numeric index + sample_image_paths.sort(key=lambda x: x[0]) + + # Create widgets matching prompt with the index + for i, prompt in enumerate(self.sample_config.prompts): + if i < len(sample_image_paths): + # Associate prompts with sample image paths based on the extracted index + _, image_path = sample_image_paths[i] + widgets.append( + { + "text": prompt, + "output": { + "url": image_path + }, + } + ) + + # Construct the README content + readme_content = f"""--- +tags: +{yaml.dump(tags, indent=4).strip()} +{"widget:" if os.path.isdir(samples_dir) else ""} +{yaml.dump(widgets, indent=4).strip() if widgets else ""} +base_model: {base_model} +{"instance_prompt: " + instance_prompt if instance_prompt else ""} +license: {license} +{'license_name: ' + license_name if license == "other" else ""} +{'license_link: ' + license_link if license == "other" else ""} +--- + +# {self.job.name} +Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) + + +## Trigger words + +{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."} + +## Download model + +Weights for this model are available in Safetensors format. + +[Download](/{repo_id}/tree/main) them in the Files & versions tab. +""" + return readme_content diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f7111a32..b794dd2f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -23,7 +23,9 @@ class SaveConfig: self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors') if self.save_format not in ['safetensors', 'diffusers']: raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}") - + self.push_to_hub: bool = kwargs.get("push_to_hub", False) + self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None) + self.hf_private: Optional[str] = kwargs.get("hf_private", False) class LogingConfig: def __init__(self, **kwargs):