Fix issue with the way chroma handled gradient checkpointing.

This commit is contained in:
Jaret Burkett
2025-05-28 08:41:47 -06:00
parent 34f4c14cd6
commit ffaf2f154a

View File

@@ -96,6 +96,7 @@ class Chroma(nn.Module):
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
self.gradient_checkpointing = False
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -162,11 +163,14 @@ class Chroma(nn.Module):
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
persistent=False,
)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def enable_gradient_checkpointing(self, enable: bool = True):
self.gradient_checkpointing = enable
def forward(
self,
@@ -246,8 +250,7 @@ class Chroma(nn.Module):
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
double_mod = [img_mod, txt_mod]
# just in case in different GPU for simple pipeline parallel
if self.training:
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
img, txt = ckpt.checkpoint(
block, img, txt, pe, double_mod, txt_img_mask
@@ -260,7 +263,7 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
if self.training:
if torch.is_grad_enabled() and self.gradient_checkpointing:
img.requires_grad_(True)
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
else: