mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Add push_to_hub to the trainer (#109)
* add push_to_hub * fix indentation * indent again * model_config * allow samples to not exist * repo creation fix * dont show empty [] if widget doesnt exist * dont submit the config and optimizer * Unsafe to have tokens saved in the yaml file * make sure to catch only the latest samples * change name to slug * formatting * formatting --------- Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
This commit is contained in:
@@ -21,6 +21,10 @@ config:
|
|||||||
dtype: float16 # precision to save
|
dtype: float16 # precision to save
|
||||||
save_every: 250 # save every this many steps
|
save_every: 250 # save every this many steps
|
||||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||||
|
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
||||||
|
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
||||||
|
# hf_repo_id: your-username/your-model-slug
|
||||||
|
# hf_private: true #whether the repo is private or public
|
||||||
datasets:
|
datasets:
|
||||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ config:
|
|||||||
dtype: float16 # precision to save
|
dtype: float16 # precision to save
|
||||||
save_every: 250 # save every this many steps
|
save_every: 250 # save every this many steps
|
||||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||||
|
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
||||||
|
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
||||||
|
# hf_repo_id: your-username/your-model-slug
|
||||||
|
# hf_private: true #whether the repo is private or public
|
||||||
datasets:
|
datasets:
|
||||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import random
|
|||||||
import shutil
|
import shutil
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from typing import Union, List, Optional
|
from typing import Union, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -17,6 +18,8 @@ from safetensors.torch import save_file, load_file
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cuda
|
import torch.backends.cuda
|
||||||
|
from huggingface_hub import HfApi, Repository, interpreter_login
|
||||||
|
from huggingface_hub.utils import HfFolder
|
||||||
|
|
||||||
from toolkit.basic import value_map
|
from toolkit.basic import value_map
|
||||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||||
@@ -1790,7 +1793,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.sample(self.step_num)
|
self.sample(self.step_num)
|
||||||
print("")
|
print("")
|
||||||
self.save()
|
self.save()
|
||||||
|
if self.save_config.push_to_hub:
|
||||||
|
if("HF_TOKEN" not in os.environ):
|
||||||
|
interpreter_login(new_session=False, write_permission=True)
|
||||||
|
self.push_to_hub(
|
||||||
|
repo_id=self.save_config.hf_repo_id,
|
||||||
|
private=self.save_config.hf_private
|
||||||
|
)
|
||||||
del (
|
del (
|
||||||
self.sd,
|
self.sd,
|
||||||
unet,
|
unet,
|
||||||
@@ -1802,3 +1811,120 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
)
|
)
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
|
def push_to_hub(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
private: bool = False,
|
||||||
|
):
|
||||||
|
readme_content = self._generate_readme(repo_id)
|
||||||
|
readme_path = os.path.join(self.save_root, "README.md")
|
||||||
|
with open(readme_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(readme_content)
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
api.create_repo(
|
||||||
|
repo_id,
|
||||||
|
private=private,
|
||||||
|
exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
folder_path=self.save_root,
|
||||||
|
ignore_patterns=["*.yaml", "*.pt"],
|
||||||
|
repo_type="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_readme(self, repo_id: str) -> str:
|
||||||
|
"""Generates the content of the README.md file."""
|
||||||
|
|
||||||
|
# Gather model info
|
||||||
|
base_model = self.model_config.name_or_path
|
||||||
|
instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None
|
||||||
|
if base_model == "black-forest-labs/FLUX.1-schnell":
|
||||||
|
license = "apache-2.0"
|
||||||
|
elif base_model == "black-forest-labs/FLUX.1-dev":
|
||||||
|
license = "other"
|
||||||
|
license_name = "flux-1-dev-non-commercial-license"
|
||||||
|
license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
|
||||||
|
else:
|
||||||
|
license = "creativeml-openrail-m"
|
||||||
|
tags = [
|
||||||
|
"text-to-image",
|
||||||
|
]
|
||||||
|
if self.model_config.is_xl:
|
||||||
|
tags.append("stable-diffusion-xl")
|
||||||
|
if self.model_config.is_flux:
|
||||||
|
tags.append("flux")
|
||||||
|
if self.network_config:
|
||||||
|
tags.extend(
|
||||||
|
[
|
||||||
|
"lora",
|
||||||
|
"diffusers",
|
||||||
|
"template:sd-lora",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate the widget section
|
||||||
|
widgets = []
|
||||||
|
sample_image_paths = []
|
||||||
|
samples_dir = os.path.join(self.save_root, "samples")
|
||||||
|
if os.path.isdir(samples_dir):
|
||||||
|
for filename in os.listdir(samples_dir):
|
||||||
|
#The filenames are structured as 1724085406830__00000500_0.jpg
|
||||||
|
#So here we capture the 2nd part (steps) and 3rd (index the matches the prompt)
|
||||||
|
match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
|
||||||
|
if match:
|
||||||
|
steps, index = int(match.group(1)), int(match.group(2))
|
||||||
|
#Here we only care about uploading the latest samples, the match with the # of steps
|
||||||
|
if steps == self.train_config.steps:
|
||||||
|
sample_image_paths.append((index, f"samples/{filename}"))
|
||||||
|
|
||||||
|
# Sort by numeric index
|
||||||
|
sample_image_paths.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# Create widgets matching prompt with the index
|
||||||
|
for i, prompt in enumerate(self.sample_config.prompts):
|
||||||
|
if i < len(sample_image_paths):
|
||||||
|
# Associate prompts with sample image paths based on the extracted index
|
||||||
|
_, image_path = sample_image_paths[i]
|
||||||
|
widgets.append(
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"output": {
|
||||||
|
"url": image_path
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct the README content
|
||||||
|
readme_content = f"""---
|
||||||
|
tags:
|
||||||
|
{yaml.dump(tags, indent=4).strip()}
|
||||||
|
{"widget:" if os.path.isdir(samples_dir) else ""}
|
||||||
|
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
|
||||||
|
base_model: {base_model}
|
||||||
|
{"instance_prompt: " + instance_prompt if instance_prompt else ""}
|
||||||
|
license: {license}
|
||||||
|
{'license_name: ' + license_name if license == "other" else ""}
|
||||||
|
{'license_link: ' + license_link if license == "other" else ""}
|
||||||
|
---
|
||||||
|
|
||||||
|
# {self.job.name}
|
||||||
|
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
|
||||||
|
<Gallery />
|
||||||
|
|
||||||
|
## Trigger words
|
||||||
|
|
||||||
|
{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
|
||||||
|
|
||||||
|
## Download model
|
||||||
|
|
||||||
|
Weights for this model are available in Safetensors format.
|
||||||
|
|
||||||
|
[Download](/{repo_id}/tree/main) them in the Files & versions tab.
|
||||||
|
"""
|
||||||
|
return readme_content
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ class SaveConfig:
|
|||||||
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
|
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
|
||||||
if self.save_format not in ['safetensors', 'diffusers']:
|
if self.save_format not in ['safetensors', 'diffusers']:
|
||||||
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
|
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
|
||||||
|
self.push_to_hub: bool = kwargs.get("push_to_hub", False)
|
||||||
|
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
|
||||||
|
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
|
||||||
|
|
||||||
class LogingConfig:
|
class LogingConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user