Fix issue with gradient checkpointing and flux kontext

This commit is contained in:
Jaret Burkett
2025-06-26 19:03:12 -06:00
parent 446b0b6989
commit 4f91cb7148

View File

@@ -385,15 +385,24 @@ class FluxKontextModel(BaseModel):
with torch.no_grad():
control_tensor = batch.control_tensor
if control_tensor is not None:
self.vae.to(self.device_torch)
# we are not packed here, so we just need to pass them so we can pack them later
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear')
if batch.tensor is not None:
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
else:
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
target_h = batch.file_items[0].crop_height
target_w = batch.file_items[0].crop_width
if control_tensor.shape[2] != target_h or control_tensor.shape[3] != target_w:
control_tensor = F.interpolate(control_tensor, size=(target_h, target_w), mode='bilinear')
control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype)
latents = torch.cat((latents, control_latent), dim=1)
self.vae.to('cpu')
return latents.detach()
return latents.detach()