diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index dd9c6062..94140c33 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -504,8 +504,11 @@ class StableDiffusion: if not self.is_flux: raise ValueError("Assistant lora is only supported for flux models currently") - # handle downloading from the hub if needed - if not os.path.exists(self.model_config.assistant_lora_path): + if os.path.isdir(self.model_config.assistant_lora_path): + self.model_config.assistant_lora_path = os.path.join( + self.model_config.assistant_lora_path, "pytorch_lora_weights.safetensors" + ) + elif not os.path.exists(self.model_config.assistant_lora_path): print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}") new_lora_path = hf_hub_download( self.model_config.assistant_lora_path,