mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed a bug where samples would fail if merging in lora on sampling for unquantized models. Quantize non ARA modules as uint8 when using an ARA
This commit is contained in:
@@ -1145,6 +1145,8 @@ class StableDiffusion:
|
||||
# the network to drastically speed up inference
|
||||
unique_network_weights = set([x.network_multiplier for x in image_configs])
|
||||
if len(unique_network_weights) == 1 and network.can_merge_in:
|
||||
# make sure it is on device before merging.
|
||||
self.unet.to(self.device_torch)
|
||||
can_merge_in = True
|
||||
merge_multiplier = unique_network_weights.pop()
|
||||
network.merge_in(merge_weight=merge_multiplier)
|
||||
|
||||
@@ -261,6 +261,7 @@ def quantize_model(
|
||||
base_model.accuracy_recovery_adapter = network
|
||||
|
||||
# quantize it
|
||||
lora_exclude_modules = []
|
||||
quantization_type = get_qtype(base_model.model_config.qtype)
|
||||
for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"):
|
||||
# the lora has already hijacked the original module
|
||||
@@ -271,10 +272,20 @@ def quantize_model(
|
||||
param.requires_grad = False
|
||||
quantize(orig_module, weights=quantization_type)
|
||||
freeze(orig_module)
|
||||
module_name = lora_module.lora_name.replace('$$', '.').replace('transformer.', '')
|
||||
lora_exclude_modules.append(module_name)
|
||||
if base_model.model_config.low_vram:
|
||||
# move it back to cpu
|
||||
orig_module.to("cpu")
|
||||
|
||||
pass
|
||||
# quantize additional layers
|
||||
print_acc(" - quantizing additional layers")
|
||||
quantization_type = get_qtype('uint8')
|
||||
quantize(
|
||||
model_to_quantize,
|
||||
weights=quantization_type,
|
||||
exclude=lora_exclude_modules
|
||||
)
|
||||
else:
|
||||
# quantize model the original way without an accuracy recovery adapter
|
||||
# move and quantize only certain pieces at a time.
|
||||
|
||||
Reference in New Issue
Block a user