diff --git a/ldm_patched/modules/model_base.py b/ldm_patched/modules/model_base.py index 9e141294..614b0d47 100644 --- a/ldm_patched/modules/model_base.py +++ b/ldm_patched/modules/model_base.py @@ -6,6 +6,7 @@ import torch from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from ldm_patched.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from ldm_patched.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +import ldm_patched.ldm.modules.attention import ldm_patched.modules.model_management import ldm_patched.modules.conds import ldm_patched.modules.ops @@ -209,17 +210,18 @@ class BaseModel(torch.nn.Module): self.inpaint_model = True def memory_required(self, input_shape): + area = input_shape[0] * input_shape[2] * input_shape[3] + dtype = self.manual_cast_dtype if self.manual_cast_dtype is not None else self.get_dtype() + dtype_size = ldm_patched.modules.model_management.dtype_size(dtype) + if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype - #TODO: this needs to be tweaked - area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * ldm_patched.modules.model_management.dtype_size(dtype) / 50) * (1024 * 1024) + scaler = 1.25 else: - #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. - area = input_shape[0] * input_shape[2] * input_shape[3] - return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) + scaler = 1.75 + if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32": + dtype_size = 4 + + return scaler * area * dtype_size * 16384 def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):