Improved lorm extraction and training

This commit is contained in:
Jaret Burkett
2023-10-28 08:21:59 -06:00
parent 0a79ac9604
commit 6f3e0d5af2
10 changed files with 559 additions and 196 deletions

View File

@@ -0,0 +1,102 @@
import os
from collections import OrderedDict
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.train_tools import get_torch_dtype
def flush():
torch.cuda.empty_cache()
gc.collect()
class PureLoraGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.device = self.get_conf('device', 'cuda')
self.device_torch = torch.device(self.device)
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
lorm_config = self.get_conf('lorm', None)
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
self.device_state_preset = get_train_sd_device_state_preset(
device=torch.device(self.device),
)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
def run(self):
super().run()
print("Loading model...")
with torch.no_grad():
self.sd.load_model()
self.sd.unet.eval()
self.sd.unet.to(self.device_torch)
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
te.to(self.device_torch)
else:
self.sd.text_encoder.eval()
self.sd.to(self.device_torch)
print(f"Converting to LoRM UNet")
# replace the unet with LoRMUnet
convert_diffusers_unet_to_lorm(
self.sd.unet,
config=self.lorm_config,
)
sample_folder = os.path.join(self.output_folder)
gen_img_config_list = []
sample_config = self.generate_config
start_seed = sample_config.seed
current_seed = start_seed
for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
filename = f"[time]_[count].{self.generate_config.ext}"
output_path = os.path.join(sample_folder, filename)
prompt = sample_config.prompts[i]
extra_args = {}
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
height=sample_config.height,
negative_prompt=sample_config.neg,
seed=current_seed,
guidance_scale=sample_config.guidance_scale,
guidance_rescale=sample_config.guidance_rescale,
num_inference_steps=sample_config.sample_steps,
network_multiplier=sample_config.network_multiplier,
output_path=output_path,
output_ext=sample_config.ext,
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
**extra_args
))
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -19,7 +19,24 @@ class AdvancedReferenceGeneratorExtension(Extension):
return ReferenceGenerator
# This is for generic training (LoRA, Dreambooth, FineTuning)
class PureLoraGenerator(Extension):
# uid must be unique, it is how the extension is identified
uid = "pure_lora_generator"
# name is the name of the extension for printing
name = "Pure LoRA Generator"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .PureLoraGenerator import PureLoraGenerator
return PureLoraGenerator
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
AdvancedReferenceGeneratorExtension,
AdvancedReferenceGeneratorExtension, PureLoraGenerator
]

View File

@@ -32,7 +32,6 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
def before_model_load(self):
pass
@@ -193,6 +192,15 @@ class SDTrainer(BaseSDTrainProcess):
self.network.is_active = was_network_active
return prior_pred
def before_unet_predict(self):
pass
def after_unet_predict(self):
pass
def end_of_training_loop(self):
pass
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
@@ -331,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess):
adapter_images_list = [adapter_images]
mask_multiplier_list = [mask_multiplier]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
noisy_latents_list,
noise_list,
@@ -366,7 +373,8 @@ class SDTrainer(BaseSDTrainProcess):
# flush()
pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
if has_adapter_img and (
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
@@ -406,8 +414,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
self.before_unet_predict()
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
@@ -416,6 +423,7 @@ class SDTrainer(BaseSDTrainProcess):
guidance_scale=1.0,
**pred_kwargs
)
self.after_unet_predict()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
@@ -442,7 +450,7 @@ class SDTrainer(BaseSDTrainProcess):
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# flush()
# flush()
with self.timer('optimizer_step'):
# apply gradients
@@ -460,4 +468,6 @@ class SDTrainer(BaseSDTrainProcess):
{'loss': loss.item()}
)
self.end_of_training_loop()
return loss_dict