mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
A lot of pixart sigma training tweaks
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
@@ -6,8 +7,9 @@ from typing import List
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from diffusers import StableDiffusionXLImg2ImgPipeline
|
||||
from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
||||
@@ -21,6 +23,7 @@ from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from controlnet_aux.midas import MidasDetector
|
||||
from diffusers.utils import load_image
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -28,6 +31,9 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class GenerateConfig:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -103,7 +109,6 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
||||
self.sd.load_model()
|
||||
device = torch.device(self.device)
|
||||
|
||||
|
||||
if self.model_config.is_xl:
|
||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=self.sd.vae,
|
||||
@@ -114,6 +119,8 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
).to(device, dtype=self.torch_dtype)
|
||||
elif self.model_config.is_pixart:
|
||||
pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
|
||||
else:
|
||||
raise NotImplementedError("Only XL models are supported")
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
@@ -130,6 +137,9 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
||||
for i, batch in enumerate(self.data_loader):
|
||||
batch: DataLoaderBatchDTO = batch
|
||||
|
||||
gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
|
||||
generator = torch.manual_seed(gen_seed)
|
||||
|
||||
file_item: FileItemDTO = batch.file_items[0]
|
||||
img_path = file_item.path
|
||||
img_filename = os.path.basename(img_path)
|
||||
@@ -152,18 +162,76 @@ class Img2ImgGenerator(BaseExtensionProcess):
|
||||
img: torch.Tensor = batch.tensor.clone()
|
||||
image = self.to_pil(img)
|
||||
|
||||
|
||||
# image.save(output_depth_path)
|
||||
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
||||
if self.model_config.is_pixart:
|
||||
pipe: PixArtSigmaPipeline = pipe
|
||||
|
||||
gen_images = pipe.__call__(
|
||||
prompt=caption,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
image=image,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
strength=self.generate_config.denoise_strength,
|
||||
).images[0]
|
||||
# Encode the full image once
|
||||
encoded_image = pipe.vae.encode(
|
||||
pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
|
||||
if hasattr(encoded_image, "latent_dist"):
|
||||
latents = encoded_image.latent_dist.sample(generator)
|
||||
elif hasattr(encoded_image, "latents"):
|
||||
latents = encoded_image.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
latents = pipe.vae.config.scaling_factor * latents
|
||||
|
||||
# latents = self.sd.encode_images(img)
|
||||
|
||||
# self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
|
||||
# start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
|
||||
# timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
|
||||
# timestep = timestep.to(device, dtype=torch.int32)
|
||||
# latent = latent.to(device, dtype=self.torch_dtype)
|
||||
# noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
|
||||
# latent = self.sd.add_noise(latent, noise, timestep)
|
||||
# timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
|
||||
batch_size = 1
|
||||
num_images_per_prompt = 1
|
||||
|
||||
shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
|
||||
image.width // pipe.vae_scale_factor)
|
||||
noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
|
||||
|
||||
# noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
|
||||
num_inference_steps = self.generate_config.sample_steps
|
||||
strength = self.generate_config.denoise_strength
|
||||
# Get timesteps
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
|
||||
timesteps = pipe.scheduler.timesteps[t_start:]
|
||||
timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
latents = pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
|
||||
gen_images = pipe.__call__(
|
||||
prompt=caption,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
# strength=self.generate_config.denoise_strength,
|
||||
use_resolution_binning=False,
|
||||
output_type="np"
|
||||
).images[0]
|
||||
gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
|
||||
gen_images = Image.fromarray(gen_images)
|
||||
else:
|
||||
pipe: StableDiffusionXLImg2ImgPipeline = pipe
|
||||
|
||||
gen_images = pipe.__call__(
|
||||
prompt=caption,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
image=image,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
strength=self.generate_config.denoise_strength,
|
||||
).images[0]
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
gen_images.save(output_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user