Added support for traditional LoRa extract using LoCon script

This commit is contained in:
Jaret Burkett
2023-07-12 19:51:40 -06:00
parent 57f14e5ef2
commit 8d6edae9fd
9 changed files with 111 additions and 8 deletions

View File

@@ -8,6 +8,8 @@ from toolkit.metadata import get_meta_for_safetensors
from typing import ForwardRef
from toolkit.train_tools import get_torch_dtype
class BaseExtractProcess(BaseProcess):
process_id: int
@@ -26,6 +28,8 @@ class BaseExtractProcess(BaseProcess):
self.process_id = process_id
self.job = job
self.config = config
self.dtype = self.get_conf('dtype', self.job.dtype)
self.torch_dtype = get_torch_dtype(self.dtype)
def run(self):
# here instead of init because child init needs to go first
@@ -70,6 +74,11 @@ class BaseExtractProcess(BaseProcess):
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(self.torch_dtype)
state_dict[key] = v
# having issues with meta
save_file(state_dict, self.output_path, save_meta)