mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes
This commit is contained in:
@@ -123,11 +123,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
|
||||
if self.adapter_type == 'photo_maker':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
|
||||
self.fuse_module = FuseModule(embed_dim)
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
|
||||
|
||||
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
@@ -288,7 +288,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder = SAFEVisionModel(
|
||||
in_channels=3,
|
||||
num_tokens=self.config.safe_tokens,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'],
|
||||
num_vectors=sd.unet_unwrapped.config['cross_attention_dim'],
|
||||
reducer_channels=self.config.safe_reducer_channels,
|
||||
channels=self.config.safe_channels,
|
||||
downscale_factor=8
|
||||
|
||||
@@ -156,6 +156,7 @@ class StableDiffusion:
|
||||
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
|
||||
self.vae: Union[None, 'AutoencoderKL']
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
|
||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
@@ -1505,7 +1506,7 @@ class StableDiffusion:
|
||||
if width is None:
|
||||
width = pixel_width // VAE_SCALE_FACTOR
|
||||
|
||||
num_channels = self.unet.config['in_channels']
|
||||
num_channels = self.unet_unwrapped.config['in_channels']
|
||||
if self.is_flux:
|
||||
# has 64 channels in for some reason
|
||||
num_channels = 16
|
||||
@@ -1813,8 +1814,8 @@ class StableDiffusion:
|
||||
ratios=aspect_ratio_bin)
|
||||
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
if self.unet.config.sample_size == 128 or (
|
||||
self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
|
||||
if self.unet_unwrapped.config.sample_size == 128 or (
|
||||
self.vae_scale_factor == 16 and self.unet_unwrapped.config.sample_size == 64):
|
||||
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
|
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
|
||||
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
|
||||
@@ -1837,7 +1838,7 @@ class StableDiffusion:
|
||||
)[0]
|
||||
|
||||
# learned sigma
|
||||
if self.unet.config.out_channels // 2 == self.unet.config.in_channels:
|
||||
if self.unet_unwrapped.config.out_channels // 2 == self.unet_unwrapped.config.in_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
@@ -1865,7 +1866,7 @@ class StableDiffusion:
|
||||
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
# # handle guidance
|
||||
if self.unet.config.guidance_embeds:
|
||||
if self.unet_unwrapped.config.guidance_embeds:
|
||||
if isinstance(guidance_embedding_scale, list):
|
||||
guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
|
||||
else:
|
||||
@@ -2457,7 +2458,7 @@ class StableDiffusion:
|
||||
# diffusers
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
|
||||
Reference in New Issue
Block a user