Add push_to_hub to the trainer (#109)

* add push_to_hub

* fix indentation

* indent again

* model_config

* allow samples to not exist

* repo creation fix

* dont show empty [] if widget doesnt exist

* dont submit the config and optimizer

* Unsafe to have tokens saved in the yaml file

* make sure to catch only the latest samples

* change name to slug

* formatting

* formatting

---------

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
This commit is contained in:
apolinário
2024-08-23 04:18:56 +01:00
committed by GitHub
parent b322d05fa3
commit 4d35a29c97
4 changed files with 138 additions and 2 deletions

View File

@@ -21,6 +21,10 @@ config:
dtype: float16 # precision to save dtype: float16 # precision to save
save_every: 250 # save every this many steps save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets: datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image # datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently

View File

@@ -21,6 +21,10 @@ config:
dtype: float16 # precision to save dtype: float16 # precision to save
save_every: 250 # save every this many steps save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets: datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image # datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently

View File

@@ -6,6 +6,7 @@ import random
import shutil import shutil
from collections import OrderedDict from collections import OrderedDict
import os import os
import re
from typing import Union, List, Optional from typing import Union, List, Optional
import numpy as np import numpy as np
@@ -17,6 +18,8 @@ from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch import torch
import torch.backends.cuda import torch.backends.cuda
from huggingface_hub import HfApi, Repository, interpreter_login
from huggingface_hub.utils import HfFolder
from toolkit.basic import value_map from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.clip_vision_adapter import ClipVisionAdapter
@@ -1790,7 +1793,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sample(self.step_num) self.sample(self.step_num)
print("") print("")
self.save() self.save()
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
interpreter_login(new_session=False, write_permission=True)
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
del ( del (
self.sd, self.sd,
unet, unet,
@@ -1802,3 +1811,120 @@ class BaseSDTrainProcess(BaseTrainProcess):
) )
flush() flush()
def push_to_hub(
self,
repo_id: str,
private: bool = False,
):
readme_content = self._generate_readme(repo_id)
readme_path = os.path.join(self.save_root, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
api = HfApi()
api.create_repo(
repo_id,
private=private,
exist_ok=True
)
api.upload_folder(
repo_id=repo_id,
folder_path=self.save_root,
ignore_patterns=["*.yaml", "*.pt"],
repo_type="model",
)
def _generate_readme(self, repo_id: str) -> str:
"""Generates the content of the README.md file."""
# Gather model info
base_model = self.model_config.name_or_path
instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None
if base_model == "black-forest-labs/FLUX.1-schnell":
license = "apache-2.0"
elif base_model == "black-forest-labs/FLUX.1-dev":
license = "other"
license_name = "flux-1-dev-non-commercial-license"
license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
else:
license = "creativeml-openrail-m"
tags = [
"text-to-image",
]
if self.model_config.is_xl:
tags.append("stable-diffusion-xl")
if self.model_config.is_flux:
tags.append("flux")
if self.network_config:
tags.extend(
[
"lora",
"diffusers",
"template:sd-lora",
]
)
# Generate the widget section
widgets = []
sample_image_paths = []
samples_dir = os.path.join(self.save_root, "samples")
if os.path.isdir(samples_dir):
for filename in os.listdir(samples_dir):
#The filenames are structured as 1724085406830__00000500_0.jpg
#So here we capture the 2nd part (steps) and 3rd (index the matches the prompt)
match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
if match:
steps, index = int(match.group(1)), int(match.group(2))
#Here we only care about uploading the latest samples, the match with the # of steps
if steps == self.train_config.steps:
sample_image_paths.append((index, f"samples/{filename}"))
# Sort by numeric index
sample_image_paths.sort(key=lambda x: x[0])
# Create widgets matching prompt with the index
for i, prompt in enumerate(self.sample_config.prompts):
if i < len(sample_image_paths):
# Associate prompts with sample image paths based on the extracted index
_, image_path = sample_image_paths[i]
widgets.append(
{
"text": prompt,
"output": {
"url": image_path
},
}
)
# Construct the README content
readme_content = f"""---
tags:
{yaml.dump(tags, indent=4).strip()}
{"widget:" if os.path.isdir(samples_dir) else ""}
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
base_model: {base_model}
{"instance_prompt: " + instance_prompt if instance_prompt else ""}
license: {license}
{'license_name: ' + license_name if license == "other" else ""}
{'license_link: ' + license_link if license == "other" else ""}
---
# {self.job.name}
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
<Gallery />
## Trigger words
{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
## Download model
Weights for this model are available in Safetensors format.
[Download](/{repo_id}/tree/main) them in the Files & versions tab.
"""
return readme_content

View File

@@ -23,7 +23,9 @@ class SaveConfig:
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors') self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
if self.save_format not in ['safetensors', 'diffusers']: if self.save_format not in ['safetensors', 'diffusers']:
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}") raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
self.push_to_hub: bool = kwargs.get("push_to_hub", False)
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
class LogingConfig: class LogingConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):