mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-06 02:59:48 +00:00
Added preliminary support for SD3.5-large lora training
This commit is contained in:
97
config/examples/train_lora_sd35_large_24gb.yaml
Normal file
97
config/examples/train_lora_sd35_large_24gb.yaml
Normal file
@@ -0,0 +1,97 @@
|
||||
---
|
||||
# NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
|
||||
job: extension
|
||||
config:
|
||||
# this name will be the folder and filename name
|
||||
name: "my_first_sd3l_lora_v1"
|
||||
process:
|
||||
- type: 'sd_trainer'
|
||||
# root folder to save training sessions/samples/weights
|
||||
training_folder: "output"
|
||||
# uncomment to see performance stats in the terminal every N steps
|
||||
# performance_log_every: 1000
|
||||
device: cuda:0
|
||||
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
||||
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
||||
# trigger_word: "p3r5on"
|
||||
network:
|
||||
type: "lora"
|
||||
linear: 16
|
||||
linear_alpha: 16
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
||||
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
||||
# hf_repo_id: your-username/your-model-slug
|
||||
# hf_private: true #whether the repo is private or public
|
||||
datasets:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
# images will automatically be resized and bucketed into the resolution specified
|
||||
# on windows, escape back slashes with another backslash so
|
||||
# "C:\\path\\to\\images\\folder"
|
||||
- folder_path: "/path/to/images/folder"
|
||||
caption_ext: "txt"
|
||||
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
||||
shuffle_tokens: false # shuffle caption order, split by commas
|
||||
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
||||
resolution: [ 1024 ]
|
||||
train:
|
||||
batch_size: 1
|
||||
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
||||
gradient_accumulation_steps: 1
|
||||
train_unet: true
|
||||
train_text_encoder: false # May not fully work with SD3 yet
|
||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||
noise_scheduler: "flowmatch"
|
||||
timestep_type: "linear" # linear or sigmoid
|
||||
optimizer: "adamw8bit"
|
||||
lr: 1e-4
|
||||
# uncomment this to skip the pre training sample
|
||||
# skip_first_sample: true
|
||||
# uncomment to completely disable sampling
|
||||
# disable_sampling: true
|
||||
# uncomment to use new vell curved weighting. Experimental but may produce better results
|
||||
# linear_timesteps: true
|
||||
|
||||
# ema will smooth out learning, but could slow it down. Recommended to leave on.
|
||||
ema_config:
|
||||
use_ema: true
|
||||
ema_decay: 0.99
|
||||
|
||||
# will probably need this if gpu supports it for sd3, other dtypes may not work correctly
|
||||
dtype: bf16
|
||||
model:
|
||||
# huggingface model name or path
|
||||
name_or_path: "stabilityai/stable-diffusion-3.5-large"
|
||||
is_v3: true
|
||||
quantize: true # run 8bit mixed precision
|
||||
sample:
|
||||
sampler: "flowmatch" # must match train.noise_scheduler
|
||||
sample_every: 250 # sample every this many steps
|
||||
width: 1024
|
||||
height: 1024
|
||||
prompts:
|
||||
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
||||
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
||||
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
||||
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
||||
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
||||
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
||||
- "a bear building a log cabin in the snow covered mountains"
|
||||
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
||||
- "hipster man with a beard, building a chair, in a wood shop"
|
||||
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
||||
- "a man holding a sign that says, 'this is a sign'"
|
||||
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
||||
neg: ""
|
||||
seed: 42
|
||||
walk_seed: true
|
||||
guidance_scale: 4
|
||||
sample_steps: 25
|
||||
# you can add any additional meta info here. [name] is replaced with config name at top
|
||||
meta:
|
||||
name: "[name]"
|
||||
version: '1.0'
|
||||
@@ -1907,6 +1907,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
tags.append("stable-diffusion-xl")
|
||||
if self.model_config.is_flux:
|
||||
tags.append("flux")
|
||||
if self.model_config.is_v3:
|
||||
tags.append("sd3")
|
||||
if self.network_config:
|
||||
tags.extend(
|
||||
[
|
||||
|
||||
@@ -232,7 +232,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.peft_format = peft_format
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux:
|
||||
if self.is_flux or self.is_v3:
|
||||
self.peft_format = True
|
||||
|
||||
if self.peft_format:
|
||||
@@ -326,6 +326,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if self.transformer_only and self.is_flux and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and self.is_v3 and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -267,30 +267,84 @@ class StableDiffusion:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = StableDiffusion3Pipeline
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-3-medium"
|
||||
text_encoder3 = T5EncoderModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder_3",
|
||||
# quantization_config=quantization_config,
|
||||
revision="refs/pr/26",
|
||||
device_map="cuda"
|
||||
|
||||
print("Loading SD3 model")
|
||||
# assume it is the large model
|
||||
base_model_path = "stabilityai/stable-diffusion-3.5-large"
|
||||
print("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
# check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, 'text_encoder')
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
else:
|
||||
# is remote use whatever path we were given
|
||||
base_model_path = model_path
|
||||
|
||||
transformer = SD3Transformer2DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
if not self.low_vram:
|
||||
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError("LoRA is not supported for SD3 models currently")
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
print("Loading t5")
|
||||
tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype)
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained(
|
||||
base_model_path,
|
||||
subfolder="text_encoder_3",
|
||||
torch_dtype=dtype
|
||||
)
|
||||
|
||||
text_encoder_3.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing T5")
|
||||
quantize(text_encoder_3, weights=qfloat8)
|
||||
freeze(text_encoder_3)
|
||||
flush()
|
||||
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
try:
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
base_model_path,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
text_encoder_3=text_encoder3,
|
||||
tokenizer_3=tokenizer_3,
|
||||
text_encoder_3=text_encoder_3,
|
||||
transformer=transformer,
|
||||
# variant="fp16",
|
||||
use_safetensors=True,
|
||||
revision="refs/pr/26",
|
||||
repo_type="model",
|
||||
ignore_patterns=["*.md", "*..gitattributes"],
|
||||
**load_args
|
||||
@@ -302,9 +356,11 @@ class StableDiffusion:
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
transformer=transformer,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
text_encoder_3=text_encoder3,
|
||||
tokenizer_3=tokenizer_3,
|
||||
text_encoder_3=text_encoder_3,
|
||||
**load_args
|
||||
)
|
||||
|
||||
@@ -1815,6 +1871,8 @@ class StableDiffusion:
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
if isinstance(noise_pred, QTensor):
|
||||
noise_pred = noise_pred.dequantize()
|
||||
elif self.is_auraflow:
|
||||
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
|
||||
Reference in New Issue
Block a user