diff --git a/extensions_built_in/diffusion_models/chroma/src/model.py b/extensions_built_in/diffusion_models/chroma/src/model.py index 3b6c29bb..33cdbe62 100644 --- a/extensions_built_in/diffusion_models/chroma/src/model.py +++ b/extensions_built_in/diffusion_models/chroma/src/model.py @@ -96,6 +96,7 @@ class Chroma(nn.Module): self.params = params self.in_channels = params.in_channels self.out_channels = self.in_channels + self.gradient_checkpointing = False if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -162,11 +163,14 @@ class Chroma(nn.Module): torch.tensor(list(range(self.mod_index_length)), device="cpu"), persistent=False, ) - + @property def device(self): # Get the device of the module (assumes all parameters are on the same device) return next(self.parameters()).device + + def enable_gradient_checkpointing(self, enable: bool = True): + self.gradient_checkpointing = enable def forward( self, @@ -246,8 +250,7 @@ class Chroma(nn.Module): txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] double_mod = [img_mod, txt_mod] - # just in case in different GPU for simple pipeline parallel - if self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: img.requires_grad_(True) img, txt = ckpt.checkpoint( block, img, txt, pe, double_mod, txt_img_mask @@ -260,7 +263,7 @@ class Chroma(nn.Module): img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - if self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: img.requires_grad_(True) img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) else: