Working multi gpu training. Still need a lot of tweaks and testing.

This commit is contained in:
Jaret Burkett
2025-01-25 16:46:20 -07:00
parent 441474e81f
commit 5e663746b8
9 changed files with 432 additions and 294 deletions

17
toolkit/accelerator.py Normal file
View File

@@ -0,0 +1,17 @@
from accelerate import Accelerator
from diffusers.utils.torch_utils import is_compiled_module
global_accelerator = None
def get_accelerator() -> Accelerator:
global global_accelerator
if global_accelerator is None:
global_accelerator = Accelerator()
return global_accelerator
def unwrap_model(model):
accelerator = get_accelerator()
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model