mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a flag to lora extraction script to do a full transformer extraction.
This commit is contained in:
@@ -9,6 +9,7 @@ parser.add_argument("--tuned", type=str, required=True, help="Tuned model path")
|
|||||||
parser.add_argument("--output", type=str, required=True, help="Output path for lora")
|
parser.add_argument("--output", type=str, required=True, help="Output path for lora")
|
||||||
parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction")
|
parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction")
|
||||||
parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction")
|
parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction")
|
||||||
|
parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -76,7 +77,7 @@ def extract_diff(
|
|||||||
lora_name = prefix + "." + name
|
lora_name = prefix + "." + name
|
||||||
# lora_name = lora_name.replace(".", "_")
|
# lora_name = lora_name.replace(".", "_")
|
||||||
layer = module.__class__.__name__
|
layer = module.__class__.__name__
|
||||||
if 'transformer_blocks' not in lora_name:
|
if 'transformer_blocks' not in lora_name and not args.full:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if layer in {
|
if layer in {
|
||||||
|
|||||||
Reference in New Issue
Block a user