mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-06 11:10:10 +00:00
Added support for traditional LoRa extract using LoCon script
This commit is contained in:
@@ -8,6 +8,8 @@ from toolkit.metadata import get_meta_for_safetensors
|
||||
|
||||
from typing import ForwardRef
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class BaseExtractProcess(BaseProcess):
|
||||
process_id: int
|
||||
@@ -26,6 +28,8 @@ class BaseExtractProcess(BaseProcess):
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
self.dtype = self.get_conf('dtype', self.job.dtype)
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
|
||||
def run(self):
|
||||
# here instead of init because child init needs to go first
|
||||
@@ -70,6 +74,11 @@ class BaseExtractProcess(BaseProcess):
|
||||
# save
|
||||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(self.torch_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, self.output_path, save_meta)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user