mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Bug fixes. Added some functionality to help with private extensions
This commit is contained in:
@@ -78,7 +78,7 @@ def flush():
|
||||
|
||||
|
||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
@@ -471,6 +471,7 @@ class StableDiffusion:
|
||||
batch_size=1,
|
||||
noise_offset=0.0,
|
||||
):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if height is None and pixel_height is None:
|
||||
raise ValueError("height or pixel_height must be specified")
|
||||
if width is None and pixel_width is None:
|
||||
@@ -493,6 +494,7 @@ class StableDiffusion:
|
||||
return noise
|
||||
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if self.is_xl:
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user