mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Working multi gpu training. Still need a lot of tweaks and testing.
This commit is contained in:
@@ -20,6 +20,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, Guid
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
@@ -59,8 +60,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
self.do_grad_scale = True
|
||||
@@ -70,12 +69,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter_config.train:
|
||||
self.do_grad_scale = False
|
||||
|
||||
if self.train_config.dtype in ["fp16", "float16"]:
|
||||
# patch the scaler to allow fp16 training
|
||||
org_unscale_grads = self.scaler._unscale_grads_
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
# if self.train_config.dtype in ["fp16", "float16"]:
|
||||
# # patch the scaler to allow fp16 training
|
||||
# org_unscale_grads = self.scaler._unscale_grads_
|
||||
# def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
# return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
# self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
self.cached_blank_embeds: Optional[PromptEmbeds] = None
|
||||
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
|
||||
@@ -168,11 +167,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
raise ValueError("Cannot unload text encoder if training text encoder")
|
||||
# cache embeddings
|
||||
|
||||
print("\n***** UNLOADING TEXT ENCODER *****")
|
||||
print("This will train only with a blank prompt or trigger word, if set")
|
||||
print("If this is not what you want, remove the unload_text_encoder flag")
|
||||
print("***********************************")
|
||||
print("")
|
||||
print_acc("\n***** UNLOADING TEXT ENCODER *****")
|
||||
print_acc("This will train only with a blank prompt or trigger word, if set")
|
||||
print_acc("If this is not what you want, remove the unload_text_encoder flag")
|
||||
print_acc("***********************************")
|
||||
print_acc("")
|
||||
self.sd.text_encoder_to(self.device_torch)
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("")
|
||||
if self.trigger_word is not None:
|
||||
@@ -484,7 +483,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
if torch.isnan(prior_loss).any():
|
||||
print("Prior loss is nan")
|
||||
print_acc("Prior loss is nan")
|
||||
prior_loss = None
|
||||
else:
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
@@ -553,7 +552,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise=noise,
|
||||
sd=self.sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
scaler=self.scaler,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -668,7 +666,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
@@ -823,7 +821,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
@@ -1446,8 +1444,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
quad_count=quad_count
|
||||
)
|
||||
else:
|
||||
print("No Clip Image")
|
||||
print([file_item.path for file_item in batch.file_items])
|
||||
print_acc("No Clip Image")
|
||||
print_acc([file_item.path for file_item in batch.file_items])
|
||||
raise ValueError("Could not find clip image")
|
||||
|
||||
if not self.adapter_config.train_image_encoder:
|
||||
@@ -1625,7 +1623,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
print("loss is nan")
|
||||
print_acc("loss is nan")
|
||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||
|
||||
with self.timer('backward'):
|
||||
@@ -1640,10 +1638,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# if self.is_bfloat:
|
||||
# loss.backward()
|
||||
# else:
|
||||
if not self.do_grad_scale:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return loss.detach()
|
||||
# flush()
|
||||
@@ -1668,21 +1663,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if not self.is_grad_accumulation_step:
|
||||
# fix this for multi params
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
if self.do_grad_scale:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# self.optimizer.step()
|
||||
if not self.do_grad_scale:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.step()
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
|
||||
@@ -61,6 +61,11 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw
|
||||
DecoratorConfig
|
||||
from toolkit.logging import create_logger
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.print import print_acc
|
||||
from accelerate import Accelerator
|
||||
import transformers
|
||||
import diffusers
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -71,6 +76,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
self.accelerator: Accelerator = get_accelerator()
|
||||
if self.accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
self.sd: StableDiffusion
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
@@ -82,8 +95,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.grad_accumulation_step = 1
|
||||
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||
self.is_grad_accumulation_step = False
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.device = str(self.accelerator.device)
|
||||
self.device_torch = self.accelerator.device
|
||||
network_config = self.get_conf('network', None)
|
||||
if network_config is not None:
|
||||
self.network_config = NetworkConfig(**network_config)
|
||||
@@ -91,6 +104,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network_config = None
|
||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||
model_config = self.get_conf('model', {})
|
||||
self.modules_being_trained: List[torch.nn.Module] = []
|
||||
|
||||
# update modelconfig dtype to match train
|
||||
model_config['dtype'] = self.train_config.dtype
|
||||
@@ -222,6 +236,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return generate_image_config_list
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
flush()
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
gen_img_config_list = []
|
||||
@@ -316,6 +332,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
elif self.model_config.is_xl:
|
||||
o_dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
elif self.model_config.is_flux:
|
||||
o_dict['ss_base_model_version'] = 'flux.1'
|
||||
else:
|
||||
o_dict['ss_base_model_version'] = 'sd_1.5'
|
||||
|
||||
@@ -344,6 +362,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return info
|
||||
|
||||
def clean_up_saves(self):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
# remove old saves
|
||||
# get latest saved step
|
||||
latest_item = None
|
||||
@@ -400,7 +420,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
items_to_remove = list(dict.fromkeys(items_to_remove))
|
||||
|
||||
for item in items_to_remove:
|
||||
self.print(f"Removing old save: {item}")
|
||||
print_acc(f"Removing old save: {item}")
|
||||
if os.path.isdir(item):
|
||||
shutil.rmtree(item)
|
||||
else:
|
||||
@@ -418,6 +438,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
pass
|
||||
|
||||
def save(self, step=None):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
flush()
|
||||
if self.ema is not None:
|
||||
# always save params as ema
|
||||
@@ -594,10 +616,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Could not save optimizer")
|
||||
print_acc(e)
|
||||
print_acc("Could not save optimizer")
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
print_acc(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
|
||||
@@ -619,7 +641,49 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return params
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.logger.start()
|
||||
if self.accelerator.is_main_process:
|
||||
self.logger.start()
|
||||
self.prepare_accelerator()
|
||||
|
||||
|
||||
def prepare_accelerator(self):
|
||||
# set some config
|
||||
self.accelerator.even_batches=False
|
||||
|
||||
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
|
||||
if self.sd.vae is not None:
|
||||
self.sd.vae = self.accelerator.prepare(self.sd.vae)
|
||||
if self.sd.unet is not None:
|
||||
self.sd.unet = self.accelerator.prepare(self.sd.unet)
|
||||
# todo always tdo it?
|
||||
self.modules_being_trained.append(self.sd.unet)
|
||||
if self.sd.text_encoder is not None and self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
self.sd.text_encoder = [self.accelerator.prepare(model) for model in self.sd.text_encoder]
|
||||
self.modules_being_trained.extend(self.sd.text_encoder)
|
||||
else:
|
||||
self.sd.text_encoder = self.accelerator.prepare(self.sd.text_encoder)
|
||||
self.modules_being_trained.append(self.sd.text_encoder)
|
||||
if self.sd.refiner_unet is not None and self.train_config.train_refiner:
|
||||
self.sd.refiner_unet = self.accelerator.prepare(self.sd.refiner_unet)
|
||||
self.modules_being_trained.append(self.sd.refiner_unet)
|
||||
# todo, do we need to do the network or will "unet" get it?
|
||||
if self.sd.network is not None:
|
||||
self.sd.network = self.accelerator.prepare(self.sd.network)
|
||||
self.modules_being_trained.append(self.sd.network)
|
||||
if self.adapter is not None and self.adapter_config.train:
|
||||
# todo adapters may not be a module. need to check
|
||||
self.adapter = self.accelerator.prepare(self.adapter)
|
||||
self.modules_being_trained.append(self.adapter)
|
||||
|
||||
# prepare other things
|
||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler)
|
||||
# self.data_loader = self.accelerator.prepare(self.data_loader)
|
||||
# if self.data_loader_reg is not None:
|
||||
# self.data_loader_reg = self.accelerator.prepare(self.data_loader_reg)
|
||||
|
||||
|
||||
def ensure_params_requires_grad(self, force=False):
|
||||
if self.train_config.do_paramiter_swapping and not force:
|
||||
@@ -692,6 +756,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return latest_path
|
||||
|
||||
def load_training_state_from_metadata(self, path):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
meta = None
|
||||
# if path is folder, then it is diffusers
|
||||
if os.path.isdir(path):
|
||||
@@ -708,7 +774,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
print_acc(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
def load_weights(self, path):
|
||||
if self.network is not None:
|
||||
@@ -716,7 +782,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.load_training_state_from_metadata(path)
|
||||
return extra_weights
|
||||
else:
|
||||
print("load_weights not implemented for non-network models")
|
||||
print_acc("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def apply_snr(self, seperated_loss, timesteps):
|
||||
@@ -747,7 +813,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
print_acc(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
|
||||
@@ -1244,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
if latest_save_path is not None and not is_control_net:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
print_acc(f"Loading adapter from {latest_save_path}")
|
||||
if is_t2i:
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
@@ -1290,7 +1356,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
|
||||
@@ -1357,7 +1423,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# block.attn.set_processor(processor)
|
||||
|
||||
# except ImportError:
|
||||
# print("sage attention is not installed. Using SDP instead")
|
||||
# print_acc("sage attention is not installed. Using SDP instead")
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
if self.sd.is_flux:
|
||||
@@ -1531,8 +1597,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
latest_save_path = self.get_latest_save_path(lora_name)
|
||||
extra_weights = None
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
print_acc(f"Loading from {latest_save_path}")
|
||||
extra_weights = self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
@@ -1665,17 +1731,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
previous_lrs.append(group['lr'])
|
||||
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
flush()
|
||||
except Exception as e:
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(e)
|
||||
|
||||
# update the optimizer LR from the params
|
||||
print(f"Updating optimizer LR from params")
|
||||
print_acc(f"Updating optimizer LR from params")
|
||||
if len(previous_lrs) > 0:
|
||||
for i, group in enumerate(optimizer.param_groups):
|
||||
group['lr'] = previous_lrs[i]
|
||||
@@ -1711,24 +1777,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.hook_before_train_loop()
|
||||
|
||||
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
|
||||
self.print("Generating first sample from first sample config")
|
||||
print_acc("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample or self.train_config.disable_sampling:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
print_acc("Skipping first sample due to config setting")
|
||||
elif self.step_num <= 1 or self.train_config.force_first_sample:
|
||||
self.print("Generating baseline samples before training")
|
||||
print_acc("Generating baseline samples before training")
|
||||
self.sample(self.step_num)
|
||||
|
||||
self.progress_bar = ToolkitProgressBar(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True,
|
||||
initial=self.step_num,
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
self.progress_bar.pause()
|
||||
|
||||
if self.accelerator.is_local_main_process:
|
||||
self.progress_bar = ToolkitProgressBar(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True,
|
||||
initial=self.step_num,
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
self.progress_bar.pause()
|
||||
else:
|
||||
self.progress_bar = None
|
||||
|
||||
if self.data_loader is not None:
|
||||
dataloader = self.data_loader
|
||||
@@ -1753,7 +1822,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
|
||||
# print(f"Compiling Model")
|
||||
# print_acc(f"Compiling Model")
|
||||
# torch.compile(self.sd.unet, dynamic=True)
|
||||
|
||||
# make sure all params require grad
|
||||
@@ -1779,7 +1848,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.is_grad_accumulation_step = True
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
@@ -1802,13 +1872,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch:reg'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
@@ -1817,7 +1889,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
self.epoch_num += 1
|
||||
@@ -1827,7 +1900,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.grad_accumulation_step = 0
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
batch_list.append(batch)
|
||||
@@ -1849,8 +1923,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
flush()
|
||||
@@ -1880,7 +1954,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
@@ -1889,8 +1964,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
if is_sample_step or is_save_step:
|
||||
self.accelerator.wait_for_everyone()
|
||||
if is_sample_step:
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
flush()
|
||||
# print above the progress bar
|
||||
if self.train_config.free_u:
|
||||
@@ -1902,57 +1980,70 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if is_save_step:
|
||||
self.accelerator
|
||||
# print above the progress bar
|
||||
self.progress_bar.pause()
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
print_acc(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
with self.timer('log_to_tensorboard'):
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
if self.accelerator.is_main_process:
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# log to logger
|
||||
self.logger.log({
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
if self.accelerator.is_main_process:
|
||||
# log to logger
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
elif self.logging_config.log_every is None:
|
||||
# log every step
|
||||
self.logger.log({
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
for key, value in loss_dict.items():
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
})
|
||||
elif self.logging_config.log_every is None:
|
||||
if self.accelerator.is_main_process:
|
||||
# log every step
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
})
|
||||
|
||||
|
||||
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
# print the timers and clear them
|
||||
self.timer.print()
|
||||
self.timer.reset()
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# commit log
|
||||
self.logger.commit(step=self.step_num)
|
||||
if self.accelerator.is_main_process:
|
||||
self.logger.commit(step=self.step_num)
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
|
||||
#############################
|
||||
# End of step
|
||||
@@ -1966,16 +2057,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
###################################################################
|
||||
## END TRAIN LOOP
|
||||
###################################################################
|
||||
|
||||
self.progress_bar.close()
|
||||
self.accelerator.wait_for_everyone()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.close()
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
if not self.train_config.disable_sampling:
|
||||
self.sample(self.step_num)
|
||||
self.logger.commit(step=self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
self.logger.finish()
|
||||
print_acc("")
|
||||
if self.accelerator.is_main_process:
|
||||
self.save()
|
||||
self.logger.finish()
|
||||
self.accelerator.end_training()
|
||||
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
@@ -2001,6 +2095,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
repo_id: str,
|
||||
private: bool = False,
|
||||
):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
readme_content = self._generate_readme(repo_id)
|
||||
readme_path = os.path.join(self.save_root, "README.md")
|
||||
with open(readme_path, "w", encoding="utf-8") as f:
|
||||
|
||||
23
run.py
23
run.py
@@ -20,20 +20,26 @@ if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.print import print_acc
|
||||
|
||||
accelerator = get_accelerator()
|
||||
|
||||
|
||||
def print_end_message(jobs_completed, jobs_failed):
|
||||
if not accelerator.is_main_process:
|
||||
return
|
||||
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
|
||||
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
|
||||
|
||||
print("")
|
||||
print("========================================")
|
||||
print("Result:")
|
||||
print_acc("")
|
||||
print_acc("========================================")
|
||||
print_acc("Result:")
|
||||
if len(completed_string) > 0:
|
||||
print(f" - {completed_string}")
|
||||
print_acc(f" - {completed_string}")
|
||||
if len(failure_string) > 0:
|
||||
print(f" - {failure_string}")
|
||||
print("========================================")
|
||||
print_acc(f" - {failure_string}")
|
||||
print_acc("========================================")
|
||||
|
||||
|
||||
def main():
|
||||
@@ -70,7 +76,8 @@ def main():
|
||||
jobs_completed = 0
|
||||
jobs_failed = 0
|
||||
|
||||
print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
|
||||
if accelerator.is_main_process:
|
||||
print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
|
||||
|
||||
for config_file in config_file_list:
|
||||
try:
|
||||
@@ -79,7 +86,7 @@ def main():
|
||||
job.cleanup()
|
||||
jobs_completed += 1
|
||||
except Exception as e:
|
||||
print(f"Error running job: {e}")
|
||||
print_acc(f"Error running job: {e}")
|
||||
jobs_failed += 1
|
||||
if not args.recover:
|
||||
print_end_message(jobs_completed, jobs_failed)
|
||||
|
||||
3
todo_multigpu.md
Normal file
3
todo_multigpu.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- only do ema on main device? shouldne be needed other than saving and sampling
|
||||
- check when to unwrap model and what it does
|
||||
- disable timer for non main local
|
||||
17
toolkit/accelerator.py
Normal file
17
toolkit/accelerator.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from accelerate import Accelerator
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
global_accelerator = None
|
||||
|
||||
|
||||
def get_accelerator() -> Accelerator:
|
||||
global global_accelerator
|
||||
if global_accelerator is None:
|
||||
global_accelerator = Accelerator()
|
||||
return global_accelerator
|
||||
|
||||
def unwrap_model(model):
|
||||
accelerator = get_accelerator()
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
@@ -20,6 +20,8 @@ from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
|
||||
import platform
|
||||
|
||||
@@ -90,7 +92,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
new_file_list = []
|
||||
bad_count = 0
|
||||
for file in tqdm(self.file_list):
|
||||
@@ -102,8 +104,8 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
|
||||
self.file_list = new_file_list
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
print_acc(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
@@ -128,8 +130,8 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
try:
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
except Exception as e:
|
||||
print(f"Error opening image: {img_path}")
|
||||
print(e)
|
||||
print_acc(f"Error opening image: {img_path}")
|
||||
print_acc(e)
|
||||
# make a noise image if we can't open it
|
||||
img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8))
|
||||
|
||||
@@ -140,7 +142,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
if self.random_crop:
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
print_acc(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
@@ -243,11 +245,11 @@ class PairedImageDataset(Dataset):
|
||||
matched_files = [t for t in (set(tuple(i) for i in matched_files))]
|
||||
|
||||
self.file_list = matched_files
|
||||
print(f" - Found {len(self.file_list)} matching pairs")
|
||||
print_acc(f" - Found {len(self.file_list)} matching pairs")
|
||||
else:
|
||||
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
||||
file.lower().endswith(supported_exts)]
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
@@ -435,11 +437,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
])
|
||||
|
||||
# this might take a while
|
||||
print(f"Dataset: {self.dataset_path}")
|
||||
print(f" - Preprocessing image dimensions")
|
||||
print_acc(f"Dataset: {self.dataset_path}")
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
dataset_folder = self.dataset_path
|
||||
if not os.path.isdir(self.dataset_path):
|
||||
dataset_folder = os.path.dirname(dataset_folder)
|
||||
|
||||
dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
|
||||
dataloader_version = "0.1.1"
|
||||
if os.path.exists(dataset_size_file):
|
||||
@@ -448,12 +451,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.size_database = json.load(f)
|
||||
|
||||
if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version:
|
||||
print("Upgrading size database to new version")
|
||||
print_acc("Upgrading size database to new version")
|
||||
# old version, delete and recreate
|
||||
self.size_database = {}
|
||||
except Exception as e:
|
||||
print(f"Error loading size database: {dataset_size_file}")
|
||||
print(e)
|
||||
print_acc(f"Error loading size database: {dataset_size_file}")
|
||||
print_acc(e)
|
||||
self.size_database = {}
|
||||
else:
|
||||
self.size_database = {}
|
||||
@@ -473,22 +476,22 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Error processing image: {file}")
|
||||
print(e)
|
||||
print_acc(traceback.format_exc())
|
||||
print_acc(f"Error processing image: {file}")
|
||||
print_acc(e)
|
||||
bad_count += 1
|
||||
|
||||
# save the size database
|
||||
with open(dataset_size_file, 'w') as f:
|
||||
json.dump(self.size_database, f)
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
# print(f" - Found {bad_count} images that are too small")
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
# print_acc(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
# handle x axis flips
|
||||
if self.dataset_config.flip_x:
|
||||
print(" - adding x axis flips")
|
||||
print_acc(" - adding x axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the x axis
|
||||
@@ -498,7 +501,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
|
||||
# handle y axis flips
|
||||
if self.dataset_config.flip_y:
|
||||
print(" - adding y axis flips")
|
||||
print_acc(" - adding y axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the y axis
|
||||
@@ -507,7 +510,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print(f" - Found {len(self.file_list)} images after adding flips")
|
||||
print_acc(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
@@ -24,6 +24,8 @@ from torchvision import transforms
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from PIL.ImageOps import exif_transpose
|
||||
import albumentations as A
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -32,6 +34,8 @@ if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
accelerator = get_accelerator()
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
# https://demo.albumentations.ai/
|
||||
class Augments:
|
||||
@@ -263,7 +267,7 @@ class BucketsMixin:
|
||||
file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)
|
||||
|
||||
if file_item.crop_y < 0 or file_item.crop_x < 0:
|
||||
print('debug')
|
||||
print_acc('debug')
|
||||
|
||||
# check if bucket exists, if not, create it
|
||||
bucket_key = f'{file_item.crop_width}x{file_item.crop_height}'
|
||||
@@ -275,10 +279,10 @@ class BucketsMixin:
|
||||
self.shuffle_buckets()
|
||||
self.build_batch_indices()
|
||||
if not quiet:
|
||||
print(f'Bucket sizes for {self.dataset_path}:')
|
||||
print_acc(f'Bucket sizes for {self.dataset_path}:')
|
||||
for key, bucket in self.buckets.items():
|
||||
print(f'{key}: {len(bucket.file_list_idx)} files')
|
||||
print(f'{len(self.buckets)} buckets made')
|
||||
print_acc(f'{key}: {len(bucket.file_list_idx)} files')
|
||||
print_acc(f'{len(self.buckets)} buckets made')
|
||||
|
||||
|
||||
class CaptionProcessingDTOMixin:
|
||||
@@ -447,8 +451,8 @@ class ImageProcessingDTOMixin:
|
||||
img = Image.open(self.path)
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.path}")
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {self.path}")
|
||||
|
||||
if self.use_alpha_as_mask:
|
||||
# we do this to make sure it does not replace the alpha with another color
|
||||
@@ -462,11 +466,11 @@ class ImageProcessingDTOMixin:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
print(
|
||||
print_acc(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
print(
|
||||
print_acc(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
@@ -482,7 +486,7 @@ class ImageProcessingDTOMixin:
|
||||
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
|
||||
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
|
||||
# todo look into this. This still happens sometimes
|
||||
print('size mismatch')
|
||||
print_acc('size mismatch')
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
@@ -501,7 +505,7 @@ class ImageProcessingDTOMixin:
|
||||
if self.dataset_config.random_crop:
|
||||
if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution:
|
||||
if min_img_size < self.dataset_config.resolution:
|
||||
print(
|
||||
print_acc(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}")
|
||||
scale_size = self.dataset_config.resolution
|
||||
else:
|
||||
@@ -567,8 +571,8 @@ class ControlFileItemDTOMixin:
|
||||
img = Image.open(self.control_path).convert('RGB')
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.control_path}")
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {self.control_path}")
|
||||
|
||||
if self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
@@ -782,8 +786,8 @@ class ClipImageFileItemDTOMixin:
|
||||
except Exception as e:
|
||||
# make a random noise image
|
||||
img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution))
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {clip_image_path}")
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {clip_image_path}")
|
||||
|
||||
img = img.convert('RGB')
|
||||
|
||||
@@ -981,8 +985,8 @@ class MaskFileItemDTOMixin:
|
||||
img = Image.open(self.mask_path)
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.mask_path}")
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {self.mask_path}")
|
||||
|
||||
if self.use_alpha_as_mask:
|
||||
# pipeline expectws an rgb image so we need to put alpha in all channels
|
||||
@@ -999,11 +1003,11 @@ class MaskFileItemDTOMixin:
|
||||
fix_size = False
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
fix_size = True
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
fix_size = True
|
||||
|
||||
if fix_size:
|
||||
@@ -1085,8 +1089,8 @@ class UnconditionalFileItemDTOMixin:
|
||||
img = Image.open(self.unconditional_path)
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.mask_path}")
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {self.mask_path}")
|
||||
|
||||
img = img.convert('RGB')
|
||||
w, h = img.size
|
||||
@@ -1166,9 +1170,9 @@ class PoiFileItemDTOMixin:
|
||||
with open(caption_path, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
if 'poi' not in json_data:
|
||||
print(f"Warning: poi not found in caption file: {caption_path}")
|
||||
print_acc(f"Warning: poi not found in caption file: {caption_path}")
|
||||
if self.poi not in json_data['poi']:
|
||||
print(f"Warning: poi not found in caption file: {caption_path}")
|
||||
print_acc(f"Warning: poi not found in caption file: {caption_path}")
|
||||
# poi has, x, y, width, height
|
||||
# do full image if no poi
|
||||
self.poi_x = 0
|
||||
@@ -1242,8 +1246,8 @@ class PoiFileItemDTOMixin:
|
||||
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
|
||||
current_resolution = get_resolution(poi_width, poi_height)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error getting resolution: {self.path}")
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error getting resolution: {self.path}")
|
||||
raise e
|
||||
return False
|
||||
if current_resolution >= self.dataset_config.resolution:
|
||||
@@ -1252,7 +1256,7 @@ class PoiFileItemDTOMixin:
|
||||
else:
|
||||
num_loops += 1
|
||||
if num_loops > 100:
|
||||
print(
|
||||
print_acc(
|
||||
f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.")
|
||||
return False
|
||||
|
||||
@@ -1279,7 +1283,7 @@ class PoiFileItemDTOMixin:
|
||||
|
||||
if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height:
|
||||
# todo look into this. This still happens sometimes
|
||||
print('size mismatch')
|
||||
print_acc('size mismatch')
|
||||
|
||||
return True
|
||||
|
||||
@@ -1373,88 +1377,89 @@ class LatentCachingMixin:
|
||||
self.latent_cache = {}
|
||||
|
||||
def cache_latents_all_latents(self: 'AiToolkitDataset'):
|
||||
print(f"Caching latents for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
with accelerator.main_process_first():
|
||||
print_acc(f"Caching latents for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
|
||||
if to_disk:
|
||||
print(" - Saving latents to disk")
|
||||
if to_memory:
|
||||
print(" - Keeping latents in memory")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
if to_disk:
|
||||
print_acc(" - Saving latents to disk")
|
||||
if to_memory:
|
||||
print_acc(" - Keeping latents in memory")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_flux:
|
||||
file_item.latent_space_version = 'flux1'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
file_item.latent_load_device = self.sd.device
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_flux:
|
||||
file_item.latent_space_version = 'flux1'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
try:
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {file_item.path}")
|
||||
print(f"Error: {str(e)}")
|
||||
raise e
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
try:
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
except Exception as e:
|
||||
print_acc(f"Error processing image: {file_item.path}")
|
||||
print_acc(f"Error: {str(e)}")
|
||||
raise e
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
|
||||
del imgs
|
||||
del latent
|
||||
del file_item.tensor
|
||||
del imgs
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
class CLIPCachingMixin:
|
||||
@@ -1469,9 +1474,9 @@ class CLIPCachingMixin:
|
||||
if not self.is_caching_clip_vision_to_disk:
|
||||
return
|
||||
with torch.no_grad():
|
||||
print(f"Caching clip vision for {self.dataset_path}")
|
||||
print_acc(f"Caching clip vision for {self.dataset_path}")
|
||||
|
||||
print(" - Saving clip to disk")
|
||||
print_acc(" - Saving clip to disk")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_clip')
|
||||
|
||||
@@ -1512,7 +1517,7 @@ class CLIPCachingMixin:
|
||||
self.clip_vision_num_unconditional_cache = 1
|
||||
|
||||
# cache unconditionals
|
||||
print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
|
||||
print_acc(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
|
||||
clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache')
|
||||
|
||||
unconditional_paths = []
|
||||
|
||||
6
toolkit/print.py
Normal file
6
toolkit/print.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from toolkit.accelerator import get_accelerator
|
||||
|
||||
|
||||
def print_acc(*args, **kwargs):
|
||||
if get_accelerator().is_local_main_process:
|
||||
print(*args, **kwargs)
|
||||
@@ -63,7 +63,9 @@ from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -130,18 +132,17 @@ class StableDiffusion:
|
||||
noise_scheduler=None,
|
||||
quantize_device=None,
|
||||
):
|
||||
self.accelerator = get_accelerator()
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = device
|
||||
self.device = str(self.accelerator.device)
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.device_torch = self.accelerator.device
|
||||
|
||||
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(
|
||||
model_config.vae_device)
|
||||
self.vae_device_torch = self.accelerator.device
|
||||
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||
|
||||
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(
|
||||
model_config.te_device)
|
||||
self.te_device_torch = self.accelerator.device
|
||||
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||
|
||||
self.model_config = model_config
|
||||
@@ -186,7 +187,7 @@ class StableDiffusion:
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler):
|
||||
self.is_flow_matching = True
|
||||
|
||||
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
||||
self.quantize_device = self.device_torch
|
||||
self.low_vram = self.model_config.low_vram
|
||||
|
||||
# merge in and preview active with -1 weight
|
||||
@@ -254,8 +255,8 @@ class StableDiffusion:
|
||||
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
|
||||
if self.model_config.experimental_xl:
|
||||
print("Experimental XL mode enabled")
|
||||
print("Loading and injecting alt weights")
|
||||
print_acc("Experimental XL mode enabled")
|
||||
print_acc("Loading and injecting alt weights")
|
||||
# load the mismatched weight and force it in
|
||||
raw_state_dict = load_file(model_path)
|
||||
replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone()
|
||||
@@ -265,17 +266,17 @@ class StableDiffusion:
|
||||
# replace weight with mismatched weight
|
||||
te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
print("Injecting alt weights")
|
||||
print_acc("Injecting alt weights")
|
||||
elif self.model_config.is_v3:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = StableDiffusion3Pipeline
|
||||
|
||||
print("Loading SD3 model")
|
||||
print_acc("Loading SD3 model")
|
||||
# assume it is the large model
|
||||
base_model_path = "stabilityai/stable-diffusion-3.5-large"
|
||||
print("Loading transformer")
|
||||
print_acc("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
# check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
|
||||
@@ -298,7 +299,7 @@ class StableDiffusion:
|
||||
)
|
||||
if not self.low_vram:
|
||||
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
@@ -306,7 +307,7 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
print_acc("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
@@ -314,11 +315,11 @@ class StableDiffusion:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
print_acc("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
print("Loading t5")
|
||||
print_acc("Loading t5")
|
||||
tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype)
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained(
|
||||
base_model_path,
|
||||
@@ -330,7 +331,7 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing T5")
|
||||
print_acc("Quantizing T5")
|
||||
quantize(text_encoder_3, weights=qfloat8)
|
||||
freeze(text_encoder_3)
|
||||
flush()
|
||||
@@ -354,7 +355,7 @@ class StableDiffusion:
|
||||
**load_args
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading from pretrained: {e}")
|
||||
print_acc(f"Error loading from pretrained: {e}")
|
||||
raise e
|
||||
|
||||
else:
|
||||
@@ -529,10 +530,10 @@ class StableDiffusion:
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
elif self.model_config.is_flux:
|
||||
print("Loading Flux model")
|
||||
print_acc("Loading Flux model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
print("Loading transformer")
|
||||
print_acc("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
local_files_only = False
|
||||
@@ -559,7 +560,7 @@ class StableDiffusion:
|
||||
|
||||
if not self.low_vram:
|
||||
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
@@ -581,7 +582,7 @@ class StableDiffusion:
|
||||
load_lora_path, "pytorch_lora_weights.safetensors"
|
||||
)
|
||||
elif not os.path.exists(load_lora_path):
|
||||
print(f"Grabbing lora from the hub: {load_lora_path}")
|
||||
print_acc(f"Grabbing lora from the hub: {load_lora_path}")
|
||||
new_lora_path = hf_hub_download(
|
||||
load_lora_path,
|
||||
filename="pytorch_lora_weights.safetensors"
|
||||
@@ -604,7 +605,7 @@ class StableDiffusion:
|
||||
self.model_config.lora_path = self.model_config.assistant_lora_path
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
print("Fusing in LoRA")
|
||||
print_acc("Fusing in LoRA")
|
||||
# need the pipe for peft
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=None,
|
||||
@@ -635,7 +636,7 @@ class StableDiffusion:
|
||||
|
||||
# double blocks
|
||||
transformer.transformer_blocks = transformer.transformer_blocks.to(
|
||||
torch.device(self.quantize_device), dtype=dtype
|
||||
self.quantize_device, dtype=dtype
|
||||
)
|
||||
pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double")
|
||||
pipe.fuse_lora()
|
||||
@@ -646,7 +647,7 @@ class StableDiffusion:
|
||||
|
||||
# single blocks
|
||||
transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
|
||||
torch.device(self.quantize_device), dtype=dtype
|
||||
self.quantize_device, dtype=dtype
|
||||
)
|
||||
pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single")
|
||||
pipe.fuse_lora()
|
||||
@@ -674,7 +675,7 @@ class StableDiffusion:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
print_acc("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
@@ -684,11 +685,11 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
print_acc("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
print("Loading t5")
|
||||
print_acc("Loading t5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
||||
torch_dtype=dtype)
|
||||
@@ -697,17 +698,17 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
print("Quantizing T5")
|
||||
print_acc("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
print("Loading clip")
|
||||
print_acc("Loading clip")
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
print("making pipe")
|
||||
print_acc("making pipe")
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
@@ -720,7 +721,7 @@ class StableDiffusion:
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.transformer = transformer
|
||||
|
||||
print("preparing")
|
||||
print_acc("preparing")
|
||||
|
||||
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
@@ -836,7 +837,7 @@ class StableDiffusion:
|
||||
self.is_loaded = True
|
||||
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Loading assistant lora")
|
||||
print_acc("Loading assistant lora")
|
||||
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
||||
self.model_config.assistant_lora_path, self)
|
||||
|
||||
@@ -846,7 +847,7 @@ class StableDiffusion:
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.model_config.inference_lora_path is not None:
|
||||
print("Loading inference lora")
|
||||
print_acc("Loading inference lora")
|
||||
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
||||
self.model_config.inference_lora_path, self)
|
||||
# disable during training
|
||||
@@ -917,11 +918,12 @@ class StableDiffusion:
|
||||
sampler=None,
|
||||
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
||||
):
|
||||
network = unwrap_model(self.network)
|
||||
merge_multiplier = 1.0
|
||||
flush()
|
||||
# if using assistant, unfuse it
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Unloading assistant lora")
|
||||
print_acc("Unloading assistant lora")
|
||||
if self.invert_assistant_lora:
|
||||
self.assistant_lora.is_active = True
|
||||
# move weights on to the device
|
||||
@@ -930,18 +932,17 @@ class StableDiffusion:
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.model_config.inference_lora_path is not None:
|
||||
print("Loading inference lora")
|
||||
print_acc("Loading inference lora")
|
||||
self.assistant_lora.is_active = True
|
||||
# move weights on to the device
|
||||
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
|
||||
|
||||
if self.network is not None:
|
||||
self.network.eval()
|
||||
network = self.network
|
||||
if network is not None:
|
||||
network.eval()
|
||||
# check if we have the same network weight for all samples. If we do, we can merge in th
|
||||
# the network to drastically speed up inference
|
||||
unique_network_weights = set([x.network_multiplier for x in image_configs])
|
||||
if len(unique_network_weights) == 1 and self.network.can_merge_in:
|
||||
if len(unique_network_weights) == 1 and network.can_merge_in:
|
||||
can_merge_in = True
|
||||
merge_multiplier = unique_network_weights.pop()
|
||||
network.merge_in(merge_weight=merge_multiplier)
|
||||
@@ -1119,15 +1120,15 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
start_multiplier = 1.0
|
||||
if self.network is not None:
|
||||
start_multiplier = self.network.multiplier
|
||||
if network is not None:
|
||||
start_multiplier = network.multiplier
|
||||
|
||||
# pipeline.to(self.device_torch)
|
||||
|
||||
with network:
|
||||
with torch.no_grad():
|
||||
if self.network is not None:
|
||||
assert self.network.is_active
|
||||
if network is not None:
|
||||
assert network.is_active
|
||||
|
||||
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
|
||||
gen_config = image_configs[i]
|
||||
@@ -1164,8 +1165,8 @@ class StableDiffusion:
|
||||
validation_image = validation_image.unsqueeze(0)
|
||||
self.adapter.set_reference_images(validation_image)
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
if network is not None:
|
||||
network.multiplier = gen_config.network_multiplier
|
||||
torch.manual_seed(gen_config.seed)
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
@@ -1332,6 +1333,12 @@ class StableDiffusion:
|
||||
**extra
|
||||
).images[0]
|
||||
else:
|
||||
# Fix a bug in diffusers/torch
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
latents = callback_kwargs["latents"]
|
||||
if latents.dtype != self.unet.dtype:
|
||||
latents = latents.to(self.unet.dtype)
|
||||
return {"latents": latents}
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
@@ -1343,6 +1350,7 @@ class StableDiffusion:
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_pixart:
|
||||
@@ -1448,9 +1456,9 @@ class StableDiffusion:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
self.restore_device_state()
|
||||
if self.network is not None:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
if network is not None:
|
||||
network.train()
|
||||
network.multiplier = start_multiplier
|
||||
|
||||
self.unet.to(self.device_torch, dtype=self.torch_dtype)
|
||||
if network.is_merged_in:
|
||||
@@ -1459,7 +1467,7 @@ class StableDiffusion:
|
||||
|
||||
# refuse loras
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Loading assistant lora")
|
||||
print_acc("Loading assistant lora")
|
||||
if self.invert_assistant_lora:
|
||||
self.assistant_lora.is_active = False
|
||||
# move weights off the device
|
||||
@@ -1468,7 +1476,7 @@ class StableDiffusion:
|
||||
self.assistant_lora.is_active = True
|
||||
|
||||
if self.model_config.inference_lora_path is not None:
|
||||
print("Unloading inference lora")
|
||||
print_acc("Unloading inference lora")
|
||||
self.assistant_lora.is_active = False
|
||||
# move weights off the device
|
||||
self.assistant_lora.force_to('cpu', self.torch_dtype)
|
||||
@@ -1867,6 +1875,11 @@ class StableDiffusion:
|
||||
bypass_flux_guidance(self.unet)
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
# changes from orig implementation
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
|
||||
@@ -2513,7 +2526,7 @@ class StableDiffusion:
|
||||
params.append(named_params[diffusers_key])
|
||||
param_data = {"params": params, "lr": unet_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
print(f"Found {len(params)} trainable parameter in unet")
|
||||
print_acc(f"Found {len(params)} trainable parameter in unet")
|
||||
|
||||
if text_encoder:
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True)
|
||||
@@ -2526,7 +2539,7 @@ class StableDiffusion:
|
||||
param_data = {"params": params, "lr": text_encoder_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
print_acc(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
if refiner:
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
|
||||
@@ -2541,7 +2554,7 @@ class StableDiffusion:
|
||||
param_data = {"params": params, "lr": refiner_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
|
||||
print(f"Found {len(params)} trainable parameter in refiner")
|
||||
print_acc(f"Found {len(params)} trainable parameter in refiner")
|
||||
|
||||
return trainable_parameters
|
||||
|
||||
|
||||
Reference in New Issue
Block a user