diff --git a/backend/loader.py b/backend/loader.py index f5d251b4..132e3fd6 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -49,9 +49,11 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p from transformers import CLIPTextConfig, CLIPTextModel config = CLIPTextConfig.from_pretrained(config_path) + to_args = dict(device=memory_management.text_encoder_device(), dtype=memory_management.text_encoder_dtype()) + with modeling_utils.no_init_weights(): - with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype()): - model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True) + with using_forge_operations(**to_args): + model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True).to(**to_args) load_state_dict(model, state_dict, ignore_errors=[ 'transformer.text_projection.weight', @@ -60,7 +62,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p ], log_name=cls_name) return model - if component_name.startswith('text_encoder') and cls_name in ['T5EncoderModel']: + if cls_name == 'T5EncoderModel': from transformers import T5EncoderModel, T5Config config = T5Config.from_pretrained(config_path) @@ -80,9 +82,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p if cls_name == 'UNet2DConditionModel': unet_config = guess.unet_config.copy() state_dict_size = memory_management.state_dict_size(state_dict) - unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size) + ini_dtype = memory_management.unet_dtype(model_params=state_dict_size) + ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype) - with using_forge_operations(device=memory_management.cpu, dtype=unet_config['dtype']): + unet_config['dtype'] = ini_dtype + unet_config['device'] = ini_device + + with using_forge_operations(device=ini_device, dtype=ini_dtype): model = IntegratedUNet2DConditionModel.from_config(unet_config) model._internal_dict = unet_config diff --git a/backend/patcher/clip.py b/backend/patcher/clip.py index b6c085da..a870d9bb 100644 --- a/backend/patcher/clip.py +++ b/backend/patcher/clip.py @@ -10,11 +10,9 @@ class CLIP: load_device = memory_management.text_encoder_device() offload_device = memory_management.text_encoder_offload_device() - text_encoder_dtype = memory_management.text_encoder_dtype(load_device) self.cond_stage_model = ModuleDict(model_dict) self.tokenizer = ObjectDict(tokenizer_dict) - self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) def clone(self):