Fix gradient checkpointing for hidream o1

This commit is contained in:
Jaret Burkett
2026-05-10 14:41:31 -06:00
parent f47f9f1f2c
commit a8001e9a3f

View File

@@ -2148,6 +2148,10 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
def visual(self):
return self.model.visual
def enable_gradient_checkpointing(self):
self.model.visual.gradient_checkpointing_enable()
self.model.language_model.gradient_checkpointing_enable()
@check_model_inputs
def forward(
self,