Fixed issue that prevented full fine-tuning of flux2 models when using gradient checkpointing

This commit is contained in:
Jaret Burkett
2026-02-06 16:18:43 -07:00
parent 1422789452
commit e82cf6eec2

View File

@@ -170,8 +170,6 @@ class Flux2(nn.Module):
for block in self.double_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
txt.requires_grad_(True)
img, txt = ckpt.checkpoint(
block,
img,
@@ -180,6 +178,7 @@ class Flux2(nn.Module):
pe_ctx,
double_block_mod_img,
double_block_mod_txt,
use_reentrant=False,
)
else:
img, txt = block(
@@ -196,12 +195,12 @@ class Flux2(nn.Module):
for i, block in enumerate(self.single_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
img = ckpt.checkpoint(
block,
img,
pe,
single_block_mod,
use_reentrant=False,
)
else:
img = block(