From a61914ddd979c8fc6046801d5824dd8336cc0287 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 12 Jul 2023 20:02:22 -0600 Subject: [PATCH] Add support for threshold, ratio, and quantile on traditional LoRa --- jobs/process/ExtractLoconProcess.py | 2 +- jobs/process/ExtractLoraProcess.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jobs/process/ExtractLoconProcess.py b/jobs/process/ExtractLoconProcess.py index 08e9f00c..e0e16506 100644 --- a/jobs/process/ExtractLoconProcess.py +++ b/jobs/process/ExtractLoconProcess.py @@ -35,7 +35,7 @@ class ExtractLoconProcess(BaseExtractProcess): self.disable_cp = self.get_conf('disable_cp', False) # set modes - if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']: + if self.mode not in list(mode_dict.keys()): raise ValueError(f"Unknown mode: {self.mode}") self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type']) self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type']) diff --git a/jobs/process/ExtractLoraProcess.py b/jobs/process/ExtractLoraProcess.py index af1cdde4..70ea7e16 100644 --- a/jobs/process/ExtractLoraProcess.py +++ b/jobs/process/ExtractLoraProcess.py @@ -2,10 +2,23 @@ from collections import OrderedDict from toolkit.lycoris_utils import extract_diff from .BaseExtractProcess import BaseExtractProcess + mode_dict = { 'fixed': { - 'dim': 4, + 'dim': 64, 'type': int + }, + 'threshold': { + 'dim': 0, + 'type': float + }, + 'ratio': { + 'dim': 0.5, + 'type': float + }, + 'quantile': { + 'dim': 0.5, + 'type': float } } @@ -20,7 +33,7 @@ class ExtractLoraProcess(BaseExtractProcess): self.mode = self.get_conf('mode', 'fixed') # set modes - if self.mode not in ['fixed']: + if self.mode not in list(mode_dict.keys()): raise ValueError(f"Unknown mode: {self.mode}") self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type']) self.use_sparse_bias = self.get_conf('use_sparse_bias', False)