mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Improved lorm extraction and training
This commit is contained in:
102
extensions_built_in/advanced_generator/PureLoraGenerator.py
Normal file
102
extensions_built_in/advanced_generator/PureLoraGenerator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
|
||||
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class PureLoraGenerator(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
self.device = self.get_conf('device', 'cuda')
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
||||
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
|
||||
self.dtype = self.get_conf('dtype', 'float16')
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
lorm_config = self.get_conf('lorm', None)
|
||||
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
|
||||
|
||||
self.device_state_preset = get_train_sd_device_state_preset(
|
||||
device=torch.device(self.device),
|
||||
)
|
||||
|
||||
self.progress_bar = None
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print("Loading model...")
|
||||
with torch.no_grad():
|
||||
self.sd.load_model()
|
||||
self.sd.unet.eval()
|
||||
self.sd.unet.to(self.device_torch)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
te.to(self.device_torch)
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
self.sd.to(self.device_torch)
|
||||
|
||||
print(f"Converting to LoRM UNet")
|
||||
# replace the unet with LoRMUnet
|
||||
convert_diffusers_unet_to_lorm(
|
||||
self.sd.unet,
|
||||
config=self.lorm_config,
|
||||
)
|
||||
|
||||
sample_folder = os.path.join(self.output_folder)
|
||||
gen_img_config_list = []
|
||||
|
||||
sample_config = self.generate_config
|
||||
start_seed = sample_config.seed
|
||||
current_seed = start_seed
|
||||
for i in range(len(sample_config.prompts)):
|
||||
if sample_config.walk_seed:
|
||||
current_seed = start_seed + i
|
||||
|
||||
filename = f"[time]_[count].{self.generate_config.ext}"
|
||||
output_path = os.path.join(sample_folder, filename)
|
||||
prompt = sample_config.prompts[i]
|
||||
extra_args = {}
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
width=sample_config.width,
|
||||
height=sample_config.height,
|
||||
negative_prompt=sample_config.neg,
|
||||
seed=current_seed,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
guidance_rescale=sample_config.guidance_rescale,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
network_multiplier=sample_config.network_multiplier,
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
print("Done generating images")
|
||||
# cleanup
|
||||
del self.sd
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -19,7 +19,24 @@ class AdvancedReferenceGeneratorExtension(Extension):
|
||||
return ReferenceGenerator
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class PureLoraGenerator(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "pure_lora_generator"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Pure LoRA Generator"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .PureLoraGenerator import PureLoraGenerator
|
||||
return PureLoraGenerator
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
AdvancedReferenceGeneratorExtension,
|
||||
AdvancedReferenceGeneratorExtension, PureLoraGenerator
|
||||
]
|
||||
|
||||
@@ -32,7 +32,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -193,6 +192,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.network.is_active = was_network_active
|
||||
return prior_pred
|
||||
|
||||
def before_unet_predict(self):
|
||||
pass
|
||||
|
||||
def after_unet_predict(self):
|
||||
pass
|
||||
|
||||
def end_of_training_loop(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
@@ -331,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_images_list = [adapter_images]
|
||||
mask_multiplier_list = [mask_multiplier]
|
||||
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
@@ -366,7 +373,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
if has_adapter_img and (
|
||||
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
@@ -406,8 +414,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
@@ -416,6 +423,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
@@ -442,7 +450,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
# flush()
|
||||
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
@@ -460,4 +468,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
|
||||
self.end_of_training_loop()
|
||||
|
||||
return loss_dict
|
||||
|
||||
Reference in New Issue
Block a user