From 70a555906a5fff6002de2de73c1b0c570b1e169e Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sat, 31 Aug 2024 10:55:19 -0700 Subject: [PATCH] use safer codes --- backend/loader.py | 2 +- modules/sd_models.py | 2 +- packages_3rdparty/gguf/quants.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index 24dc4b40..54dbb38e 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -264,7 +264,7 @@ def split_state_dict(sd, additional_state_dicts: list = None): return state_dict, guess -@torch.no_grad() +@torch.inference_mode() def forge_loader(sd, additional_state_dicts=None): try: state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0dcf3bfd..c3085785 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -459,7 +459,7 @@ def apply_token_merging(sd_model, token_merging_ratio): return -@torch.no_grad() +@torch.inference_mode() def forge_model_reload(): current_hash = str(model_data.forge_loading_parameters) diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index 1fac91ff..ac82f79b 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -151,7 +151,7 @@ class __Quant(ABC): rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) n_blocks = rows.numel() // cls.type_size blocks = rows.reshape((n_blocks, cls.type_size)) - parameter.data = blocks.contiguous() + parameter.data = blocks.clone(memory_format=torch.contiguous_format) cls.bake_inner(parameter) parameter.baked = True return @@ -312,7 +312,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): d, x = quick_split(blocks, [2]) d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) x = change_4bits_order(x).view(torch.uint8) - parameter.data = torch.cat([d, x], dim=-1).contiguous() + parameter.data = torch.cat([d, x], dim=-1).clone(memory_format=torch.contiguous_format) return @classmethod @@ -389,7 +389,7 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) qs = change_4bits_order(qs).view(torch.uint8) - parameter.data = torch.cat([d, m, qs], dim=-1).contiguous() + parameter.data = torch.cat([d, m, qs], dim=-1).clone(memory_format=torch.contiguous_format) return @@ -601,7 +601,7 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): d, x = quick_split(blocks, [2]) x = x.view(torch.int8) d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8) - parameter.data = torch.cat([d, x], dim=-1).contiguous() + parameter.data = torch.cat([d, x], dim=-1).clone(memory_format=torch.contiguous_format) return @classmethod @@ -808,7 +808,7 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): dm = dm.view(torch.uint8).reshape((n_blocks, -1)) qs = qs.view(torch.uint8) - parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous() + parameter.data = torch.cat([d, dm, qs], dim=-1).clone(memory_format=torch.contiguous_format) return @classmethod