diff --git a/toolkit/unloader.py b/toolkit/unloader.py index 09ce1ba..31f75e1 100644 --- a/toolkit/unloader.py +++ b/toolkit/unloader.py @@ -47,6 +47,7 @@ def unload_text_encoder(model: "BaseModel"): if hasattr(pipe, "text_encoder"): te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype) text_encoder_list.append(te) + pipe.text_encoder.to('cpu') pipe.text_encoder = te i = 2