revise inference dtype

This commit is contained in:
layerdiffusion
2024-08-07 17:08:47 -07:00
parent 89ea50a19d
commit b61bf553ea
5 changed files with 10 additions and 6 deletions

View File

@@ -8,12 +8,12 @@ from backend import memory_management
class UnetPatcher(ModelPatcher):
@classmethod
def from_model(cls, model, diffusers_scheduler, k_predictor=None):
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
parameters = memory_management.module_size(model)
unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = memory_management.get_torch_device()
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device, supported_dtypes=config.supported_inference_dtypes)
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
model.to(device=initial_load_device, dtype=unet_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)