diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index 641c8ae9..f421190c 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -283,9 +283,9 @@ def quantize_model( all_blocks: List[torch.nn.Module] = [] transformer_block_names = base_model.get_transformer_block_names() for name in transformer_block_names: - block = getattr(model_to_quantize, name, None) - if block is not None: - all_blocks.append(block) + block_list = getattr(model_to_quantize, name, None) + if block_list is not None: + all_blocks += list(block_list) base_model.print_and_status_update( f" - quantizing {len(all_blocks)} transformer blocks" )