diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 07dc91bb..5a1746b8 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -162,14 +162,16 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer text_cond_models.append(embedder) - if typename == 'FrozenCLIPEmbedder': + elif typename == 'FrozenCLIPEmbedder': embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer text_cond_models.append(embedder) - if typename == 'FrozenOpenCLIPEmbedder2': + elif typename == 'FrozenOpenCLIPEmbedder2': embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer text_cond_models.append(embedder) + else: + raise NotImplementedError('Bad Class Name:' + typename) if len(text_cond_models) == 1: sd_model.cond_stage_model = text_cond_models[0] @@ -178,11 +180,11 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_l.transformer - pass elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_g.transformer - pass + else: + raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__) timer.record("forge set components")