From 4f91cb7148e521a4d1479cccd3e8ad62d777d4e3 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 26 Jun 2025 19:03:12 -0600 Subject: [PATCH] Fix issue with gradient checkpointing and flux kontext --- .../diffusion_models/flux_kontext/flux_kontext.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py index 07212151..da6acb94 100644 --- a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py +++ b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py @@ -385,15 +385,24 @@ class FluxKontextModel(BaseModel): with torch.no_grad(): control_tensor = batch.control_tensor if control_tensor is not None: + self.vae.to(self.device_torch) # we are not packed here, so we just need to pass them so we can pack them later control_tensor = control_tensor * 2 - 1 control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype) # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it - if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: - control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear') + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if control_tensor.shape[2] != target_h or control_tensor.shape[3] != target_w: + control_tensor = F.interpolate(control_tensor, size=(target_h, target_w), mode='bilinear') control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype) latents = torch.cat((latents, control_latent), dim=1) + self.vae.to('cpu') - return latents.detach() \ No newline at end of file + return latents.detach() \ No newline at end of file