mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with gradient checkpointing and flux kontext
This commit is contained in:
@@ -385,15 +385,24 @@ class FluxKontextModel(BaseModel):
|
||||
with torch.no_grad():
|
||||
control_tensor = batch.control_tensor
|
||||
if control_tensor is not None:
|
||||
self.vae.to(self.device_torch)
|
||||
# we are not packed here, so we just need to pass them so we can pack them later
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear')
|
||||
if batch.tensor is not None:
|
||||
target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3]
|
||||
else:
|
||||
# When caching latents, batch.tensor is None. We get the size from the file_items instead.
|
||||
target_h = batch.file_items[0].crop_height
|
||||
target_w = batch.file_items[0].crop_width
|
||||
|
||||
if control_tensor.shape[2] != target_h or control_tensor.shape[3] != target_w:
|
||||
control_tensor = F.interpolate(control_tensor, size=(target_h, target_w), mode='bilinear')
|
||||
|
||||
control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
self.vae.to('cpu')
|
||||
|
||||
return latents.detach()
|
||||
return latents.detach()
|
||||
Reference in New Issue
Block a user