mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with the way chroma handled gradient checkpointing.
This commit is contained in:
@@ -96,6 +96,7 @@ class Chroma(nn.Module):
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
self.gradient_checkpointing = False
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@@ -162,11 +163,14 @@ class Chroma(nn.Module):
|
||||
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def enable_gradient_checkpointing(self, enable: bool = True):
|
||||
self.gradient_checkpointing = enable
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -246,8 +250,7 @@ class Chroma(nn.Module):
|
||||
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
|
||||
double_mod = [img_mod, txt_mod]
|
||||
|
||||
# just in case in different GPU for simple pipeline parallel
|
||||
if self.training:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img.requires_grad_(True)
|
||||
img, txt = ckpt.checkpoint(
|
||||
block, img, txt, pe, double_mod, txt_img_mask
|
||||
@@ -260,7 +263,7 @@ class Chroma(nn.Module):
|
||||
img = torch.cat((txt, img), 1)
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
||||
if self.training:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img.requires_grad_(True)
|
||||
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user