use safer codes

This commit is contained in:
layerdiffusion
2024-08-31 10:55:19 -07:00
parent 1f91b35a43
commit 70a555906a
3 changed files with 7 additions and 7 deletions

View File

@@ -264,7 +264,7 @@ def split_state_dict(sd, additional_state_dicts: list = None):
return state_dict, guess
@torch.no_grad()
@torch.inference_mode()
def forge_loader(sd, additional_state_dicts=None):
try:
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts)

View File

@@ -459,7 +459,7 @@ def apply_token_merging(sd_model, token_merging_ratio):
return
@torch.no_grad()
@torch.inference_mode()
def forge_model_reload():
current_hash = str(model_data.forge_loading_parameters)

View File

@@ -151,7 +151,7 @@ class __Quant(ABC):
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // cls.type_size
blocks = rows.reshape((n_blocks, cls.type_size))
parameter.data = blocks.contiguous()
parameter.data = blocks.clone(memory_format=torch.contiguous_format)
cls.bake_inner(parameter)
parameter.baked = True
return
@@ -312,7 +312,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
d, x = quick_split(blocks, [2])
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
x = change_4bits_order(x).view(torch.uint8)
parameter.data = torch.cat([d, x], dim=-1).contiguous()
parameter.data = torch.cat([d, x], dim=-1).clone(memory_format=torch.contiguous_format)
return
@classmethod
@@ -389,7 +389,7 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
qs = change_4bits_order(qs).view(torch.uint8)
parameter.data = torch.cat([d, m, qs], dim=-1).contiguous()
parameter.data = torch.cat([d, m, qs], dim=-1).clone(memory_format=torch.contiguous_format)
return
@@ -601,7 +601,7 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
d, x = quick_split(blocks, [2])
x = x.view(torch.int8)
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8)
parameter.data = torch.cat([d, x], dim=-1).contiguous()
parameter.data = torch.cat([d, x], dim=-1).clone(memory_format=torch.contiguous_format)
return
@classmethod
@@ -808,7 +808,7 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
dm = dm.view(torch.uint8).reshape((n_blocks, -1))
qs = qs.view(torch.uint8)
parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous()
parameter.data = torch.cat([d, dm, qs], dim=-1).clone(memory_format=torch.contiguous_format)
return
@classmethod