diff --git a/extensions_built_in/diffusion_models/flux2/src/model.py b/extensions_built_in/diffusion_models/flux2/src/model.py index 211b438c..cffae9f6 100644 --- a/extensions_built_in/diffusion_models/flux2/src/model.py +++ b/extensions_built_in/diffusion_models/flux2/src/model.py @@ -170,8 +170,6 @@ class Flux2(nn.Module): for block in self.double_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - img.requires_grad_(True) - txt.requires_grad_(True) img, txt = ckpt.checkpoint( block, img, @@ -180,6 +178,7 @@ class Flux2(nn.Module): pe_ctx, double_block_mod_img, double_block_mod_txt, + use_reentrant=False, ) else: img, txt = block( @@ -196,12 +195,12 @@ class Flux2(nn.Module): for i, block in enumerate(self.single_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - img.requires_grad_(True) img = ckpt.checkpoint( block, img, pe, single_block_mod, + use_reentrant=False, ) else: img = block(