mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Working multi gpu training. Still need a lot of tweaks and testing.
This commit is contained in:
17
toolkit/accelerator.py
Normal file
17
toolkit/accelerator.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from accelerate import Accelerator
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
global_accelerator = None
|
||||
|
||||
|
||||
def get_accelerator() -> Accelerator:
|
||||
global global_accelerator
|
||||
if global_accelerator is None:
|
||||
global_accelerator = Accelerator()
|
||||
return global_accelerator
|
||||
|
||||
def unwrap_model(model):
|
||||
accelerator = get_accelerator()
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
Reference in New Issue
Block a user