From e6180d1e1df9bfd520649ad08f03fcc3301b73f9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 31 Jan 2025 13:23:01 -0700 Subject: [PATCH] Bug fixes --- jobs/process/BaseSDTrainProcess.py | 4 ++-- toolkit/custom_adapter.py | 6 +++--- toolkit/stable_diffusion_model.py | 13 +++++++------ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3daa1905..1d8278e7 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -651,9 +651,9 @@ class BaseSDTrainProcess(BaseTrainProcess): self.accelerator.even_batches=False # # prepare all the models stuff for accelerator (hopefully we dont miss any) - if self.sd.vae is not None: - self.sd.vae = self.accelerator.prepare(self.sd.vae) + self.sd.vae = self.accelerator.prepare(self.sd.vae) if self.sd.unet is not None: + self.sd.unet_unwrapped = self.sd.unet self.sd.unet = self.accelerator.prepare(self.sd.unet) # todo always tdo it? self.modules_being_trained.append(self.sd.unet) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 12a4df4b..d73570d2 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -123,11 +123,11 @@ class CustomAdapter(torch.nn.Module): torch_dtype = get_torch_dtype(self.sd_ref().dtype) if self.adapter_type == 'photo_maker': sd = self.sd_ref() - embed_dim = sd.unet.config['cross_attention_dim'] + embed_dim = sd.unet_unwrapped.config['cross_attention_dim'] self.fuse_module = FuseModule(embed_dim) elif self.adapter_type == 'clip_fusion': sd = self.sd_ref() - embed_dim = sd.unet.config['cross_attention_dim'] + embed_dim = sd.unet_unwrapped.config['cross_attention_dim'] vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) if self.config.image_encoder_arch == 'clip': @@ -288,7 +288,7 @@ class CustomAdapter(torch.nn.Module): self.vision_encoder = SAFEVisionModel( in_channels=3, num_tokens=self.config.safe_tokens, - num_vectors=sd.unet.config['cross_attention_dim'], + num_vectors=sd.unet_unwrapped.config['cross_attention_dim'], reducer_channels=self.config.safe_reducer_channels, channels=self.config.safe_channels, downscale_factor=8 diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 078af580..18ac335d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -156,6 +156,7 @@ class StableDiffusion: self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] + self.unet_unwrapped: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler @@ -1505,7 +1506,7 @@ class StableDiffusion: if width is None: width = pixel_width // VAE_SCALE_FACTOR - num_channels = self.unet.config['in_channels'] + num_channels = self.unet_unwrapped.config['in_channels'] if self.is_flux: # has 64 channels in for some reason num_channels = 16 @@ -1813,8 +1814,8 @@ class StableDiffusion: ratios=aspect_ratio_bin) added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - if self.unet.config.sample_size == 128 or ( - self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): + if self.unet_unwrapped.config.sample_size == 128 or ( + self.vae_scale_factor == 16 and self.unet_unwrapped.config.sample_size == 64): resolution = torch.tensor([height, width]).repeat(batch_size, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) @@ -1837,7 +1838,7 @@ class StableDiffusion: )[0] # learned sigma - if self.unet.config.out_channels // 2 == self.unet.config.in_channels: + if self.unet_unwrapped.config.out_channels // 2 == self.unet_unwrapped.config.in_channels: noise_pred = noise_pred.chunk(2, dim=1)[0] else: noise_pred = noise_pred @@ -1865,7 +1866,7 @@ class StableDiffusion: txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) # # handle guidance - if self.unet.config.guidance_embeds: + if self.unet_unwrapped.config.guidance_embeds: if isinstance(guidance_embedding_scale, list): guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch) else: @@ -2457,7 +2458,7 @@ class StableDiffusion: # diffusers if self.is_flux: # only save the unet - transformer: FluxTransformer2DModel = self.unet + transformer: FluxTransformer2DModel = unwrap_model(self.unet) transformer.save_pretrained( save_directory=os.path.join(output_file, 'transformer'), safe_serialization=True,