mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
A lot of pixart sigma training tweaks
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@@ -6,8 +7,9 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from diffusers import T2IAdapter
|
from diffusers import T2IAdapter
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from diffusers import StableDiffusionXLImg2ImgPipeline
|
from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
||||||
@@ -21,6 +23,7 @@ from toolkit.data_loader import get_dataloader_from_datasets
|
|||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
from controlnet_aux.midas import MidasDetector
|
from controlnet_aux.midas import MidasDetector
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -28,6 +31,9 @@ def flush():
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateConfig:
|
class GenerateConfig:
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -103,7 +109,6 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
|||||||
self.sd.load_model()
|
self.sd.load_model()
|
||||||
device = torch.device(self.device)
|
device = torch.device(self.device)
|
||||||
|
|
||||||
|
|
||||||
if self.model_config.is_xl:
|
if self.model_config.is_xl:
|
||||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||||
vae=self.sd.vae,
|
vae=self.sd.vae,
|
||||||
@@ -114,6 +119,8 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
|||||||
tokenizer_2=self.sd.tokenizer[1],
|
tokenizer_2=self.sd.tokenizer[1],
|
||||||
scheduler=get_sampler(self.generate_config.sampler),
|
scheduler=get_sampler(self.generate_config.sampler),
|
||||||
).to(device, dtype=self.torch_dtype)
|
).to(device, dtype=self.torch_dtype)
|
||||||
|
elif self.model_config.is_pixart:
|
||||||
|
pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Only XL models are supported")
|
raise NotImplementedError("Only XL models are supported")
|
||||||
pipe.set_progress_bar_config(disable=True)
|
pipe.set_progress_bar_config(disable=True)
|
||||||
@@ -130,6 +137,9 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
|||||||
for i, batch in enumerate(self.data_loader):
|
for i, batch in enumerate(self.data_loader):
|
||||||
batch: DataLoaderBatchDTO = batch
|
batch: DataLoaderBatchDTO = batch
|
||||||
|
|
||||||
|
gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
|
||||||
|
generator = torch.manual_seed(gen_seed)
|
||||||
|
|
||||||
file_item: FileItemDTO = batch.file_items[0]
|
file_item: FileItemDTO = batch.file_items[0]
|
||||||
img_path = file_item.path
|
img_path = file_item.path
|
||||||
img_filename = os.path.basename(img_path)
|
img_filename = os.path.basename(img_path)
|
||||||
@@ -152,18 +162,76 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
|||||||
img: torch.Tensor = batch.tensor.clone()
|
img: torch.Tensor = batch.tensor.clone()
|
||||||
image = self.to_pil(img)
|
image = self.to_pil(img)
|
||||||
|
|
||||||
|
|
||||||
# image.save(output_depth_path)
|
# image.save(output_depth_path)
|
||||||
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
if self.model_config.is_pixart:
|
||||||
|
pipe: PixArtSigmaPipeline = pipe
|
||||||
|
|
||||||
gen_images = pipe.__call__(
|
# Encode the full image once
|
||||||
prompt=caption,
|
encoded_image = pipe.vae.encode(
|
||||||
negative_prompt=self.generate_config.neg,
|
pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
|
||||||
image=image,
|
if hasattr(encoded_image, "latent_dist"):
|
||||||
num_inference_steps=self.generate_config.sample_steps,
|
latents = encoded_image.latent_dist.sample(generator)
|
||||||
guidance_scale=self.generate_config.guidance_scale,
|
elif hasattr(encoded_image, "latents"):
|
||||||
strength=self.generate_config.denoise_strength,
|
latents = encoded_image.latents
|
||||||
).images[0]
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
latents = pipe.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
|
# latents = self.sd.encode_images(img)
|
||||||
|
|
||||||
|
# self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
|
||||||
|
# start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
|
||||||
|
# timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
|
||||||
|
# timestep = timestep.to(device, dtype=torch.int32)
|
||||||
|
# latent = latent.to(device, dtype=self.torch_dtype)
|
||||||
|
# noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
|
||||||
|
# latent = self.sd.add_noise(latent, noise, timestep)
|
||||||
|
# timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
|
||||||
|
batch_size = 1
|
||||||
|
num_images_per_prompt = 1
|
||||||
|
|
||||||
|
shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
|
||||||
|
image.width // pipe.vae_scale_factor)
|
||||||
|
noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
|
||||||
|
|
||||||
|
# noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
|
||||||
|
num_inference_steps = self.generate_config.sample_steps
|
||||||
|
strength = self.generate_config.denoise_strength
|
||||||
|
# Get timesteps
|
||||||
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||||
|
t_start = max(num_inference_steps - init_timestep, 0)
|
||||||
|
pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
|
||||||
|
timesteps = pipe.scheduler.timesteps[t_start:]
|
||||||
|
timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||||
|
latents = pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
|
||||||
|
gen_images = pipe.__call__(
|
||||||
|
prompt=caption,
|
||||||
|
negative_prompt=self.generate_config.neg,
|
||||||
|
latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
width=image.width,
|
||||||
|
height=image.height,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
guidance_scale=self.generate_config.guidance_scale,
|
||||||
|
# strength=self.generate_config.denoise_strength,
|
||||||
|
use_resolution_binning=False,
|
||||||
|
output_type="np"
|
||||||
|
).images[0]
|
||||||
|
gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
|
||||||
|
gen_images = Image.fromarray(gen_images)
|
||||||
|
else:
|
||||||
|
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
||||||
|
|
||||||
|
gen_images = pipe.__call__(
|
||||||
|
prompt=caption,
|
||||||
|
negative_prompt=self.generate_config.neg,
|
||||||
|
image=image,
|
||||||
|
num_inference_steps=self.generate_config.sample_steps,
|
||||||
|
guidance_scale=self.generate_config.guidance_scale,
|
||||||
|
strength=self.generate_config.denoise_strength,
|
||||||
|
).images[0]
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
gen_images.save(output_path)
|
gen_images.save(output_path)
|
||||||
|
|
||||||
|
|||||||
@@ -1331,6 +1331,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
is_lorm=is_lorm,
|
is_lorm=is_lorm,
|
||||||
network_config=self.network_config,
|
network_config=self.network_config,
|
||||||
network_type=self.network_config.type,
|
network_type=self.network_config.type,
|
||||||
|
transformer_only=self.network_config.transformer_only,
|
||||||
**network_kwargs
|
**network_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
2
run.py
2
run.py
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
import sys
|
import sys
|
||||||
from typing import Union, OrderedDict
|
from typing import Union, OrderedDict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ resolutions_1024: List[BucketResolution] = [
|
|||||||
{"width": 512, "height": 1920},
|
{"width": 512, "height": 1920},
|
||||||
{"width": 512, "height": 1984},
|
{"width": 512, "height": 1984},
|
||||||
{"width": 512, "height": 2048},
|
{"width": 512, "height": 2048},
|
||||||
|
# extra wides
|
||||||
|
{"width": 8192, "height": 128},
|
||||||
|
{"width": 128, "height": 8192},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Even numbers so they can be patched easier
|
# Even numbers so they can be patched easier
|
||||||
|
|||||||
@@ -128,6 +128,8 @@ class NetworkConfig:
|
|||||||
if self.lorm_config.do_conv:
|
if self.lorm_config.do_conv:
|
||||||
self.conv = 4
|
self.conv = 4
|
||||||
|
|
||||||
|
self.transformer_only = kwargs.get('transformer_only', False)
|
||||||
|
|
||||||
|
|
||||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
|
||||||
|
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||||
network_type: str = "lora",
|
network_type: str = "lora",
|
||||||
full_train_in_out: bool = False,
|
full_train_in_out: bool = False,
|
||||||
|
transformer_only: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -193,6 +194,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if ignore_if_contains is None:
|
if ignore_if_contains is None:
|
||||||
ignore_if_contains = []
|
ignore_if_contains = []
|
||||||
self.ignore_if_contains = ignore_if_contains
|
self.ignore_if_contains = ignore_if_contains
|
||||||
|
self.transformer_only = transformer_only
|
||||||
|
|
||||||
self.only_if_contains: Union[List, None] = only_if_contains
|
self.only_if_contains: Union[List, None] = only_if_contains
|
||||||
|
|
||||||
@@ -271,6 +273,15 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
||||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
lora_name = [prefix, name, child_name]
|
||||||
|
# filter out blank
|
||||||
|
lora_name = [x for x in lora_name if x and x != ""]
|
||||||
|
lora_name = ".".join(lora_name)
|
||||||
|
# if it doesnt have a name, it wil have two dots
|
||||||
|
lora_name.replace("..", ".")
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
skip = False
|
skip = False
|
||||||
if any([word in child_name for word in self.ignore_if_contains]):
|
if any([word in child_name for word in self.ignore_if_contains]):
|
||||||
skip = True
|
skip = True
|
||||||
@@ -279,9 +290,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if count_parameters(child_module) < parameter_threshold:
|
if count_parameters(child_module) < parameter_threshold:
|
||||||
skip = True
|
skip = True
|
||||||
|
|
||||||
|
if self.transformer_only and self.is_pixart and is_unet:
|
||||||
|
if "transformer_blocks" not in lora_name:
|
||||||
|
skip = True
|
||||||
|
|
||||||
if (is_linear or is_conv2d) and not skip:
|
if (is_linear or is_conv2d) and not skip:
|
||||||
lora_name = prefix + "." + name + "." + child_name
|
|
||||||
lora_name = lora_name.replace(".", "_")
|
|
||||||
|
|
||||||
if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]):
|
if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]):
|
||||||
continue
|
continue
|
||||||
@@ -356,8 +369,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
index = None
|
index = None
|
||||||
print(f"create LoRA for Text Encoder:")
|
print(f"create LoRA for Text Encoder:")
|
||||||
|
|
||||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder,
|
replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||||
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
|
||||||
|
if self.is_pixart:
|
||||||
|
replace_modules = ["T5EncoderModel"]
|
||||||
|
|
||||||
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules)
|
||||||
self.text_encoder_loras.extend(text_encoder_loras)
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
skipped_te += skipped
|
skipped_te += skipped
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|||||||
@@ -516,6 +516,9 @@ class ToolkitNetworkMixin:
|
|||||||
load_sd = OrderedDict()
|
load_sd = OrderedDict()
|
||||||
for key, value in weights_sd.items():
|
for key, value in weights_sd.items():
|
||||||
load_key = keymap[key] if key in keymap else key
|
load_key = keymap[key] if key in keymap else key
|
||||||
|
# replace old double __ with single _
|
||||||
|
if self.is_pixart:
|
||||||
|
load_key = load_key.replace('__', '_')
|
||||||
load_sd[load_key] = value
|
load_sd[load_key] = value
|
||||||
|
|
||||||
# extract extra items from state dict
|
# extract extra items from state dict
|
||||||
|
|||||||
@@ -169,15 +169,6 @@ class StableDiffusion:
|
|||||||
if self.is_loaded:
|
if self.is_loaded:
|
||||||
return
|
return
|
||||||
dtype = get_torch_dtype(self.dtype)
|
dtype = get_torch_dtype(self.dtype)
|
||||||
# sch = KDPM2DiscreteScheduler
|
|
||||||
if self.noise_scheduler is None:
|
|
||||||
scheduler = get_sampler(
|
|
||||||
'ddpm', {
|
|
||||||
"prediction_type": self.prediction_type,
|
|
||||||
},
|
|
||||||
'sd' if not self.is_pixart else 'pixart'
|
|
||||||
)
|
|
||||||
self.noise_scheduler = scheduler
|
|
||||||
|
|
||||||
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
||||||
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||||
@@ -190,9 +181,10 @@ class StableDiffusion:
|
|||||||
from toolkit.civitai import get_model_path_from_url
|
from toolkit.civitai import get_model_path_from_url
|
||||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||||
|
|
||||||
load_args = {
|
load_args = {}
|
||||||
'scheduler': self.noise_scheduler,
|
if self.noise_scheduler:
|
||||||
}
|
load_args['scheduler'] = self.noise_scheduler
|
||||||
|
|
||||||
if self.model_config.vae_path is not None:
|
if self.model_config.vae_path is not None:
|
||||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||||
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
||||||
@@ -290,6 +282,7 @@ class StableDiffusion:
|
|||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
torch_dtype=self.torch_dtype,
|
torch_dtype=self.torch_dtype,
|
||||||
text_encoder_3=text_encoder3,
|
text_encoder_3=text_encoder3,
|
||||||
|
**load_args
|
||||||
)
|
)
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
@@ -387,6 +380,8 @@ class StableDiffusion:
|
|||||||
tokenizer = pipe.tokenizer
|
tokenizer = pipe.tokenizer
|
||||||
|
|
||||||
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||||
|
if self.noise_scheduler is None:
|
||||||
|
self.noise_scheduler = pipe.scheduler
|
||||||
|
|
||||||
|
|
||||||
elif self.model_config.is_auraflow:
|
elif self.model_config.is_auraflow:
|
||||||
|
|||||||
Reference in New Issue
Block a user