Bug fixes and little improvements here and there.

This commit is contained in:
Jaret Burkett
2024-06-08 06:24:20 -06:00
parent 833c833f28
commit 3f3636b788
12 changed files with 358 additions and 117 deletions

View File

@@ -39,7 +39,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -497,6 +498,12 @@ class StableDiffusion:
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
elif isinstance(self.adapter, ControlNetModel):
if self.is_xl:
Pipe = StableDiffusionXLControlNetPipeline
else:
Pipe = StableDiffusionControlNetPipeline
extra_args['controlnet'] = self.adapter
elif isinstance(self.adapter, ReferenceAdapter):
# pass the noise scheduler to the adapter
self.adapter.noise_scheduler = noise_scheduler
@@ -588,6 +595,10 @@ class StableDiffusion:
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, ControlNetModel):
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['image'] = validation_image
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([
transforms.ToTensor(),
@@ -967,6 +978,16 @@ class StableDiffusion:
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
# handle controlnet
if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs:
# go through each item and concat if doing cfg and it doesnt have the same shape
for idx, item in enumerate(kwargs['down_block_additional_residuals']):
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
for idx, item in enumerate(kwargs['mid_block_additional_residual']):
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0)
def scale_model_input(model_input, timestep_tensor):
if is_input_scaled:
return model_input
@@ -1383,11 +1404,13 @@ class StableDiffusion:
# move to device and dtype
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
# resize images if not divisible by 8
for i in range(len(image_list)):
image = image_list[i]
if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0:
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
images = torch.stack(image_list)
if isinstance(self.vae, AutoencoderTiny):
@@ -1756,6 +1779,9 @@ class StableDiffusion:
elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
adapter_device = self.adapter.device
elif isinstance(self.adapter, ControlNetModel):
requires_grad = self.adapter.conv_in.training
adapter_device = self.adapter.device
elif isinstance(self.adapter, ClipVisionAdapter):
requires_grad = self.adapter.embedder.training
adapter_device = self.adapter.device