Bug fixes

This commit is contained in:
Jaret Burkett
2025-01-31 13:23:01 -07:00
parent 15a57bc89f
commit e6180d1e1d
3 changed files with 12 additions and 11 deletions

View File

@@ -123,11 +123,11 @@ class CustomAdapter(torch.nn.Module):
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
if self.adapter_type == 'photo_maker':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
self.fuse_module = FuseModule(embed_dim)
elif self.adapter_type == 'clip_fusion':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
@@ -288,7 +288,7 @@ class CustomAdapter(torch.nn.Module):
self.vision_encoder = SAFEVisionModel(
in_channels=3,
num_tokens=self.config.safe_tokens,
num_vectors=sd.unet.config['cross_attention_dim'],
num_vectors=sd.unet_unwrapped.config['cross_attention_dim'],
reducer_channels=self.config.safe_reducer_channels,
channels=self.config.safe_channels,
downscale_factor=8

View File

@@ -156,6 +156,7 @@ class StableDiffusion:
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
self.vae: Union[None, 'AutoencoderKL']
self.unet: Union[None, 'UNet2DConditionModel']
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
@@ -1505,7 +1506,7 @@ class StableDiffusion:
if width is None:
width = pixel_width // VAE_SCALE_FACTOR
num_channels = self.unet.config['in_channels']
num_channels = self.unet_unwrapped.config['in_channels']
if self.is_flux:
# has 64 channels in for some reason
num_channels = 16
@@ -1813,8 +1814,8 @@ class StableDiffusion:
ratios=aspect_ratio_bin)
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.unet.config.sample_size == 128 or (
self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
if self.unet_unwrapped.config.sample_size == 128 or (
self.vae_scale_factor == 16 and self.unet_unwrapped.config.sample_size == 64):
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
@@ -1837,7 +1838,7 @@ class StableDiffusion:
)[0]
# learned sigma
if self.unet.config.out_channels // 2 == self.unet.config.in_channels:
if self.unet_unwrapped.config.out_channels // 2 == self.unet_unwrapped.config.in_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
@@ -1865,7 +1866,7 @@ class StableDiffusion:
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
# # handle guidance
if self.unet.config.guidance_embeds:
if self.unet_unwrapped.config.guidance_embeds:
if isinstance(guidance_embedding_scale, list):
guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
else:
@@ -2457,7 +2458,7 @@ class StableDiffusion:
# diffusers
if self.is_flux:
# only save the unet
transformer: FluxTransformer2DModel = self.unet
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
transformer.save_pretrained(
save_directory=os.path.join(output_file, 'transformer'),
safe_serialization=True,