From b61bf553ea6b3c8f15701a505a2fb816aafbfa76 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:08:47 -0700 Subject: [PATCH] revise inference dtype --- backend/diffusion_engine/flux.py | 3 ++- backend/diffusion_engine/sd15.py | 3 ++- backend/diffusion_engine/sd20.py | 3 ++- backend/diffusion_engine/sdxl.py | 3 ++- backend/patcher/unet.py | 4 ++-- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/backend/diffusion_engine/flux.py b/backend/diffusion_engine/flux.py index bff9525a..08f014a2 100644 --- a/backend/diffusion_engine/flux.py +++ b/backend/diffusion_engine/flux.py @@ -38,7 +38,8 @@ class Flux(ForgeDiffusionEngine): unet = UnetPatcher.from_model( model=huggingface_components['transformer'], diffusers_scheduler=None, - k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000) + k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000), + config=estimated_config ) self.text_processing_engine_l = ClassicTextProcessingEngine( diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py index 2117c79f..f16d046c 100644 --- a/backend/diffusion_engine/sd15.py +++ b/backend/diffusion_engine/sd15.py @@ -29,7 +29,8 @@ class StableDiffusion(ForgeDiffusionEngine): unet = UnetPatcher.from_model( model=huggingface_components['unet'], - diffusers_scheduler=huggingface_components['scheduler'] + diffusers_scheduler=huggingface_components['scheduler'], + config=estimated_config ) self.text_processing_engine = ClassicTextProcessingEngine( diff --git a/backend/diffusion_engine/sd20.py b/backend/diffusion_engine/sd20.py index 5620fbb7..f9af0b53 100644 --- a/backend/diffusion_engine/sd20.py +++ b/backend/diffusion_engine/sd20.py @@ -29,7 +29,8 @@ class StableDiffusion2(ForgeDiffusionEngine): unet = UnetPatcher.from_model( model=huggingface_components['unet'], - diffusers_scheduler=huggingface_components['scheduler'] + diffusers_scheduler=huggingface_components['scheduler'], + config=estimated_config ) self.text_processing_engine = ClassicTextProcessingEngine( diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py index 6bb6ecd1..fe3a3796 100644 --- a/backend/diffusion_engine/sdxl.py +++ b/backend/diffusion_engine/sdxl.py @@ -32,7 +32,8 @@ class StableDiffusionXL(ForgeDiffusionEngine): unet = UnetPatcher.from_model( model=huggingface_components['unet'], - diffusers_scheduler=huggingface_components['scheduler'] + diffusers_scheduler=huggingface_components['scheduler'], + config=estimated_config ) self.text_processing_engine_l = ClassicTextProcessingEngine( diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py index 48b7c639..e824eddd 100644 --- a/backend/patcher/unet.py +++ b/backend/patcher/unet.py @@ -8,12 +8,12 @@ from backend import memory_management class UnetPatcher(ModelPatcher): @classmethod - def from_model(cls, model, diffusers_scheduler, k_predictor=None): + def from_model(cls, model, diffusers_scheduler, config, k_predictor=None): parameters = memory_management.module_size(model) unet_dtype = memory_management.unet_dtype(model_params=parameters) load_device = memory_management.get_torch_device() initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype) - manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device) + manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device, supported_dtypes=config.supported_inference_dtypes) manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype model.to(device=initial_load_device, dtype=unet_dtype) model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)