mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
21 lines
551 B
Python
21 lines
551 B
Python
from accelerate import Accelerator
|
|
from diffusers.utils.torch_utils import is_compiled_module
|
|
|
|
global_accelerator = None
|
|
|
|
|
|
def get_accelerator() -> Accelerator:
|
|
global global_accelerator
|
|
if global_accelerator is None:
|
|
global_accelerator = Accelerator()
|
|
return global_accelerator
|
|
|
|
def unwrap_model(model):
|
|
try:
|
|
accelerator = get_accelerator()
|
|
model = accelerator.unwrap_model(model)
|
|
model = model._orig_mod if is_compiled_module(model) else model
|
|
except Exception as e:
|
|
pass
|
|
return model
|