mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add push_to_hub to the trainer (#109)
* add push_to_hub * fix indentation * indent again * model_config * allow samples to not exist * repo creation fix * dont show empty [] if widget doesnt exist * dont submit the config and optimizer * Unsafe to have tokens saved in the yaml file * make sure to catch only the latest samples * change name to slug * formatting * formatting --------- Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
This commit is contained in:
@@ -21,6 +21,10 @@ config:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
||||
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
||||
# hf_repo_id: your-username/your-model-slug
|
||||
# hf_private: true #whether the repo is private or public
|
||||
datasets:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
|
||||
@@ -21,6 +21,10 @@ config:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
||||
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
||||
# hf_repo_id: your-username/your-model-slug
|
||||
# hf_private: true #whether the repo is private or public
|
||||
datasets:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
|
||||
@@ -6,6 +6,7 @@ import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import re
|
||||
from typing import Union, List, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -17,6 +18,8 @@ from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
from huggingface_hub import HfApi, Repository, interpreter_login
|
||||
from huggingface_hub.utils import HfFolder
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
@@ -1790,7 +1793,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sample(self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
interpreter_login(new_session=False, write_permission=True)
|
||||
self.push_to_hub(
|
||||
repo_id=self.save_config.hf_repo_id,
|
||||
private=self.save_config.hf_private
|
||||
)
|
||||
del (
|
||||
self.sd,
|
||||
unet,
|
||||
@@ -1802,3 +1811,120 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
flush()
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
private: bool = False,
|
||||
):
|
||||
readme_content = self._generate_readme(repo_id)
|
||||
readme_path = os.path.join(self.save_root, "README.md")
|
||||
with open(readme_path, "w", encoding="utf-8") as f:
|
||||
f.write(readme_content)
|
||||
|
||||
api = HfApi()
|
||||
|
||||
api.create_repo(
|
||||
repo_id,
|
||||
private=private,
|
||||
exist_ok=True
|
||||
)
|
||||
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=self.save_root,
|
||||
ignore_patterns=["*.yaml", "*.pt"],
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
def _generate_readme(self, repo_id: str) -> str:
|
||||
"""Generates the content of the README.md file."""
|
||||
|
||||
# Gather model info
|
||||
base_model = self.model_config.name_or_path
|
||||
instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None
|
||||
if base_model == "black-forest-labs/FLUX.1-schnell":
|
||||
license = "apache-2.0"
|
||||
elif base_model == "black-forest-labs/FLUX.1-dev":
|
||||
license = "other"
|
||||
license_name = "flux-1-dev-non-commercial-license"
|
||||
license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
|
||||
else:
|
||||
license = "creativeml-openrail-m"
|
||||
tags = [
|
||||
"text-to-image",
|
||||
]
|
||||
if self.model_config.is_xl:
|
||||
tags.append("stable-diffusion-xl")
|
||||
if self.model_config.is_flux:
|
||||
tags.append("flux")
|
||||
if self.network_config:
|
||||
tags.extend(
|
||||
[
|
||||
"lora",
|
||||
"diffusers",
|
||||
"template:sd-lora",
|
||||
]
|
||||
)
|
||||
|
||||
# Generate the widget section
|
||||
widgets = []
|
||||
sample_image_paths = []
|
||||
samples_dir = os.path.join(self.save_root, "samples")
|
||||
if os.path.isdir(samples_dir):
|
||||
for filename in os.listdir(samples_dir):
|
||||
#The filenames are structured as 1724085406830__00000500_0.jpg
|
||||
#So here we capture the 2nd part (steps) and 3rd (index the matches the prompt)
|
||||
match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
|
||||
if match:
|
||||
steps, index = int(match.group(1)), int(match.group(2))
|
||||
#Here we only care about uploading the latest samples, the match with the # of steps
|
||||
if steps == self.train_config.steps:
|
||||
sample_image_paths.append((index, f"samples/{filename}"))
|
||||
|
||||
# Sort by numeric index
|
||||
sample_image_paths.sort(key=lambda x: x[0])
|
||||
|
||||
# Create widgets matching prompt with the index
|
||||
for i, prompt in enumerate(self.sample_config.prompts):
|
||||
if i < len(sample_image_paths):
|
||||
# Associate prompts with sample image paths based on the extracted index
|
||||
_, image_path = sample_image_paths[i]
|
||||
widgets.append(
|
||||
{
|
||||
"text": prompt,
|
||||
"output": {
|
||||
"url": image_path
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Construct the README content
|
||||
readme_content = f"""---
|
||||
tags:
|
||||
{yaml.dump(tags, indent=4).strip()}
|
||||
{"widget:" if os.path.isdir(samples_dir) else ""}
|
||||
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
|
||||
base_model: {base_model}
|
||||
{"instance_prompt: " + instance_prompt if instance_prompt else ""}
|
||||
license: {license}
|
||||
{'license_name: ' + license_name if license == "other" else ""}
|
||||
{'license_link: ' + license_link if license == "other" else ""}
|
||||
---
|
||||
|
||||
# {self.job.name}
|
||||
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
|
||||
<Gallery />
|
||||
|
||||
## Trigger words
|
||||
|
||||
{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
|
||||
|
||||
## Download model
|
||||
|
||||
Weights for this model are available in Safetensors format.
|
||||
|
||||
[Download](/{repo_id}/tree/main) them in the Files & versions tab.
|
||||
"""
|
||||
return readme_content
|
||||
|
||||
@@ -23,7 +23,9 @@ class SaveConfig:
|
||||
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
|
||||
if self.save_format not in ['safetensors', 'diffusers']:
|
||||
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
|
||||
|
||||
self.push_to_hub: bool = kwargs.get("push_to_hub", False)
|
||||
self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
|
||||
self.hf_private: Optional[str] = kwargs.get("hf_private", False)
|
||||
|
||||
class LogingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user