mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-23 13:53:57 +00:00
Fixed issue that prevented full fine-tuning of flux2 models when using gradient checkpointing
This commit is contained in:
@@ -170,8 +170,6 @@ class Flux2(nn.Module):
|
||||
|
||||
for block in self.double_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img.requires_grad_(True)
|
||||
txt.requires_grad_(True)
|
||||
img, txt = ckpt.checkpoint(
|
||||
block,
|
||||
img,
|
||||
@@ -180,6 +178,7 @@ class Flux2(nn.Module):
|
||||
pe_ctx,
|
||||
double_block_mod_img,
|
||||
double_block_mod_txt,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
img, txt = block(
|
||||
@@ -196,12 +195,12 @@ class Flux2(nn.Module):
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img.requires_grad_(True)
|
||||
img = ckpt.checkpoint(
|
||||
block,
|
||||
img,
|
||||
pe,
|
||||
single_block_mod,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
img = block(
|
||||
|
||||
Reference in New Issue
Block a user