diffusion in fp8 landed

This commit is contained in:
lllyasviel
2024-08-06 16:47:39 -07:00
committed by GitHub
parent dd8997ee2e
commit 71c94799d1
7 changed files with 96 additions and 46 deletions

View File

@@ -8,7 +8,7 @@ import platform
from enum import Enum
from backend import stream
from backend.args import args
from backend.args import args, dynamic_args
class VRAMState(Enum):
@@ -289,9 +289,8 @@ if 'rtx' in torch_device_name.lower():
current_loaded_models = []
def module_size(module, exclude_device=None):
def state_dict_size(sd, exclude_device=None):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
@@ -303,6 +302,10 @@ def module_size(module, exclude_device=None):
return module_mem
def module_size(module, exclude_device=None):
return state_dict_size(module.state_dict(), exclude_device=exclude_device)
class LoadedModel:
def __init__(self, model, memory_required):
self.model = model
@@ -563,20 +566,31 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype')
if unet_storage_dtype_overwrite is not None:
return unet_storage_dtype_overwrite
if args.unet_in_bf16:
return torch.bfloat16
if args.unet_in_fp16:
return torch.float16
if args.unet_in_fp8_e4m3fn:
return torch.float8_e4m3fn
if args.unet_in_fp8_e5m2:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32