diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 4934ba93..7190f56d 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -131,7 +131,7 @@ class ToolkitModuleMixin: lora_output_batch_size = lora_output.size(0) multiplier_batch_size = multiplier.size(0) if lora_output_batch_size != multiplier_batch_size: - num_interleaves = (lora_output_batch_size // 2) // multiplier_batch_size + num_interleaves = lora_output_batch_size // multiplier_batch_size multiplier = multiplier.repeat_interleave(num_interleaves) # multiplier = 1.0