mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Pixel shuffle adapter. Some bug fixes thrown in
This commit is contained in:
@@ -249,6 +249,17 @@ class StableDiffusion:
|
||||
@property
|
||||
def unet_unwrapped(self):
|
||||
return unwrap_model(self.unet)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
if self.vae is None:
|
||||
return 8
|
||||
divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
# flux packs this again,
|
||||
if self.is_flux:
|
||||
divisibility = divisibility * 4
|
||||
return divisibility
|
||||
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -1721,6 +1732,7 @@ class StableDiffusion:
|
||||
pixel_width=None,
|
||||
batch_size=1,
|
||||
noise_offset=0.0,
|
||||
num_channels=None,
|
||||
):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if height is None and pixel_height is None:
|
||||
@@ -1732,10 +1744,11 @@ class StableDiffusion:
|
||||
if width is None:
|
||||
width = pixel_width // VAE_SCALE_FACTOR
|
||||
|
||||
num_channels = self.unet_unwrapped.config['in_channels']
|
||||
if self.is_flux:
|
||||
# has 64 channels in for some reason
|
||||
num_channels = 16
|
||||
if num_channels is None:
|
||||
num_channels = self.unet_unwrapped.config['in_channels']
|
||||
if self.is_flux:
|
||||
# it gets packed, unpack it
|
||||
num_channels = num_channels // 4
|
||||
noise = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
|
||||
Reference in New Issue
Block a user