Fix tiled vae for ace step 1.5 (#12253)

This commit is contained in:
comfyanonymous
2026-02-03 11:40:45 -08:00
committed by GitHub
parent ab1050bec3
commit b8315e66cb

View File

@@ -554,6 +554,8 @@ class VAE:
elif "decoder.layers.1.layers.0.beta" in sd:
config = {}
param_key = None
self.upscale_ratio = 2048
self.downscale_ratio = 2048
if "decoder.layers.2.layers.1.weight_v" in sd:
param_key = "decoder.layers.2.layers.1.weight_v"
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
@@ -562,6 +564,8 @@ class VAE:
if sd[param_key].shape[-1] == 12:
config["strides"] = [2, 4, 4, 6, 10]
self.audio_sample_rate = 48000
self.upscale_ratio = 1920
self.downscale_ratio = 1920
self.first_stage_model = AudioOobleckVAE(**config)
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
@@ -569,8 +573,6 @@ class VAE:
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
@@ -870,7 +872,7 @@ class VAE:
/ 3.0)
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else: