From 441474e81f23c45c560e0ef921e0915920898cc6 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 24 Jan 2025 09:34:13 -0700 Subject: [PATCH] Added a flag to lora extraction script to do a full transformer extraction. --- scripts/extract_lora_from_flex.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/extract_lora_from_flex.py b/scripts/extract_lora_from_flex.py index 908c84e3..c80c8892 100644 --- a/scripts/extract_lora_from_flex.py +++ b/scripts/extract_lora_from_flex.py @@ -9,6 +9,7 @@ parser.add_argument("--tuned", type=str, required=True, help="Tuned model path") parser.add_argument("--output", type=str, required=True, help="Output path for lora") parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") +parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks") args = parser.parse_args() @@ -76,7 +77,7 @@ def extract_diff( lora_name = prefix + "." + name # lora_name = lora_name.replace(".", "_") layer = module.__class__.__name__ - if 'transformer_blocks' not in lora_name: + if 'transformer_blocks' not in lora_name and not args.full: continue if layer in {