Added a flag to lora extraction script to do a full transformer extraction.

This commit is contained in:
Jaret Burkett
2025-01-24 09:34:13 -07:00
parent a6a690f796
commit 441474e81f

View File

@@ -9,6 +9,7 @@ parser.add_argument("--tuned", type=str, required=True, help="Tuned model path")
parser.add_argument("--output", type=str, required=True, help="Output path for lora") parser.add_argument("--output", type=str, required=True, help="Output path for lora")
parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction")
parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction")
parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks")
args = parser.parse_args() args = parser.parse_args()
@@ -76,7 +77,7 @@ def extract_diff(
lora_name = prefix + "." + name lora_name = prefix + "." + name
# lora_name = lora_name.replace(".", "_") # lora_name = lora_name.replace(".", "_")
layer = module.__class__.__name__ layer = module.__class__.__name__
if 'transformer_blocks' not in lora_name: if 'transformer_blocks' not in lora_name and not args.full:
continue continue
if layer in { if layer in {