diff --git a/toolkit/accelerator.py b/toolkit/accelerator.py index ebcf0095..0736f016 100644 --- a/toolkit/accelerator.py +++ b/toolkit/accelerator.py @@ -11,7 +11,10 @@ def get_accelerator() -> Accelerator: return global_accelerator def unwrap_model(model): - accelerator = get_accelerator() - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model + try: + accelerator = get_accelerator() + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + except Exception as e: + pass return model