mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Bug fixes
This commit is contained in:
@@ -156,6 +156,7 @@ class StableDiffusion:
|
||||
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
|
||||
self.vae: Union[None, 'AutoencoderKL']
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
self.unet_unwrapped: Union[None, 'UNet2DConditionModel']
|
||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
@@ -1505,7 +1506,7 @@ class StableDiffusion:
|
||||
if width is None:
|
||||
width = pixel_width // VAE_SCALE_FACTOR
|
||||
|
||||
num_channels = self.unet.config['in_channels']
|
||||
num_channels = self.unet_unwrapped.config['in_channels']
|
||||
if self.is_flux:
|
||||
# has 64 channels in for some reason
|
||||
num_channels = 16
|
||||
@@ -1813,8 +1814,8 @@ class StableDiffusion:
|
||||
ratios=aspect_ratio_bin)
|
||||
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
if self.unet.config.sample_size == 128 or (
|
||||
self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
|
||||
if self.unet_unwrapped.config.sample_size == 128 or (
|
||||
self.vae_scale_factor == 16 and self.unet_unwrapped.config.sample_size == 64):
|
||||
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
|
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
|
||||
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
|
||||
@@ -1837,7 +1838,7 @@ class StableDiffusion:
|
||||
)[0]
|
||||
|
||||
# learned sigma
|
||||
if self.unet.config.out_channels // 2 == self.unet.config.in_channels:
|
||||
if self.unet_unwrapped.config.out_channels // 2 == self.unet_unwrapped.config.in_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
@@ -1865,7 +1866,7 @@ class StableDiffusion:
|
||||
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
# # handle guidance
|
||||
if self.unet.config.guidance_embeds:
|
||||
if self.unet_unwrapped.config.guidance_embeds:
|
||||
if isinstance(guidance_embedding_scale, list):
|
||||
guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
|
||||
else:
|
||||
@@ -2457,7 +2458,7 @@ class StableDiffusion:
|
||||
# diffusers
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
transformer: FluxTransformer2DModel = unwrap_model(self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
|
||||
Reference in New Issue
Block a user