diff --git a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py index 07212151..da6acb94 100644 --- a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py +++ b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py @@ -385,15 +385,24 @@ class FluxKontextModel(BaseModel): with torch.no_grad(): control_tensor = batch.control_tensor if control_tensor is not None: + self.vae.to(self.device_torch) # we are not packed here, so we just need to pass them so we can pack them later control_tensor = control_tensor * 2 - 1 control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype) # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it - if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: - control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear') + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + if control_tensor.shape[2] != target_h or control_tensor.shape[3] != target_w: + control_tensor = F.interpolate(control_tensor, size=(target_h, target_w), mode='bilinear') control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype) latents = torch.cat((latents, control_latent), dim=1) + self.vae.to('cpu') - return latents.detach() \ No newline at end of file + return latents.detach() \ No newline at end of file