diff --git a/toolkit/unloader.py b/toolkit/unloader.py index 6c45926f..09ce1ba9 100644 --- a/toolkit/unloader.py +++ b/toolkit/unloader.py @@ -58,6 +58,6 @@ def unload_text_encoder(model: "BaseModel"): model.text_encoder = text_encoder_list else: # only has a single text encoder - model.text_encoder = FakeTextEncoder() + model.text_encoder = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype) flush()