From e82cf6eec2ebf040c6a2c1ef6bd5f5a2d1ab40e7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 6 Feb 2026 16:18:43 -0700 Subject: [PATCH] Fixed issue that prevented full fine-tuning of flux2 models when using gradient checkpointing --- extensions_built_in/diffusion_models/flux2/src/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/diffusion_models/flux2/src/model.py b/extensions_built_in/diffusion_models/flux2/src/model.py index 211b438c..cffae9f6 100644 --- a/extensions_built_in/diffusion_models/flux2/src/model.py +++ b/extensions_built_in/diffusion_models/flux2/src/model.py @@ -170,8 +170,6 @@ class Flux2(nn.Module): for block in self.double_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - img.requires_grad_(True) - txt.requires_grad_(True) img, txt = ckpt.checkpoint( block, img, @@ -180,6 +178,7 @@ class Flux2(nn.Module): pe_ctx, double_block_mod_img, double_block_mod_txt, + use_reentrant=False, ) else: img, txt = block( @@ -196,12 +195,12 @@ class Flux2(nn.Module): for i, block in enumerate(self.single_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - img.requires_grad_(True) img = ckpt.checkpoint( block, img, pe, single_block_mod, + use_reentrant=False, ) else: img = block(