Pixel shuffle adapter. Some bug fixes thrown in

This commit is contained in:
Jaret Burkett
2025-03-29 21:15:01 -06:00
parent b94d7aafea
commit 860d892214
10 changed files with 594 additions and 11 deletions

View File

@@ -249,6 +249,17 @@ class StableDiffusion:
@property
def unet_unwrapped(self):
return unwrap_model(self.unet)
def get_bucket_divisibility(self):
if self.vae is None:
return 8
divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1)
# flux packs this again,
if self.is_flux:
divisibility = divisibility * 4
return divisibility
def load_model(self):
if self.is_loaded:
@@ -1721,6 +1732,7 @@ class StableDiffusion:
pixel_width=None,
batch_size=1,
noise_offset=0.0,
num_channels=None,
):
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
if height is None and pixel_height is None:
@@ -1732,10 +1744,11 @@ class StableDiffusion:
if width is None:
width = pixel_width // VAE_SCALE_FACTOR
num_channels = self.unet_unwrapped.config['in_channels']
if self.is_flux:
# has 64 channels in for some reason
num_channels = 16
if num_channels is None:
num_channels = self.unet_unwrapped.config['in_channels']
if self.is_flux:
# it gets packed, unpack it
num_channels = num_channels // 4
noise = torch.randn(
(
batch_size,