mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 03:31:30 +00:00
speedup swap/loading of all quant types
This commit is contained in:
@@ -54,7 +54,7 @@ class ForgeParams4bit(Params4bit):
|
||||
if device is not None and device.type == "cuda" and not self.bnb_quantized:
|
||||
return self._quantize(device)
|
||||
else:
|
||||
n = ForgeParams4bit(
|
||||
return ForgeParams4bit(
|
||||
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
|
||||
requires_grad=self.requires_grad,
|
||||
quant_state=copy_quant_state(self.quant_state, device),
|
||||
@@ -63,10 +63,7 @@ class ForgeParams4bit(Params4bit):
|
||||
quant_type=self.quant_type,
|
||||
quant_storage=self.quant_storage,
|
||||
bnb_quantized=self.bnb_quantized,
|
||||
module=self.module
|
||||
)
|
||||
self.module.quant_state = n.quant_state
|
||||
return n
|
||||
|
||||
|
||||
class ForgeLoader4Bit(torch.nn.Module):
|
||||
@@ -74,10 +71,16 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
|
||||
self.weight = None
|
||||
self.quant_state = None
|
||||
self.bias = None
|
||||
self.quant_type = quant_type
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = fn(self.weight)
|
||||
if self.bias is not None:
|
||||
self.bias = torch.nn.Parameter(fn(self.bias), requires_grad=False)
|
||||
return self
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
quant_state = getattr(self.weight, "quant_state", None)
|
||||
@@ -97,9 +100,7 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
quantized_stats=quant_state_dict,
|
||||
requires_grad=False,
|
||||
device=self.dummy.device,
|
||||
module=self
|
||||
)
|
||||
self.quant_state = self.weight.quant_state
|
||||
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
@@ -114,9 +115,7 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
blocksize=64,
|
||||
quant_type=self.quant_type,
|
||||
quant_storage=torch.uint8,
|
||||
module=self,
|
||||
)
|
||||
self.quant_state = self.weight.quant_state
|
||||
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
@@ -133,8 +132,6 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
blocksize=self.weight.blocksize,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
bnb_quantized=False
|
||||
)
|
||||
self.quant_state = self.weight.quant_state
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user