diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 7859d39d..3a7caf15 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -510,7 +510,7 @@ def unet_dtype(device=None, model_params=0): return torch.float8_e4m3fn if args.unet_in_fp8_e5m2: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params): + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): return torch.float16 return torch.float32 @@ -710,7 +710,7 @@ def is_device_mps(device): return True return False -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled if device is not None: @@ -736,10 +736,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if is_intel_xpu(): return True - if torch.cuda.is_bf16_supported(): + if torch.version.hip: return True props = torch.cuda.get_device_properties("cuda") + if props.major >= 8: + return True + if props.major < 6: return False @@ -752,7 +755,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if x in props.name.lower(): fp16_works = True - if fp16_works: + if fp16_works or manual_cast: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True diff --git a/ldm_patched/modules/utils.py b/ldm_patched/modules/utils.py index f8283a86..18259199 100644 --- a/ldm_patched/modules/utils.py +++ b/ldm_patched/modules/utils.py @@ -413,6 +413,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): + x = max(0, min(s.shape[-1] - overlap, x)) + y = max(0, min(s.shape[-2] - overlap, y)) s_in = s[:,:,y:y+tile_y,x:x+tile_x] ps = function(s_in).to(output_device)