mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Merge branch 'lllyasviel:main' into cmd-arg-for-text-encoders
This commit is contained in:
17
README.md
17
README.md
@@ -6,6 +6,23 @@ The name "Forge" is inspired from "Minecraft Forge". This project is aimed at be
|
||||
|
||||
Forge is currently based on SD-WebUI 1.10.1 at [this commit](https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/82a973c04367123ae98bd9abdf80d9eda9b910e2). (Because original SD-WebUI is almost static now, Forge will sync with original WebUI every 90 days, or when important fixes.)
|
||||
|
||||
### Forge Issue&Discussion is Under Attack Now
|
||||
|
||||
Today, a group of attackers attacked Forge Repo questions/discussions by sending spam files with viruses to all questions/discussions.
|
||||
|
||||
As a protection, issue and discussion is in temp outage now. We will resume issues and discussions soon.
|
||||
|
||||
Screenshots:
|
||||
|
||||
(DO NOT download any file from those attackers!)
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
# Quick List
|
||||
|
||||
[Gradio 4 UI Must Read (TLDR: You need to use RIGHT MOUSE BUTTON to move canvas!)](https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/853)
|
||||
|
||||
@@ -104,6 +104,11 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
if storage_dtype in ['gguf']:
|
||||
from backend.operations_gguf import bake_gguf_model
|
||||
model.computation_dtype = torch.float16
|
||||
model = bake_gguf_model(model)
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
@@ -162,6 +167,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
model.initial_device = initial_device
|
||||
model.offload_device = offload_device
|
||||
|
||||
if storage_dtype in ['gguf']:
|
||||
from backend.operations_gguf import bake_gguf_model
|
||||
model = bake_gguf_model(model)
|
||||
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
|
||||
@@ -405,12 +405,24 @@ class ForgeOperationsGGUF(ForgeOperations):
|
||||
self.weight = state_dict[prefix + 'weight']
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = state_dict[prefix + 'bias']
|
||||
if self.weight is not None and hasattr(self.weight, 'parent'):
|
||||
self.weight.parent = self
|
||||
if self.bias is not None and hasattr(self.bias, 'parent'):
|
||||
self.bias.parent = self
|
||||
return
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = utils.tensor2parameter(fn(self.weight))
|
||||
if self.bias is not None:
|
||||
self.bias = utils.tensor2parameter(fn(self.bias))
|
||||
for i in range(5):
|
||||
quant_state_name = f'quant_state_{i}'
|
||||
quant_state = getattr(self, quant_state_name, None)
|
||||
if quant_state is not None:
|
||||
quant_state = fn(quant_state)
|
||||
quant_state = utils.tensor2parameter(quant_state)
|
||||
setattr(self, quant_state_name, quant_state)
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -27,6 +27,7 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
self.gguf_type = tensor.tensor_type
|
||||
self.gguf_real_shape = torch.Size(reversed(list(tensor.shape)))
|
||||
self.gguf_cls = quants_mapping.get(self.gguf_type, None)
|
||||
self.parent = None
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@@ -36,6 +37,9 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad)
|
||||
|
||||
def dequantize_as_pytorch_parameter(self):
|
||||
if self.parent is None:
|
||||
self.parent = torch.nn.Module()
|
||||
self.gguf_cls.bake_layer(self.parent, self, computation_dtype=torch.float16)
|
||||
return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
@@ -43,6 +47,7 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
new.gguf_type = self.gguf_type
|
||||
new.gguf_real_shape = self.gguf_real_shape
|
||||
new.gguf_cls = self.gguf_cls
|
||||
new.parent = self.parent
|
||||
return new
|
||||
|
||||
def pin_memory(self, device=None):
|
||||
@@ -50,17 +55,38 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
new.gguf_type = self.gguf_type
|
||||
new.gguf_real_shape = self.gguf_real_shape
|
||||
new.gguf_cls = self.gguf_cls
|
||||
new.parent = self.parent
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def make(cls, data, gguf_type, gguf_cls, gguf_real_shape):
|
||||
def make(cls, data, gguf_type, gguf_cls, gguf_real_shape, parent):
|
||||
new = ParameterGGUF(data, no_init=True)
|
||||
new.gguf_type = gguf_type
|
||||
new.gguf_real_shape = gguf_real_shape
|
||||
new.gguf_cls = gguf_cls
|
||||
new.parent = parent
|
||||
return new
|
||||
|
||||
|
||||
def bake_gguf_model(model):
|
||||
computation_dtype = model.computation_dtype
|
||||
backed_layer_counter = 0
|
||||
|
||||
for m in model.modules():
|
||||
if hasattr(m, 'weight'):
|
||||
weight = m.weight
|
||||
if hasattr(weight, 'gguf_cls'):
|
||||
gguf_cls = weight.gguf_cls
|
||||
if gguf_cls is not None:
|
||||
backed_layer_counter += 1
|
||||
gguf_cls.bake_layer(m, weight, computation_dtype)
|
||||
|
||||
if backed_layer_counter > 0:
|
||||
print(f'GGUF backed {backed_layer_counter} layers.')
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def dequantize_tensor(tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
@@ -68,7 +94,7 @@ def dequantize_tensor(tensor):
|
||||
if not hasattr(tensor, 'gguf_cls'):
|
||||
return tensor
|
||||
|
||||
data = torch.tensor(tensor.data)
|
||||
data = tensor
|
||||
gguf_cls = tensor.gguf_cls
|
||||
gguf_real_shape = tensor.gguf_real_shape
|
||||
|
||||
|
||||
@@ -421,12 +421,15 @@ class LoraLoader:
|
||||
if gguf_cls is not None:
|
||||
from backend.operations_gguf import ParameterGGUF
|
||||
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
|
||||
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
|
||||
weight = ParameterGGUF.make(
|
||||
data=weight,
|
||||
gguf_type=gguf_type,
|
||||
gguf_cls=gguf_cls,
|
||||
gguf_real_shape=gguf_real_shape
|
||||
))
|
||||
gguf_real_shape=gguf_real_shape,
|
||||
parent=parent_layer
|
||||
)
|
||||
gguf_cls.bake_layer(parent_layer, weight, gguf_cls.computation_dtype)
|
||||
utils.set_attr_raw(self.model, key, weight)
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
||||
|
||||
21
javascript/dragdrop.js
vendored
21
javascript/dragdrop.js
vendored
@@ -26,26 +26,7 @@ function dropReplaceImage(imgWrap, files) {
|
||||
}
|
||||
};
|
||||
|
||||
if (imgWrap.closest('#pnginfo_image')) {
|
||||
// special treatment for PNG Info tab, wait for fetch request to finish
|
||||
const oldFetch = window.fetch;
|
||||
window.fetch = async(input, options) => {
|
||||
const response = await oldFetch(input, options);
|
||||
if ('api/predict/' === input) {
|
||||
const content = await response.text();
|
||||
window.fetch = oldFetch;
|
||||
window.requestAnimationFrame(() => callback());
|
||||
return new Response(content, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
headers: response.headers
|
||||
});
|
||||
}
|
||||
return response;
|
||||
};
|
||||
} else {
|
||||
window.requestAnimationFrame(() => callback());
|
||||
}
|
||||
window.requestAnimationFrame(() => callback());
|
||||
}
|
||||
|
||||
function eventHasFiles(e) {
|
||||
|
||||
180
packages_3rdparty/gguf/quants.py
vendored
180
packages_3rdparty/gguf/quants.py
vendored
@@ -8,6 +8,7 @@ from numpy.typing import DTypeLike
|
||||
|
||||
from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
|
||||
from .lazy import LazyNumpyTensor
|
||||
from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpack_4bits_u
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -89,6 +90,8 @@ class __Quant(ABC):
|
||||
grid_map: tuple[int | float, ...] = ()
|
||||
grid_hex: bytes | None = None
|
||||
|
||||
computation_dtype: torch.dtype = torch.bfloat16
|
||||
|
||||
def __init__(self):
|
||||
return TypeError("Quant conversion classes can't have instances")
|
||||
|
||||
@@ -141,18 +144,29 @@ class __Quant(ABC):
|
||||
return blocks.reshape(original_shape)
|
||||
|
||||
@classmethod
|
||||
def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
block_size, type_size = GGML_QUANT_SIZES[cls.qtype]
|
||||
def bake_layer(cls, layer, weight, computation_dtype):
|
||||
data = weight.data
|
||||
cls.computation_dtype = computation_dtype
|
||||
cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype]
|
||||
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = rows.reshape((n_blocks, type_size))
|
||||
blocks = cls.dequantize_blocks_pytorch(blocks, block_size, type_size)
|
||||
n_blocks = rows.numel() // cls.type_size
|
||||
blocks = rows.reshape((n_blocks, cls.type_size))
|
||||
weight.data = blocks
|
||||
cls.bake_layer_weight(layer, weight)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor:
|
||||
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent)
|
||||
return blocks.reshape(original_shape)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -289,14 +303,23 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
d, x = quick_split(blocks, [2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
x = change_4bits_order(x)
|
||||
weight.data = x
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
return
|
||||
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
qs = blocks[:, 2:]
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
d, qs = parent.quant_state_0, blocks
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
qs = quick_unpack_4bits(qs)
|
||||
return d * qs
|
||||
|
||||
@classmethod
|
||||
@@ -358,16 +381,31 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
m = blocks[:, 2:4].view(torch.float16)
|
||||
qs = blocks[:, 4:]
|
||||
d, m, qs = quick_split(blocks, [2, 2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
m = m.view(torch.float16).to(cls.computation_dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
||||
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
||||
qs = change_4bits_order(qs)
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
if m.device != qs.device:
|
||||
m = m.to(device=qs.device)
|
||||
|
||||
qs = quick_unpack_4bits_u(qs)
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
@@ -414,7 +452,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def to_uint32(x):
|
||||
# pytorch uint32 by City96 - Apache-2.0
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
@@ -422,11 +460,8 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d = blocks[:, :2]
|
||||
qh = blocks[:, 2:6]
|
||||
qs = blocks[:, 6:]
|
||||
|
||||
d = d.view(torch.float16).to(torch.float32)
|
||||
d, qh, qs = quick_split(blocks, [2, 4])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
@@ -436,7 +471,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
||||
return d * qs
|
||||
return (d * qs)
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
@@ -520,7 +555,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def to_uint32(x):
|
||||
# pytorch uint32 by City96 - Apache-2.0
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
@@ -528,11 +563,9 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
m = blocks[:, 2:4].view(torch.float16)
|
||||
qh = blocks[:, 4:8]
|
||||
qs = blocks[:, 8:]
|
||||
|
||||
d, m, qh, qs = quick_split(blocks, [2, 2, 4])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
m = m.view(torch.float16).to(cls.computation_dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
@@ -570,9 +603,23 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
return (x * d)
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
d, x = quick_split(blocks, [2])
|
||||
x = x.view(torch.int8)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
weight.data = x
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
x = blocks
|
||||
d = parent.quant_state_0
|
||||
|
||||
if d.device != x.device:
|
||||
d = d.to(device=x.device)
|
||||
|
||||
return x * d
|
||||
|
||||
@classmethod
|
||||
@@ -613,12 +660,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
n_blocks = blocks.shape[0]
|
||||
scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2])
|
||||
d = d.view(torch.float16)
|
||||
dmin = dmin.view(torch.float16)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
# (n_blocks, 16, 1)
|
||||
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
||||
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
||||
@@ -673,11 +720,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
n_blocks = blocks.shape[0]
|
||||
hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12])
|
||||
d = d.view(torch.float16)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
lscales, hscales = scales[:, :8], scales[:, 8:]
|
||||
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
|
||||
lscales = lscales.reshape((n_blocks, 16))
|
||||
@@ -754,19 +801,42 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
QK_K = 256
|
||||
def bake_layer_weight(cls, layer, weight): # Only compute one time when model load
|
||||
# Copyright Forge 2024
|
||||
|
||||
blocks = weight.data
|
||||
K_SCALE_SIZE = 12
|
||||
n_blocks = blocks.shape[0]
|
||||
d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE])
|
||||
d = d.view(torch.float16)
|
||||
dmin = dmin.view(torch.float16)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
sc, m = Q4_K.get_scale_min_pytorch(scales)
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32))
|
||||
qs = change_4bits_order(qs)
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# Compute in each diffusion iteration
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
if dm.device != qs.device:
|
||||
dm = dm.to(device=qs.device)
|
||||
|
||||
qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32))
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
@@ -797,14 +867,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
n_blocks = blocks.shape[0]
|
||||
d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8])
|
||||
d = d.view(torch.float16)
|
||||
dmin = dmin.view(torch.float16)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
sc, m = Q4_K.get_scale_min_pytorch(scales)
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
@@ -839,12 +909,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# Written by ChatGPT
|
||||
n_blocks = blocks.shape[0]
|
||||
ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16])
|
||||
scales = scales.view(torch.int8)
|
||||
d = d.view(torch.float16)
|
||||
scales = scales.view(torch.int8).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
||||
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
61
packages_3rdparty/gguf/quick_4bits_ops.py
vendored
Normal file
61
packages_3rdparty/gguf/quick_4bits_ops.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
# By Forge
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x):
|
||||
x = x.view(torch.uint8).view(x.size(0), -1)
|
||||
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
|
||||
reshaped = unpacked.view(x.size(0), -1)
|
||||
reshaped = reshaped.to(torch.int8) - 8
|
||||
return reshaped.view(torch.int32)
|
||||
|
||||
|
||||
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x):
|
||||
x = x.view(torch.uint8).view(x.size(0), -1)
|
||||
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
|
||||
reshaped = unpacked.view(x.size(0), -1)
|
||||
return reshaped.view(torch.int32)
|
||||
|
||||
|
||||
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
|
||||
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
|
||||
|
||||
|
||||
def quick_unpack_4bits(x):
|
||||
global native_4bits_lookup_table
|
||||
|
||||
s0 = x.size(0)
|
||||
x = x.view(torch.uint16)
|
||||
|
||||
if native_4bits_lookup_table.device != x.device:
|
||||
native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device)
|
||||
|
||||
y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten())
|
||||
y = y.view(torch.int8)
|
||||
y = y.view(s0, -1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def quick_unpack_4bits_u(x):
|
||||
global native_4bits_lookup_table_u
|
||||
|
||||
s0 = x.size(0)
|
||||
x = x.view(torch.uint16)
|
||||
|
||||
if native_4bits_lookup_table_u.device != x.device:
|
||||
native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device)
|
||||
|
||||
y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten())
|
||||
y = y.view(torch.uint8)
|
||||
y = y.view(s0, -1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def change_4bits_order(x):
|
||||
y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1)
|
||||
z = y[:, ::2] | (y[:, 1::2] << 4)
|
||||
return z
|
||||
Reference in New Issue
Block a user