mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
diffusion in fp8 landed
This commit is contained in:
@@ -8,7 +8,7 @@ import platform
|
||||
|
||||
from enum import Enum
|
||||
from backend import stream
|
||||
from backend.args import args
|
||||
from backend.args import args, dynamic_args
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
@@ -289,9 +289,8 @@ if 'rtx' in torch_device_name.lower():
|
||||
current_loaded_models = []
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None):
|
||||
def state_dict_size(sd, exclude_device=None):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
|
||||
@@ -303,6 +302,10 @@ def module_size(module, exclude_device=None):
|
||||
return module_mem
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None):
|
||||
return state_dict_size(module.state_dict(), exclude_device=exclude_device)
|
||||
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model, memory_required):
|
||||
self.model = model
|
||||
@@ -563,20 +566,31 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype')
|
||||
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
return unet_storage_dtype_overwrite
|
||||
|
||||
if args.unet_in_bf16:
|
||||
return torch.bfloat16
|
||||
|
||||
if args.unet_in_fp16:
|
||||
return torch.float16
|
||||
|
||||
if args.unet_in_fp8_e4m3fn:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if args.unet_in_fp8_e5m2:
|
||||
return torch.float8_e5m2
|
||||
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user