Fixed a bug where samples would fail if merging in lora on sampling for unquantized models. Quantize non ARA modules as uint8 when using an ARA

This commit is contained in:
Jaret Burkett
2025-08-25 09:21:40 -06:00
parent f48d21caee
commit ea01a1c7d0
2 changed files with 14 additions and 1 deletions

View File

@@ -1145,6 +1145,8 @@ class StableDiffusion:
# the network to drastically speed up inference
unique_network_weights = set([x.network_multiplier for x in image_configs])
if len(unique_network_weights) == 1 and network.can_merge_in:
# make sure it is on device before merging.
self.unet.to(self.device_torch)
can_merge_in = True
merge_multiplier = unique_network_weights.pop()
network.merge_in(merge_weight=merge_multiplier)

View File

@@ -261,6 +261,7 @@ def quantize_model(
base_model.accuracy_recovery_adapter = network
# quantize it
lora_exclude_modules = []
quantization_type = get_qtype(base_model.model_config.qtype)
for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"):
# the lora has already hijacked the original module
@@ -271,10 +272,20 @@ def quantize_model(
param.requires_grad = False
quantize(orig_module, weights=quantization_type)
freeze(orig_module)
module_name = lora_module.lora_name.replace('$$', '.').replace('transformer.', '')
lora_exclude_modules.append(module_name)
if base_model.model_config.low_vram:
# move it back to cpu
orig_module.to("cpu")
pass
# quantize additional layers
print_acc(" - quantizing additional layers")
quantization_type = get_qtype('uint8')
quantize(
model_to_quantize,
weights=quantization_type,
exclude=lora_exclude_modules
)
else:
# quantize model the original way without an accuracy recovery adapter
# move and quantize only certain pieces at a time.