From 78b59c5e99f85efb000d2d011dad2fe27ded6796 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 16 Jul 2023 15:35:14 -0600 Subject: [PATCH] Added support for 3cleir, not fully tested --- config/examples/extract.example.json | 11 +++++++++++ jobs/process/ExtractLoraProcess.py | 22 ++++++++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/config/examples/extract.example.json b/config/examples/extract.example.json index f42eaf46..f0d6e32f 100644 --- a/config/examples/extract.example.json +++ b/config/examples/extract.example.json @@ -29,6 +29,17 @@ "mode": "quantile", "linear": 0.5, "conv": 0.5 + }, + { + "type": "lora", + "mode": "fixed", + "linear": 4 + }, + { + "type": "lora", + "mode": "fixed", + "linear": 64, + "conv": 32 } ] }, diff --git a/jobs/process/ExtractLoraProcess.py b/jobs/process/ExtractLoraProcess.py index 343a6bd4..76f0cc94 100644 --- a/jobs/process/ExtractLoraProcess.py +++ b/jobs/process/ExtractLoraProcess.py @@ -5,19 +5,23 @@ from .BaseExtractProcess import BaseExtractProcess mode_dict = { 'fixed': { - 'dim': 64, + 'linear': 4, + 'conv': 0, 'type': int }, 'threshold': { - 'dim': 0, + 'linear': 0, + 'conv': 0, 'type': float }, 'ratio': { - 'dim': 0.5, + 'linear': 0.5, + 'conv': 0, 'type': float }, 'quantile': { - 'dim': 0.5, + 'linear': 0.5, + 'conv': 0, 'type': float } } @@ -35,7 +39,9 @@ class ExtractLoraProcess(BaseExtractProcess): # set modes if self.mode not in list(mode_dict.keys()): raise ValueError(f"Unknown mode: {self.mode}") - self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type']) + self.linear = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) + self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) self.use_sparse_bias = self.get_conf('use_sparse_bias', False) self.sparsity = self.get_conf('sparsity', 0.98) @@ -47,13 +53,13 @@ class ExtractLoraProcess(BaseExtractProcess): self.job.model_base, self.job.model_extract, self.mode, - self.dim, - 0, + self.linear_param, + self.conv_param, self.job.device, self.use_sparse_bias, self.sparsity, small_conv=False, - linear_only=True, + linear_only=self.conv_param > 0.0000000001, extract_unet=self.extract_unet, extract_text_encoder=self.extract_text_encoder )