diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index 31eb6986..3888909a 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -411,6 +411,8 @@ class ChromaModel(BaseModel): return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" # only save the unet transformer: Chroma = unwrap_model(self.model) state_dict = transformer.state_dict()