mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Bug fixes and little improvements here and there.
This commit is contained in:
@@ -39,7 +39,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -497,6 +498,12 @@ class StableDiffusion:
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
elif isinstance(self.adapter, ControlNetModel):
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLControlNetPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionControlNetPipeline
|
||||
extra_args['controlnet'] = self.adapter
|
||||
elif isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass the noise scheduler to the adapter
|
||||
self.adapter.noise_scheduler = noise_scheduler
|
||||
@@ -588,6 +595,10 @@ class StableDiffusion:
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, ControlNetModel):
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['image'] = validation_image
|
||||
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
@@ -967,6 +978,16 @@ class StableDiffusion:
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
||||
|
||||
# handle controlnet
|
||||
if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs:
|
||||
# go through each item and concat if doing cfg and it doesnt have the same shape
|
||||
for idx, item in enumerate(kwargs['down_block_additional_residuals']):
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
||||
for idx, item in enumerate(kwargs['mid_block_additional_residual']):
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0)
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
if is_input_scaled:
|
||||
return model_input
|
||||
@@ -1383,11 +1404,13 @@ class StableDiffusion:
|
||||
# move to device and dtype
|
||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
# resize images if not divisible by 8
|
||||
for i in range(len(image_list)):
|
||||
image = image_list[i]
|
||||
if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
|
||||
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
if isinstance(self.vae, AutoencoderTiny):
|
||||
@@ -1756,6 +1779,9 @@ class StableDiffusion:
|
||||
elif isinstance(self.adapter, T2IAdapter):
|
||||
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ControlNetModel):
|
||||
requires_grad = self.adapter.conv_in.training
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ClipVisionAdapter):
|
||||
requires_grad = self.adapter.embedder.training
|
||||
adapter_device = self.adapter.device
|
||||
|
||||
Reference in New Issue
Block a user