mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Fix gradient checkpointing for hidream o1
This commit is contained in:
@@ -2148,6 +2148,10 @@ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
|
||||
def visual(self):
|
||||
return self.model.visual
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.model.visual.gradient_checkpointing_enable()
|
||||
self.model.language_model.gradient_checkpointing_enable()
|
||||
|
||||
@check_model_inputs
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user