mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
revise kernel
This commit is contained in:
@@ -49,9 +49,11 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
from transformers import CLIPTextConfig, CLIPTextModel
|
from transformers import CLIPTextConfig, CLIPTextModel
|
||||||
config = CLIPTextConfig.from_pretrained(config_path)
|
config = CLIPTextConfig.from_pretrained(config_path)
|
||||||
|
|
||||||
|
to_args = dict(device=memory_management.text_encoder_device(), dtype=memory_management.text_encoder_dtype())
|
||||||
|
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype()):
|
with using_forge_operations(**to_args):
|
||||||
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True)
|
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True).to(**to_args)
|
||||||
|
|
||||||
load_state_dict(model, state_dict, ignore_errors=[
|
load_state_dict(model, state_dict, ignore_errors=[
|
||||||
'transformer.text_projection.weight',
|
'transformer.text_projection.weight',
|
||||||
@@ -60,7 +62,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
], log_name=cls_name)
|
], log_name=cls_name)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
if component_name.startswith('text_encoder') and cls_name in ['T5EncoderModel']:
|
if cls_name == 'T5EncoderModel':
|
||||||
from transformers import T5EncoderModel, T5Config
|
from transformers import T5EncoderModel, T5Config
|
||||||
config = T5Config.from_pretrained(config_path)
|
config = T5Config.from_pretrained(config_path)
|
||||||
|
|
||||||
@@ -80,9 +82,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
if cls_name == 'UNet2DConditionModel':
|
if cls_name == 'UNet2DConditionModel':
|
||||||
unet_config = guess.unet_config.copy()
|
unet_config = guess.unet_config.copy()
|
||||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||||
unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size)
|
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||||
|
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||||
|
|
||||||
with using_forge_operations(device=memory_management.cpu, dtype=unet_config['dtype']):
|
unet_config['dtype'] = ini_dtype
|
||||||
|
unet_config['device'] = ini_device
|
||||||
|
|
||||||
|
with using_forge_operations(device=ini_device, dtype=ini_dtype):
|
||||||
model = IntegratedUNet2DConditionModel.from_config(unet_config)
|
model = IntegratedUNet2DConditionModel.from_config(unet_config)
|
||||||
model._internal_dict = unet_config
|
model._internal_dict = unet_config
|
||||||
|
|
||||||
|
|||||||
@@ -10,11 +10,9 @@ class CLIP:
|
|||||||
|
|
||||||
load_device = memory_management.text_encoder_device()
|
load_device = memory_management.text_encoder_device()
|
||||||
offload_device = memory_management.text_encoder_offload_device()
|
offload_device = memory_management.text_encoder_offload_device()
|
||||||
text_encoder_dtype = memory_management.text_encoder_dtype(load_device)
|
|
||||||
|
|
||||||
self.cond_stage_model = ModuleDict(model_dict)
|
self.cond_stage_model = ModuleDict(model_dict)
|
||||||
self.tokenizer = ObjectDict(tokenizer_dict)
|
self.tokenizer = ObjectDict(tokenizer_dict)
|
||||||
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
|
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user