mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-08 17:09:59 +00:00
revise inference dtype
This commit is contained in:
@@ -38,7 +38,8 @@ class Flux(ForgeDiffusionEngine):
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler=None,
|
||||
k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000)
|
||||
k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000),
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
|
||||
@@ -29,7 +29,8 @@ class StableDiffusion(ForgeDiffusionEngine):
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler']
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine = ClassicTextProcessingEngine(
|
||||
|
||||
@@ -29,7 +29,8 @@ class StableDiffusion2(ForgeDiffusionEngine):
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler']
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine = ClassicTextProcessingEngine(
|
||||
|
||||
@@ -32,7 +32,8 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['unet'],
|
||||
diffusers_scheduler=huggingface_components['scheduler']
|
||||
diffusers_scheduler=huggingface_components['scheduler'],
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
|
||||
@@ -8,12 +8,12 @@ from backend import memory_management
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
@classmethod
|
||||
def from_model(cls, model, diffusers_scheduler, k_predictor=None):
|
||||
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
|
||||
parameters = memory_management.module_size(model)
|
||||
unet_dtype = memory_management.unet_dtype(model_params=parameters)
|
||||
load_device = memory_management.get_torch_device()
|
||||
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
|
||||
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device, supported_dtypes=config.supported_inference_dtypes)
|
||||
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
|
||||
model.to(device=initial_load_device, dtype=unet_dtype)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
|
||||
|
||||
Reference in New Issue
Block a user