Allow user to set the attention backend. Add method to recomver from the occasional OOM if it is a rare event. Still exit if it ooms 3 times in a row.

This commit is contained in:
Jaret Burkett
2025-09-27 08:56:15 -06:00
parent 6da417261c
commit 3b1f7b0948
2 changed files with 40 additions and 10 deletions

View File

@@ -357,6 +357,8 @@ class TrainConfig:
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
self.sdp = kwargs.get('sdp', False)
# see https://huggingface.co/docs/diffusers/main/optimization/attention_backends#available-backends for options
self.attention_backend: str = kwargs.get('attention_backend', 'native') # native, flash, _flash_3_hub, _flash_3,
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', False)
self.train_refiner = kwargs.get('train_refiner', True)