Add push_to_hub to the trainer (#109)

* add push_to_hub

* fix indentation

* indent again

* model_config

* allow samples to not exist

* repo creation fix

* dont show empty [] if widget doesnt exist

* dont submit the config and optimizer

* Unsafe to have tokens saved in the yaml file

* make sure to catch only the latest samples

* change name to slug

* formatting

* formatting

---------

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
This commit is contained in:
apolinário
2024-08-23 04:18:56 +01:00
committed by GitHub
parent b322d05fa3
commit 4d35a29c97
4 changed files with 138 additions and 2 deletions

View File

@@ -21,6 +21,10 @@ config:
dtype: float16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently

View File

@@ -21,6 +21,10 @@ config:
dtype: float16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently

View File

@@ -6,6 +6,7 @@ import random
import shutil
from collections import OrderedDict
import os
import re
from typing import Union, List, Optional
import numpy as np
@@ -17,6 +18,8 @@ from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader
import torch
import torch.backends.cuda
from huggingface_hub import HfApi, Repository, interpreter_login
from huggingface_hub.utils import HfFolder
from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter
@@ -1790,7 +1793,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sample(self.step_num)
print("")
self.save()
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
interpreter_login(new_session=False, write_permission=True)
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
del (
self.sd,
unet,
@@ -1802,3 +1811,120 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
flush()
def push_to_hub(
self,
repo_id: str,
private: bool = False,
):
readme_content = self._generate_readme(repo_id)
readme_path = os.path.join(self.save_root, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
api = HfApi()
api.create_repo(
repo_id,
private=private,
exist_ok=True
)
api.upload_folder(
repo_id=repo_id,
folder_path=self.save_root,
ignore_patterns=["*.yaml", "*.pt"],
repo_type="model",
)
def _generate_readme(self, repo_id: str) -> str:
"""Generates the content of the README.md file."""
# Gather model info
base_model = self.model_config.name_or_path
instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None
if base_model == "black-forest-labs/FLUX.1-schnell":
license = "apache-2.0"
elif base_model == "black-forest-labs/FLUX.1-dev":
license = "other"
license_name = "flux-1-dev-non-commercial-license"
license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
else:
license = "creativeml-openrail-m"
tags = [
"text-to-image",
]
if self.model_config.is_xl:
tags.append("stable-diffusion-xl")
if self.model_config.is_flux:
tags.append("flux")
if self.network_config:
tags.extend(
[
"lora",
"diffusers",
"template:sd-lora",
]
)
# Generate the widget section
widgets = []
sample_image_paths = []
samples_dir = os.path.join(self.save_root, "samples")
if os.path.isdir(samples_dir):
for filename in os.listdir(samples_dir):
#The filenames are structured as 1724085406830__00000500_0.jpg
#So here we capture the 2nd part (steps) and 3rd (index the matches the prompt)
match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
if match:
steps, index = int(match.group(1)), int(match.group(2))
#Here we only care about uploading the latest samples, the match with the # of steps
if steps == self.train_config.steps:
sample_image_paths.append((index, f"samples/{filename}"))
# Sort by numeric index
sample_image_paths.sort(key=lambda x: x[0])
# Create widgets matching prompt with the index
for i, prompt in enumerate(self.sample_config.prompts):
if i < len(sample_image_paths):
# Associate prompts with sample image paths based on the extracted index
_, image_path = sample_image_paths[i]
widgets.append(
{
"text": prompt,
"output": {
"url": image_path
},
}
)
# Construct the README content
readme_content = f"""---
tags:
{yaml.dump(tags, indent=4).strip()}
{"widget:" if os.path.isdir(samples_dir) else ""}
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
base_model: {base_model}
{"instance_prompt: " + instance_prompt if instance_prompt else ""}
license: {license}
{'license_name: ' + license_name if license == "other" else ""}
{'license_link: ' + license_link if license == "other" else ""}
---
# {self.job.name}
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
<Gallery />
## Trigger words
{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
## Download model
Weights for this model are available in Safetensors format.
[Download](/{repo_id}/tree/main) them in the Files & versions tab.
"""
return readme_content

View File

@@ -23,7 +23,9 @@ class SaveConfig:
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
if self.save_format not in ['safetensors', 'diffusers']:
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
self.push_to_hub: bool = kwargs.get("push_to_hub", False)
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
class LogingConfig:
def __init__(self, **kwargs):