Files
ai-toolkit/jobs/process/TrainLoRAHack.py

77 lines
2.1 KiB
Python

# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
def flush():
torch.cuda.empty_cache()
gc.collect()
class LoRAHack:
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'suppression')
class TrainLoRAHack(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
def hook_before_train_loop(self):
# we don't need text encoder so move it to cpu
self.sd.text_encoder.to("cpu")
flush()
# end hook_before_train_loop
if self.hack_config.type == 'suppression':
# set all params to self.current_suppression
params = self.network.parameters()
for param in params:
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = noise * 0.001
def supress_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
loss_dict = OrderedDict(
{'sup': 0.0}
)
# increase noise
for param in self.network.parameters():
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = param.data + noise * 0.001
return loss_dict
def hook_train_loop(self, batch):
if self.hack_config.type == 'suppression':
return self.supress_loop()
else:
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
# end hook_train_loop