revise kernel

This commit is contained in:
lllyasviel
2024-08-07 13:28:12 -07:00
committed by GitHub
parent 1ef0844225
commit 14a759b5ca
10 changed files with 317 additions and 420 deletions

View File

@@ -8,7 +8,7 @@ from backend import memory_management
class UnetPatcher(ModelPatcher):
@classmethod
def from_model(cls, model, diffusers_scheduler):
def from_model(cls, model, diffusers_scheduler, k_predictor=None):
parameters = memory_management.module_size(model)
unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = memory_management.get_torch_device()
@@ -16,7 +16,7 @@ class UnetPatcher(ModelPatcher):
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
model.to(device=initial_load_device, dtype=unet_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
def __init__(self, *args, **kwargs):