From 5e663746b83de64d5ca239fa699cf59a8a268ba5 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 25 Jan 2025 16:46:20 -0700 Subject: [PATCH] Working multi gpu training. Still need a lot of tweaks and testing. --- extensions_built_in/sd_trainer/SDTrainer.py | 56 ++--- jobs/process/BaseSDTrainProcess.py | 242 ++++++++++++++------ run.py | 23 +- todo_multigpu.md | 3 + toolkit/accelerator.py | 17 ++ toolkit/data_loader.py | 45 ++-- toolkit/dataloader_mixins.py | 213 ++++++++--------- toolkit/print.py | 6 + toolkit/stable_diffusion_model.py | 121 +++++----- 9 files changed, 432 insertions(+), 294 deletions(-) create mode 100644 todo_multigpu.md create mode 100644 toolkit/accelerator.py create mode 100644 toolkit/print.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3747972d..98cec16a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -20,6 +20,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, Guid from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter from toolkit.custom_adapter import CustomAdapter +from toolkit.print import print_acc from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork @@ -59,8 +60,6 @@ class SDTrainer(BaseSDTrainProcess): self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None - self.scaler = torch.cuda.amp.GradScaler() - self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" self.do_grad_scale = True @@ -70,12 +69,12 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter_config.train: self.do_grad_scale = False - if self.train_config.dtype in ["fp16", "float16"]: - # patch the scaler to allow fp16 training - org_unscale_grads = self.scaler._unscale_grads_ - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - self.scaler._unscale_grads_ = _unscale_grads_replacer + # if self.train_config.dtype in ["fp16", "float16"]: + # # patch the scaler to allow fp16 training + # org_unscale_grads = self.scaler._unscale_grads_ + # def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + # return org_unscale_grads(optimizer, inv_scale, found_inf, True) + # self.scaler._unscale_grads_ = _unscale_grads_replacer self.cached_blank_embeds: Optional[PromptEmbeds] = None self.cached_trigger_embeds: Optional[PromptEmbeds] = None @@ -168,11 +167,11 @@ class SDTrainer(BaseSDTrainProcess): raise ValueError("Cannot unload text encoder if training text encoder") # cache embeddings - print("\n***** UNLOADING TEXT ENCODER *****") - print("This will train only with a blank prompt or trigger word, if set") - print("If this is not what you want, remove the unload_text_encoder flag") - print("***********************************") - print("") + print_acc("\n***** UNLOADING TEXT ENCODER *****") + print_acc("This will train only with a blank prompt or trigger word, if set") + print_acc("If this is not what you want, remove the unload_text_encoder flag") + print_acc("***********************************") + print_acc("") self.sd.text_encoder_to(self.device_torch) self.cached_blank_embeds = self.sd.encode_prompt("") if self.trigger_word is not None: @@ -484,7 +483,7 @@ class SDTrainer(BaseSDTrainProcess): prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier if torch.isnan(prior_loss).any(): - print("Prior loss is nan") + print_acc("Prior loss is nan") prior_loss = None else: prior_loss = prior_loss.mean([1, 2, 3]) @@ -553,7 +552,6 @@ class SDTrainer(BaseSDTrainProcess): noise=noise, sd=self.sd, unconditional_embeds=unconditional_embeds, - scaler=self.scaler, **kwargs ) @@ -668,7 +666,7 @@ class SDTrainer(BaseSDTrainProcess): # loss = self.apply_snr(loss, timesteps) loss = loss.mean() - loss.backward() + self.accelerator.backward(loss) # detach it so parent class can run backward on no grads without throwing error loss = loss.detach() @@ -823,7 +821,7 @@ class SDTrainer(BaseSDTrainProcess): # loss = self.apply_snr(loss, timesteps) loss = loss.mean() - loss.backward() + self.accelerator.backward(loss) # detach it so parent class can run backward on no grads without throwing error loss = loss.detach() @@ -1446,8 +1444,8 @@ class SDTrainer(BaseSDTrainProcess): quad_count=quad_count ) else: - print("No Clip Image") - print([file_item.path for file_item in batch.file_items]) + print_acc("No Clip Image") + print_acc([file_item.path for file_item in batch.file_items]) raise ValueError("Could not find clip image") if not self.adapter_config.train_image_encoder: @@ -1625,7 +1623,7 @@ class SDTrainer(BaseSDTrainProcess): ) # check if nan if torch.isnan(loss): - print("loss is nan") + print_acc("loss is nan") loss = torch.zeros_like(loss).requires_grad_(True) with self.timer('backward'): @@ -1640,10 +1638,7 @@ class SDTrainer(BaseSDTrainProcess): # if self.is_bfloat: # loss.backward() # else: - if not self.do_grad_scale: - loss.backward() - else: - self.scaler.scale(loss).backward() + self.accelerator.backward(loss) return loss.detach() # flush() @@ -1668,21 +1663,14 @@ class SDTrainer(BaseSDTrainProcess): if not self.is_grad_accumulation_step: # fix this for multi params if self.train_config.optimizer != 'adafactor': - if self.do_grad_scale: - self.scaler.unscale_(self.optimizer) if isinstance(self.params[0], dict): for i in range(len(self.params)): - torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) else: - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # only step if we are not accumulating with self.timer('optimizer_step'): - # self.optimizer.step() - if not self.do_grad_scale: - self.optimizer.step() - else: - self.scaler.step(self.optimizer) - self.scaler.update() + self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) if self.adapter and isinstance(self.adapter, CustomAdapter): diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3b210154..058ba375 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -61,6 +61,11 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw DecoratorConfig from toolkit.logging import create_logger from diffusers import FluxTransformer2DModel +from toolkit.accelerator import get_accelerator +from toolkit.print import print_acc +from accelerate import Accelerator +import transformers +import diffusers def flush(): torch.cuda.empty_cache() @@ -71,6 +76,14 @@ class BaseSDTrainProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): super().__init__(process_id, job, config) + self.accelerator: Accelerator = get_accelerator() + if self.accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + self.sd: StableDiffusion self.embedding: Union[Embedding, None] = None @@ -82,8 +95,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.grad_accumulation_step = 1 # if true, then we do not do an optimizer step. We are accumulating gradients self.is_grad_accumulation_step = False - self.device = self.get_conf('device', self.job.device) - self.device_torch = torch.device(self.device) + self.device = str(self.accelerator.device) + self.device_torch = self.accelerator.device network_config = self.get_conf('network', None) if network_config is not None: self.network_config = NetworkConfig(**network_config) @@ -91,6 +104,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network_config = None self.train_config = TrainConfig(**self.get_conf('train', {})) model_config = self.get_conf('model', {}) + self.modules_being_trained: List[torch.nn.Module] = [] # update modelconfig dtype to match train model_config['dtype'] = self.train_config.dtype @@ -222,6 +236,8 @@ class BaseSDTrainProcess(BaseTrainProcess): return generate_image_config_list def sample(self, step=None, is_first=False): + if not self.accelerator.is_main_process: + return flush() sample_folder = os.path.join(self.save_root, 'samples') gen_img_config_list = [] @@ -316,6 +332,8 @@ class BaseSDTrainProcess(BaseTrainProcess): elif self.model_config.is_xl: o_dict['ss_base_model_version'] = 'sdxl_1.0' + elif self.model_config.is_flux: + o_dict['ss_base_model_version'] = 'flux.1' else: o_dict['ss_base_model_version'] = 'sd_1.5' @@ -344,6 +362,8 @@ class BaseSDTrainProcess(BaseTrainProcess): return info def clean_up_saves(self): + if not self.accelerator.is_main_process: + return # remove old saves # get latest saved step latest_item = None @@ -400,7 +420,7 @@ class BaseSDTrainProcess(BaseTrainProcess): items_to_remove = list(dict.fromkeys(items_to_remove)) for item in items_to_remove: - self.print(f"Removing old save: {item}") + print_acc(f"Removing old save: {item}") if os.path.isdir(item): shutil.rmtree(item) else: @@ -418,6 +438,8 @@ class BaseSDTrainProcess(BaseTrainProcess): pass def save(self, step=None): + if not self.accelerator.is_main_process: + return flush() if self.ema is not None: # always save params as ema @@ -594,10 +616,10 @@ class BaseSDTrainProcess(BaseTrainProcess): state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) except Exception as e: - print(e) - print("Could not save optimizer") + print_acc(e) + print_acc("Could not save optimizer") - self.print(f"Saved to {file_path}") + print_acc(f"Saved to {file_path}") self.clean_up_saves() self.post_save_hook(file_path) @@ -619,7 +641,49 @@ class BaseSDTrainProcess(BaseTrainProcess): return params def hook_before_train_loop(self): - self.logger.start() + if self.accelerator.is_main_process: + self.logger.start() + self.prepare_accelerator() + + + def prepare_accelerator(self): + # set some config + self.accelerator.even_batches=False + + # # prepare all the models stuff for accelerator (hopefully we dont miss any) + if self.sd.vae is not None: + self.sd.vae = self.accelerator.prepare(self.sd.vae) + if self.sd.unet is not None: + self.sd.unet = self.accelerator.prepare(self.sd.unet) + # todo always tdo it? + self.modules_being_trained.append(self.sd.unet) + if self.sd.text_encoder is not None and self.train_config.train_text_encoder: + if isinstance(self.sd.text_encoder, list): + self.sd.text_encoder = [self.accelerator.prepare(model) for model in self.sd.text_encoder] + self.modules_being_trained.extend(self.sd.text_encoder) + else: + self.sd.text_encoder = self.accelerator.prepare(self.sd.text_encoder) + self.modules_being_trained.append(self.sd.text_encoder) + if self.sd.refiner_unet is not None and self.train_config.train_refiner: + self.sd.refiner_unet = self.accelerator.prepare(self.sd.refiner_unet) + self.modules_being_trained.append(self.sd.refiner_unet) + # todo, do we need to do the network or will "unet" get it? + if self.sd.network is not None: + self.sd.network = self.accelerator.prepare(self.sd.network) + self.modules_being_trained.append(self.sd.network) + if self.adapter is not None and self.adapter_config.train: + # todo adapters may not be a module. need to check + self.adapter = self.accelerator.prepare(self.adapter) + self.modules_being_trained.append(self.adapter) + + # prepare other things + self.optimizer = self.accelerator.prepare(self.optimizer) + if self.lr_scheduler is not None: + self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) + # self.data_loader = self.accelerator.prepare(self.data_loader) + # if self.data_loader_reg is not None: + # self.data_loader_reg = self.accelerator.prepare(self.data_loader_reg) + def ensure_params_requires_grad(self, force=False): if self.train_config.do_paramiter_swapping and not force: @@ -692,6 +756,8 @@ class BaseSDTrainProcess(BaseTrainProcess): return latest_path def load_training_state_from_metadata(self, path): + if not self.accelerator.is_main_process: + return meta = None # if path is folder, then it is diffusers if os.path.isdir(path): @@ -708,7 +774,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if 'epoch' in meta['training_info']: self.epoch_num = meta['training_info']['epoch'] self.start_step = self.step_num - print(f"Found step {self.step_num} in metadata, starting from there") + print_acc(f"Found step {self.step_num} in metadata, starting from there") def load_weights(self, path): if self.network is not None: @@ -716,7 +782,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.load_training_state_from_metadata(path) return extra_weights else: - print("load_weights not implemented for non-network models") + print_acc("load_weights not implemented for non-network models") return None def apply_snr(self, seperated_loss, timesteps): @@ -747,7 +813,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if 'epoch' in meta['training_info']: self.epoch_num = meta['training_info']['epoch'] self.start_step = self.step_num - print(f"Found step {self.step_num} in metadata, starting from there") + print_acc(f"Found step {self.step_num} in metadata, starting from there") # def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): # self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch) @@ -1244,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.adapter.to(self.device_torch, dtype=dtype) if latest_save_path is not None and not is_control_net: # load adapter from path - print(f"Loading adapter from {latest_save_path}") + print_acc(f"Loading adapter from {latest_save_path}") if is_t2i: loaded_state_dict = load_t2i_model( latest_save_path, @@ -1290,7 +1356,7 @@ class BaseSDTrainProcess(BaseTrainProcess): latest_save_path = self.get_latest_save_path() if latest_save_path is not None: - print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") model_config_to_load.name_or_path = latest_save_path self.load_training_state_from_metadata(latest_save_path) @@ -1357,7 +1423,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # block.attn.set_processor(processor) # except ImportError: - # print("sage attention is not installed. Using SDP instead") + # print_acc("sage attention is not installed. Using SDP instead") if self.train_config.gradient_checkpointing: if self.sd.is_flux: @@ -1531,8 +1597,8 @@ class BaseSDTrainProcess(BaseTrainProcess): latest_save_path = self.get_latest_save_path(lora_name) extra_weights = None if latest_save_path is not None: - self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - self.print(f"Loading from {latest_save_path}") + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + print_acc(f"Loading from {latest_save_path}") extra_weights = self.load_weights(latest_save_path) self.network.multiplier = 1.0 @@ -1665,17 +1731,17 @@ class BaseSDTrainProcess(BaseTrainProcess): previous_lrs.append(group['lr']) try: - print(f"Loading optimizer state from {optimizer_state_file_path}") + print_acc(f"Loading optimizer state from {optimizer_state_file_path}") optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) optimizer.load_state_dict(optimizer_state_dict) del optimizer_state_dict flush() except Exception as e: - print(f"Failed to load optimizer state from {optimizer_state_file_path}") - print(e) + print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") + print_acc(e) # update the optimizer LR from the params - print(f"Updating optimizer LR from params") + print_acc(f"Updating optimizer LR from params") if len(previous_lrs) > 0: for i, group in enumerate(optimizer.param_groups): group['lr'] = previous_lrs[i] @@ -1711,24 +1777,27 @@ class BaseSDTrainProcess(BaseTrainProcess): self.hook_before_train_loop() if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling: - self.print("Generating first sample from first sample config") + print_acc("Generating first sample from first sample config") self.sample(0, is_first=True) # sample first if self.train_config.skip_first_sample or self.train_config.disable_sampling: - self.print("Skipping first sample due to config setting") + print_acc("Skipping first sample due to config setting") elif self.step_num <= 1 or self.train_config.force_first_sample: - self.print("Generating baseline samples before training") + print_acc("Generating baseline samples before training") self.sample(self.step_num) - - self.progress_bar = ToolkitProgressBar( - total=self.train_config.steps, - desc=self.job.name, - leave=True, - initial=self.step_num, - iterable=range(0, self.train_config.steps), - ) - self.progress_bar.pause() + + if self.accelerator.is_local_main_process: + self.progress_bar = ToolkitProgressBar( + total=self.train_config.steps, + desc=self.job.name, + leave=True, + initial=self.step_num, + iterable=range(0, self.train_config.steps), + ) + self.progress_bar.pause() + else: + self.progress_bar = None if self.data_loader is not None: dataloader = self.data_loader @@ -1753,7 +1822,7 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() # self.step_num = 0 - # print(f"Compiling Model") + # print_acc(f"Compiling Model") # torch.compile(self.sd.unet, dynamic=True) # make sure all params require grad @@ -1779,7 +1848,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.is_grad_accumulation_step = True if self.train_config.free_u: self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2) - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() with torch.no_grad(): # if is even step and we have a reg dataset, use that # todo improve this logic to send one of each through if we can buckets and batch size might be an issue @@ -1802,13 +1872,15 @@ class BaseSDTrainProcess(BaseTrainProcess): except StopIteration: with self.timer('reset_batch:reg'): # hit the end of an epoch, reset - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() dataloader_iterator_reg = iter(dataloader_reg) trigger_dataloader_setup_epoch(dataloader_reg) with self.timer('get_batch:reg'): batch = next(dataloader_iterator_reg) - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() is_reg_step = True elif dataloader is not None: try: @@ -1817,7 +1889,8 @@ class BaseSDTrainProcess(BaseTrainProcess): except StopIteration: with self.timer('reset_batch'): # hit the end of an epoch, reset - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() dataloader_iterator = iter(dataloader) trigger_dataloader_setup_epoch(dataloader) self.epoch_num += 1 @@ -1827,7 +1900,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.grad_accumulation_step = 0 with self.timer('get_batch'): batch = next(dataloader_iterator) - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() else: batch = None batch_list.append(batch) @@ -1849,8 +1923,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### - - loss_dict = self.hook_train_loop(batch_list) + with self.accelerator.accumulate(self.modules_being_trained): + loss_dict = self.hook_train_loop(batch_list) self.timer.stop('train_loop') if not did_first_flush: flush() @@ -1880,7 +1954,8 @@ class BaseSDTrainProcess(BaseTrainProcess): for key, value in loss_dict.items(): prog_bar_string += f" {key}: {value:.3e}" - self.progress_bar.set_postfix_str(prog_bar_string) + if self.progress_bar is not None: + self.progress_bar.set_postfix_str(prog_bar_string) # if the batch is a DataLoaderBatchDTO, then we need to clean it up if isinstance(batch, DataLoaderBatchDTO): @@ -1889,8 +1964,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # don't do on first step if self.step_num != self.start_step: + if is_sample_step or is_save_step: + self.accelerator.wait_for_everyone() if is_sample_step: - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() flush() # print above the progress bar if self.train_config.free_u: @@ -1902,57 +1980,70 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() self.ensure_params_requires_grad() - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() if is_save_step: + self.accelerator # print above the progress bar - self.progress_bar.pause() - self.print(f"Saving at step {self.step_num}") + if self.progress_bar is not None: + self.progress_bar.pause() + print_acc(f"Saving at step {self.step_num}") self.save(self.step_num) self.ensure_params_requires_grad() - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() with self.timer('log_to_tensorboard'): # log to tensorboard - if self.writer is not None: - for key, value in loss_dict.items(): - self.writer.add_scalar(f"{key}", value, self.step_num) - self.writer.add_scalar(f"lr", learning_rate, self.step_num) - self.progress_bar.unpause() + if self.accelerator.is_main_process: + if self.writer is not None: + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) + self.writer.add_scalar(f"lr", learning_rate, self.step_num) + if self.progress_bar is not None: + self.progress_bar.unpause() - # log to logger - self.logger.log({ - 'learning_rate': learning_rate, - }) - for key, value in loss_dict.items(): + if self.accelerator.is_main_process: + # log to logger self.logger.log({ - f'loss/{key}': value, + 'learning_rate': learning_rate, }) - elif self.logging_config.log_every is None: - # log every step - self.logger.log({ - 'learning_rate': learning_rate, - }) - for key, value in loss_dict.items(): + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + elif self.logging_config.log_every is None: + if self.accelerator.is_main_process: + # log every step self.logger.log({ - f'loss/{key}': value, + 'learning_rate': learning_rate, }) + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0: - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() # print the timers and clear them self.timer.print() self.timer.reset() - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() # commit log - self.logger.commit(step=self.step_num) + if self.accelerator.is_main_process: + self.logger.commit(step=self.step_num) # sets progress bar to match out step - self.progress_bar.update(step - self.progress_bar.n) + if self.progress_bar is not None: + self.progress_bar.update(step - self.progress_bar.n) ############################# # End of step @@ -1966,16 +2057,19 @@ class BaseSDTrainProcess(BaseTrainProcess): ################################################################### ## END TRAIN LOOP ################################################################### - - self.progress_bar.close() + self.accelerator.wait_for_everyone() + if self.progress_bar is not None: + self.progress_bar.close() if self.train_config.free_u: self.sd.pipeline.disable_freeu() if not self.train_config.disable_sampling: self.sample(self.step_num) self.logger.commit(step=self.step_num) - print("") - self.save() - self.logger.finish() + print_acc("") + if self.accelerator.is_main_process: + self.save() + self.logger.finish() + self.accelerator.end_training() if self.save_config.push_to_hub: if("HF_TOKEN" not in os.environ): @@ -2001,6 +2095,8 @@ class BaseSDTrainProcess(BaseTrainProcess): repo_id: str, private: bool = False, ): + if not self.accelerator.is_main_process: + return readme_content = self._generate_readme(repo_id) readme_path = os.path.join(self.save_root, "README.md") with open(readme_path, "w", encoding="utf-8") as f: diff --git a/run.py b/run.py index 6f133081..9a3e57fd 100644 --- a/run.py +++ b/run.py @@ -20,20 +20,26 @@ if os.environ.get("DEBUG_TOOLKIT", "0") == "1": torch.autograd.set_detect_anomaly(True) import argparse from toolkit.job import get_job +from toolkit.accelerator import get_accelerator +from toolkit.print import print_acc + +accelerator = get_accelerator() def print_end_message(jobs_completed, jobs_failed): + if not accelerator.is_main_process: + return failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" - print("") - print("========================================") - print("Result:") + print_acc("") + print_acc("========================================") + print_acc("Result:") if len(completed_string) > 0: - print(f" - {completed_string}") + print_acc(f" - {completed_string}") if len(failure_string) > 0: - print(f" - {failure_string}") - print("========================================") + print_acc(f" - {failure_string}") + print_acc("========================================") def main(): @@ -70,7 +76,8 @@ def main(): jobs_completed = 0 jobs_failed = 0 - print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") + if accelerator.is_main_process: + print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") for config_file in config_file_list: try: @@ -79,7 +86,7 @@ def main(): job.cleanup() jobs_completed += 1 except Exception as e: - print(f"Error running job: {e}") + print_acc(f"Error running job: {e}") jobs_failed += 1 if not args.recover: print_end_message(jobs_completed, jobs_failed) diff --git a/todo_multigpu.md b/todo_multigpu.md new file mode 100644 index 00000000..02d5abda --- /dev/null +++ b/todo_multigpu.md @@ -0,0 +1,3 @@ +- only do ema on main device? shouldne be needed other than saving and sampling +- check when to unwrap model and what it does +- disable timer for non main local \ No newline at end of file diff --git a/toolkit/accelerator.py b/toolkit/accelerator.py new file mode 100644 index 00000000..ebcf0095 --- /dev/null +++ b/toolkit/accelerator.py @@ -0,0 +1,17 @@ +from accelerate import Accelerator +from diffusers.utils.torch_utils import is_compiled_module + +global_accelerator = None + + +def get_accelerator() -> Accelerator: + global global_accelerator + if global_accelerator is None: + global_accelerator = Accelerator() + return global_accelerator + +def unwrap_model(model): + accelerator = get_accelerator() + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 5285b371..1fd6a3b8 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -20,6 +20,8 @@ from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.print import print_acc +from toolkit.accelerator import get_accelerator import platform @@ -90,7 +92,7 @@ class ImageDataset(Dataset, CaptionMixin): file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] # this might take a while - print(f" - Preprocessing image dimensions") + print_acc(f" - Preprocessing image dimensions") new_file_list = [] bad_count = 0 for file in tqdm(self.file_list): @@ -102,8 +104,8 @@ class ImageDataset(Dataset, CaptionMixin): self.file_list = new_file_list - print(f" - Found {len(self.file_list)} images") - print(f" - Found {bad_count} images that are too small") + print_acc(f" - Found {len(self.file_list)} images") + print_acc(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.path}" self.transform = transforms.Compose([ @@ -128,8 +130,8 @@ class ImageDataset(Dataset, CaptionMixin): try: img = exif_transpose(Image.open(img_path)).convert('RGB') except Exception as e: - print(f"Error opening image: {img_path}") - print(e) + print_acc(f"Error opening image: {img_path}") + print_acc(e) # make a noise image if we can't open it img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)) @@ -140,7 +142,7 @@ class ImageDataset(Dataset, CaptionMixin): if self.random_crop: if self.random_scale and min_img_size > self.resolution: if min_img_size < self.resolution: - print( + print_acc( f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") scale_size = self.resolution else: @@ -243,11 +245,11 @@ class PairedImageDataset(Dataset): matched_files = [t for t in (set(tuple(i) for i in matched_files))] self.file_list = matched_files - print(f" - Found {len(self.file_list)} matching pairs") + print_acc(f" - Found {len(self.file_list)} matching pairs") else: self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if file.lower().endswith(supported_exts)] - print(f" - Found {len(self.file_list)} images") + print_acc(f" - Found {len(self.file_list)} images") self.transform = transforms.Compose([ transforms.ToTensor(), @@ -435,11 +437,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti ]) # this might take a while - print(f"Dataset: {self.dataset_path}") - print(f" - Preprocessing image dimensions") + print_acc(f"Dataset: {self.dataset_path}") + print_acc(f" - Preprocessing image dimensions") dataset_folder = self.dataset_path if not os.path.isdir(self.dataset_path): dataset_folder = os.path.dirname(dataset_folder) + dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json') dataloader_version = "0.1.1" if os.path.exists(dataset_size_file): @@ -448,12 +451,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.size_database = json.load(f) if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version: - print("Upgrading size database to new version") + print_acc("Upgrading size database to new version") # old version, delete and recreate self.size_database = {} except Exception as e: - print(f"Error loading size database: {dataset_size_file}") - print(e) + print_acc(f"Error loading size database: {dataset_size_file}") + print_acc(e) self.size_database = {} else: self.size_database = {} @@ -473,22 +476,22 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti ) self.file_list.append(file_item) except Exception as e: - print(traceback.format_exc()) - print(f"Error processing image: {file}") - print(e) + print_acc(traceback.format_exc()) + print_acc(f"Error processing image: {file}") + print_acc(e) bad_count += 1 # save the size database with open(dataset_size_file, 'w') as f: json.dump(self.size_database, f) - print(f" - Found {len(self.file_list)} images") - # print(f" - Found {bad_count} images that are too small") + print_acc(f" - Found {len(self.file_list)} images") + # print_acc(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" # handle x axis flips if self.dataset_config.flip_x: - print(" - adding x axis flips") + print_acc(" - adding x axis flips") current_file_list = [x for x in self.file_list] for file_item in current_file_list: # create a copy that is flipped on the x axis @@ -498,7 +501,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti # handle y axis flips if self.dataset_config.flip_y: - print(" - adding y axis flips") + print_acc(" - adding y axis flips") current_file_list = [x for x in self.file_list] for file_item in current_file_list: # create a copy that is flipped on the y axis @@ -507,7 +510,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.file_list.append(new_file_item) if self.dataset_config.flip_x or self.dataset_config.flip_y: - print(f" - Found {len(self.file_list)} images after adding flips") + print_acc(f" - Found {len(self.file_list)} images after adding flips") self.setup_epoch() diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 1bba4431..57626c3f 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -24,6 +24,8 @@ from torchvision import transforms from PIL import Image, ImageFilter, ImageOps from PIL.ImageOps import exif_transpose import albumentations as A +from toolkit.print import print_acc +from toolkit.accelerator import get_accelerator from toolkit.train_tools import get_torch_dtype @@ -32,6 +34,8 @@ if TYPE_CHECKING: from toolkit.data_transfer_object.data_loader import FileItemDTO from toolkit.stable_diffusion_model import StableDiffusion +accelerator = get_accelerator() + # def get_associated_caption_from_img_path(img_path): # https://demo.albumentations.ai/ class Augments: @@ -263,7 +267,7 @@ class BucketsMixin: file_item.crop_y = int((file_item.scale_to_height - new_height) / 2) if file_item.crop_y < 0 or file_item.crop_x < 0: - print('debug') + print_acc('debug') # check if bucket exists, if not, create it bucket_key = f'{file_item.crop_width}x{file_item.crop_height}' @@ -275,10 +279,10 @@ class BucketsMixin: self.shuffle_buckets() self.build_batch_indices() if not quiet: - print(f'Bucket sizes for {self.dataset_path}:') + print_acc(f'Bucket sizes for {self.dataset_path}:') for key, bucket in self.buckets.items(): - print(f'{key}: {len(bucket.file_list_idx)} files') - print(f'{len(self.buckets)} buckets made') + print_acc(f'{key}: {len(bucket.file_list_idx)} files') + print_acc(f'{len(self.buckets)} buckets made') class CaptionProcessingDTOMixin: @@ -447,8 +451,8 @@ class ImageProcessingDTOMixin: img = Image.open(self.path) img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.path}") if self.use_alpha_as_mask: # we do this to make sure it does not replace the alpha with another color @@ -462,11 +466,11 @@ class ImageProcessingDTOMixin: w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match - print( + print_acc( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") elif h > w and self.scale_to_height < self.scale_to_width: # throw error, they should match - print( + print_acc( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") if self.flip_x: @@ -482,7 +486,7 @@ class ImageProcessingDTOMixin: # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: # todo look into this. This still happens sometimes - print('size mismatch') + print_acc('size mismatch') img = img.crop(( self.crop_x, self.crop_y, @@ -501,7 +505,7 @@ class ImageProcessingDTOMixin: if self.dataset_config.random_crop: if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution: if min_img_size < self.dataset_config.resolution: - print( + print_acc( f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}") scale_size = self.dataset_config.resolution else: @@ -567,8 +571,8 @@ class ControlFileItemDTOMixin: img = Image.open(self.control_path).convert('RGB') img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.control_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.control_path}") if self.full_size_control_images: # we just scale them to 512x512: @@ -782,8 +786,8 @@ class ClipImageFileItemDTOMixin: except Exception as e: # make a random noise image img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution)) - print(f"Error: {e}") - print(f"Error loading image: {clip_image_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {clip_image_path}") img = img.convert('RGB') @@ -981,8 +985,8 @@ class MaskFileItemDTOMixin: img = Image.open(self.mask_path) img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.mask_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.mask_path}") if self.use_alpha_as_mask: # pipeline expectws an rgb image so we need to put alpha in all channels @@ -999,11 +1003,11 @@ class MaskFileItemDTOMixin: fix_size = False if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match - print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") fix_size = True elif h > w and self.scale_to_height < self.scale_to_width: # throw error, they should match - print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") fix_size = True if fix_size: @@ -1085,8 +1089,8 @@ class UnconditionalFileItemDTOMixin: img = Image.open(self.unconditional_path) img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.mask_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.mask_path}") img = img.convert('RGB') w, h = img.size @@ -1166,9 +1170,9 @@ class PoiFileItemDTOMixin: with open(caption_path, 'r', encoding='utf-8') as f: json_data = json.load(f) if 'poi' not in json_data: - print(f"Warning: poi not found in caption file: {caption_path}") + print_acc(f"Warning: poi not found in caption file: {caption_path}") if self.poi not in json_data['poi']: - print(f"Warning: poi not found in caption file: {caption_path}") + print_acc(f"Warning: poi not found in caption file: {caption_path}") # poi has, x, y, width, height # do full image if no poi self.poi_x = 0 @@ -1242,8 +1246,8 @@ class PoiFileItemDTOMixin: # now we have our random crop, but it may be smaller than resolution. Check and expand if needed current_resolution = get_resolution(poi_width, poi_height) except Exception as e: - print(f"Error: {e}") - print(f"Error getting resolution: {self.path}") + print_acc(f"Error: {e}") + print_acc(f"Error getting resolution: {self.path}") raise e return False if current_resolution >= self.dataset_config.resolution: @@ -1252,7 +1256,7 @@ class PoiFileItemDTOMixin: else: num_loops += 1 if num_loops > 100: - print( + print_acc( f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.") return False @@ -1279,7 +1283,7 @@ class PoiFileItemDTOMixin: if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height: # todo look into this. This still happens sometimes - print('size mismatch') + print_acc('size mismatch') return True @@ -1373,88 +1377,89 @@ class LatentCachingMixin: self.latent_cache = {} def cache_latents_all_latents(self: 'AiToolkitDataset'): - print(f"Caching latents for {self.dataset_path}") - # cache all latents to disk - to_disk = self.is_caching_latents_to_disk - to_memory = self.is_caching_latents_to_memory + with accelerator.main_process_first(): + print_acc(f"Caching latents for {self.dataset_path}") + # cache all latents to disk + to_disk = self.is_caching_latents_to_disk + to_memory = self.is_caching_latents_to_memory - if to_disk: - print(" - Saving latents to disk") - if to_memory: - print(" - Keeping latents in memory") - # move sd items to cpu except for vae - self.sd.set_device_state_preset('cache_latents') + if to_disk: + print_acc(" - Saving latents to disk") + if to_memory: + print_acc(" - Keeping latents in memory") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_latents') - # use tqdm to show progress - i = 0 - for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): - # set latent space version - if self.sd.model_config.latent_space_version is not None: - file_item.latent_space_version = self.sd.model_config.latent_space_version - elif self.sd.is_xl: - file_item.latent_space_version = 'sdxl' - elif self.sd.is_v3: - file_item.latent_space_version = 'sd3' - elif self.sd.is_auraflow: - file_item.latent_space_version = 'sdxl' - elif self.sd.is_flux: - file_item.latent_space_version = 'flux1' - elif self.sd.model_config.is_pixart_sigma: - file_item.latent_space_version = 'sdxl' - else: - file_item.latent_space_version = 'sd1' - file_item.is_caching_to_disk = to_disk - file_item.is_caching_to_memory = to_memory - file_item.latent_load_device = self.sd.device + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): + # set latent space version + if self.sd.model_config.latent_space_version is not None: + file_item.latent_space_version = self.sd.model_config.latent_space_version + elif self.sd.is_xl: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_v3: + file_item.latent_space_version = 'sd3' + elif self.sd.is_auraflow: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_flux: + file_item.latent_space_version = 'flux1' + elif self.sd.model_config.is_pixart_sigma: + file_item.latent_space_version = 'sdxl' + else: + file_item.latent_space_version = 'sd1' + file_item.is_caching_to_disk = to_disk + file_item.is_caching_to_memory = to_memory + file_item.latent_load_device = self.sd.device - latent_path = file_item.get_latent_path(recalculate=True) - # check if it is saved to disk already - if os.path.exists(latent_path): - if to_memory: - # load it into memory - state_dict = load_file(latent_path, device='cpu') - file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) - else: - # not saved to disk, calculate - # load the image first - file_item.load_and_process_image(self.transform, only_load_latents=True) - dtype = self.sd.torch_dtype - device = self.sd.device_torch - # add batch dimension - try: - imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) - latent = self.sd.encode_images(imgs).squeeze(0) - except Exception as e: - print(f"Error processing image: {file_item.path}") - print(f"Error: {str(e)}") - raise e - # save_latent - if to_disk: - state_dict = OrderedDict([ - ('latent', latent.clone().detach().cpu()), - ]) - # metadata - meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) - os.makedirs(os.path.dirname(latent_path), exist_ok=True) - save_file(state_dict, latent_path, metadata=meta) + latent_path = file_item.get_latent_path(recalculate=True) + # check if it is saved to disk already + if os.path.exists(latent_path): + if to_memory: + # load it into memory + state_dict = load_file(latent_path, device='cpu') + file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + else: + # not saved to disk, calculate + # load the image first + file_item.load_and_process_image(self.transform, only_load_latents=True) + dtype = self.sd.torch_dtype + device = self.sd.device_torch + # add batch dimension + try: + imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + latent = self.sd.encode_images(imgs).squeeze(0) + except Exception as e: + print_acc(f"Error processing image: {file_item.path}") + print_acc(f"Error: {str(e)}") + raise e + # save_latent + if to_disk: + state_dict = OrderedDict([ + ('latent', latent.clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) + os.makedirs(os.path.dirname(latent_path), exist_ok=True) + save_file(state_dict, latent_path, metadata=meta) - if to_memory: - # keep it in memory - file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) + if to_memory: + # keep it in memory + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) - del imgs - del latent - del file_item.tensor + del imgs + del latent + del file_item.tensor - # flush(garbage_collect=False) - file_item.is_latent_cached = True - i += 1 - # flush every 100 - # if i % 100 == 0: - # flush() + # flush(garbage_collect=False) + file_item.is_latent_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() - # restore device state - self.sd.restore_device_state() + # restore device state + self.sd.restore_device_state() class CLIPCachingMixin: @@ -1469,9 +1474,9 @@ class CLIPCachingMixin: if not self.is_caching_clip_vision_to_disk: return with torch.no_grad(): - print(f"Caching clip vision for {self.dataset_path}") + print_acc(f"Caching clip vision for {self.dataset_path}") - print(" - Saving clip to disk") + print_acc(" - Saving clip to disk") # move sd items to cpu except for vae self.sd.set_device_state_preset('cache_clip') @@ -1512,7 +1517,7 @@ class CLIPCachingMixin: self.clip_vision_num_unconditional_cache = 1 # cache unconditionals - print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk") + print_acc(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk") clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache') unconditional_paths = [] diff --git a/toolkit/print.py b/toolkit/print.py new file mode 100644 index 00000000..0ada4102 --- /dev/null +++ b/toolkit/print.py @@ -0,0 +1,6 @@ +from toolkit.accelerator import get_accelerator + + +def print_acc(*args, **kwargs): + if get_accelerator().is_local_main_process: + print(*args, **kwargs) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 286d2410..66469a9f 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -63,7 +63,9 @@ from huggingface_hub import hf_hub_download from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING +from toolkit.print import print_acc if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -130,18 +132,17 @@ class StableDiffusion: noise_scheduler=None, quantize_device=None, ): + self.accelerator = get_accelerator() self.custom_pipeline = custom_pipeline - self.device = device + self.device = str(self.accelerator.device) self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) - self.device_torch = torch.device(self.device) + self.device_torch = self.accelerator.device - self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( - model_config.vae_device) + self.vae_device_torch = self.accelerator.device self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) - self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( - model_config.te_device) + self.te_device_torch = self.accelerator.device self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) self.model_config = model_config @@ -186,7 +187,7 @@ class StableDiffusion: if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): self.is_flow_matching = True - self.quantize_device = quantize_device if quantize_device is not None else self.device + self.quantize_device = self.device_torch self.low_vram = self.model_config.low_vram # merge in and preview active with -1 weight @@ -254,8 +255,8 @@ class StableDiffusion: pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) if self.model_config.experimental_xl: - print("Experimental XL mode enabled") - print("Loading and injecting alt weights") + print_acc("Experimental XL mode enabled") + print_acc("Loading and injecting alt weights") # load the mismatched weight and force it in raw_state_dict = load_file(model_path) replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() @@ -265,17 +266,17 @@ class StableDiffusion: # replace weight with mismatched weight te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) flush() - print("Injecting alt weights") + print_acc("Injecting alt weights") elif self.model_config.is_v3: if self.custom_pipeline is not None: pipln = self.custom_pipeline else: pipln = StableDiffusion3Pipeline - print("Loading SD3 model") + print_acc("Loading SD3 model") # assume it is the large model base_model_path = "stabilityai/stable-diffusion-3.5-large" - print("Loading transformer") + print_acc("Loading transformer") subfolder = 'transformer' transformer_path = model_path # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set @@ -298,7 +299,7 @@ class StableDiffusion: ) if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu - transformer.to(torch.device(self.quantize_device), dtype=dtype) + transformer.to(self.quantize_device, dtype=dtype) flush() if self.model_config.lora_path is not None: @@ -306,7 +307,7 @@ class StableDiffusion: if self.model_config.quantize: quantization_type = qfloat8 - print("Quantizing transformer") + print_acc("Quantizing transformer") quantize(transformer, weights=quantization_type) freeze(transformer) transformer.to(self.device_torch) @@ -314,11 +315,11 @@ class StableDiffusion: transformer.to(self.device_torch, dtype=dtype) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print("Loading vae") + print_acc("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() - print("Loading t5") + print_acc("Loading t5") tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) text_encoder_3 = T5EncoderModel.from_pretrained( base_model_path, @@ -330,7 +331,7 @@ class StableDiffusion: flush() if self.model_config.quantize: - print("Quantizing T5") + print_acc("Quantizing T5") quantize(text_encoder_3, weights=qfloat8) freeze(text_encoder_3) flush() @@ -354,7 +355,7 @@ class StableDiffusion: **load_args ) except Exception as e: - print(f"Error loading from pretrained: {e}") + print_acc(f"Error loading from pretrained: {e}") raise e else: @@ -529,10 +530,10 @@ class StableDiffusion: tokenizer = pipe.tokenizer elif self.model_config.is_flux: - print("Loading Flux model") + print_acc("Loading Flux model") # base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = self.model_config.name_or_path_original - print("Loading transformer") + print_acc("Loading transformer") subfolder = 'transformer' transformer_path = model_path local_files_only = False @@ -559,7 +560,7 @@ class StableDiffusion: if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu - transformer.to(torch.device(self.quantize_device), dtype=dtype) + transformer.to(self.quantize_device, dtype=dtype) flush() if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: @@ -581,7 +582,7 @@ class StableDiffusion: load_lora_path, "pytorch_lora_weights.safetensors" ) elif not os.path.exists(load_lora_path): - print(f"Grabbing lora from the hub: {load_lora_path}") + print_acc(f"Grabbing lora from the hub: {load_lora_path}") new_lora_path = hf_hub_download( load_lora_path, filename="pytorch_lora_weights.safetensors" @@ -604,7 +605,7 @@ class StableDiffusion: self.model_config.lora_path = self.model_config.assistant_lora_path if self.model_config.lora_path is not None: - print("Fusing in LoRA") + print_acc("Fusing in LoRA") # need the pipe for peft pipe: FluxPipeline = FluxPipeline( scheduler=None, @@ -635,7 +636,7 @@ class StableDiffusion: # double blocks transformer.transformer_blocks = transformer.transformer_blocks.to( - torch.device(self.quantize_device), dtype=dtype + self.quantize_device, dtype=dtype ) pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") pipe.fuse_lora() @@ -646,7 +647,7 @@ class StableDiffusion: # single blocks transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( - torch.device(self.quantize_device), dtype=dtype + self.quantize_device, dtype=dtype ) pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") pipe.fuse_lora() @@ -674,7 +675,7 @@ class StableDiffusion: # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = qfloat8 - print("Quantizing transformer") + print_acc("Quantizing transformer") quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) freeze(transformer) transformer.to(self.device_torch) @@ -684,11 +685,11 @@ class StableDiffusion: flush() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print("Loading vae") + print_acc("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() - print("Loading t5") + print_acc("Loading t5") tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) @@ -697,17 +698,17 @@ class StableDiffusion: flush() if self.model_config.quantize_te: - print("Quantizing T5") + print_acc("Quantizing T5") quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) flush() - print("Loading clip") + print_acc("Loading clip") text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) - print("making pipe") + print_acc("making pipe") pipe: FluxPipeline = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, @@ -720,7 +721,7 @@ class StableDiffusion: pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer - print("preparing") + print_acc("preparing") text_encoder = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] @@ -836,7 +837,7 @@ class StableDiffusion: self.is_loaded = True if self.model_config.assistant_lora_path is not None: - print("Loading assistant lora") + print_acc("Loading assistant lora") self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( self.model_config.assistant_lora_path, self) @@ -846,7 +847,7 @@ class StableDiffusion: self.assistant_lora.is_active = False if self.model_config.inference_lora_path is not None: - print("Loading inference lora") + print_acc("Loading inference lora") self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( self.model_config.inference_lora_path, self) # disable during training @@ -917,11 +918,12 @@ class StableDiffusion: sampler=None, pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, ): + network = unwrap_model(self.network) merge_multiplier = 1.0 flush() # if using assistant, unfuse it if self.model_config.assistant_lora_path is not None: - print("Unloading assistant lora") + print_acc("Unloading assistant lora") if self.invert_assistant_lora: self.assistant_lora.is_active = True # move weights on to the device @@ -930,18 +932,17 @@ class StableDiffusion: self.assistant_lora.is_active = False if self.model_config.inference_lora_path is not None: - print("Loading inference lora") + print_acc("Loading inference lora") self.assistant_lora.is_active = True # move weights on to the device self.assistant_lora.force_to(self.device_torch, self.torch_dtype) - if self.network is not None: - self.network.eval() - network = self.network + if network is not None: + network.eval() # check if we have the same network weight for all samples. If we do, we can merge in th # the network to drastically speed up inference unique_network_weights = set([x.network_multiplier for x in image_configs]) - if len(unique_network_weights) == 1 and self.network.can_merge_in: + if len(unique_network_weights) == 1 and network.can_merge_in: can_merge_in = True merge_multiplier = unique_network_weights.pop() network.merge_in(merge_weight=merge_multiplier) @@ -1119,15 +1120,15 @@ class StableDiffusion: flush() start_multiplier = 1.0 - if self.network is not None: - start_multiplier = self.network.multiplier + if network is not None: + start_multiplier = network.multiplier # pipeline.to(self.device_torch) with network: with torch.no_grad(): - if self.network is not None: - assert self.network.is_active + if network is not None: + assert network.is_active for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): gen_config = image_configs[i] @@ -1164,8 +1165,8 @@ class StableDiffusion: validation_image = validation_image.unsqueeze(0) self.adapter.set_reference_images(validation_image) - if self.network is not None: - self.network.multiplier = gen_config.network_multiplier + if network is not None: + network.multiplier = gen_config.network_multiplier torch.manual_seed(gen_config.seed) torch.cuda.manual_seed(gen_config.seed) @@ -1332,6 +1333,12 @@ class StableDiffusion: **extra ).images[0] else: + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + latents = callback_kwargs["latents"] + if latents.dtype != self.unet.dtype: + latents = latents.to(self.unet.dtype) + return {"latents": latents} img = pipeline( prompt_embeds=conditional_embeds.text_embeds, pooled_prompt_embeds=conditional_embeds.pooled_embeds, @@ -1343,6 +1350,7 @@ class StableDiffusion: guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, generator=generator, + callback_on_step_end=callback_on_step_end, **extra ).images[0] elif self.is_pixart: @@ -1448,9 +1456,9 @@ class StableDiffusion: torch.cuda.set_rng_state(cuda_rng_state) self.restore_device_state() - if self.network is not None: - self.network.train() - self.network.multiplier = start_multiplier + if network is not None: + network.train() + network.multiplier = start_multiplier self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: @@ -1459,7 +1467,7 @@ class StableDiffusion: # refuse loras if self.model_config.assistant_lora_path is not None: - print("Loading assistant lora") + print_acc("Loading assistant lora") if self.invert_assistant_lora: self.assistant_lora.is_active = False # move weights off the device @@ -1468,7 +1476,7 @@ class StableDiffusion: self.assistant_lora.is_active = True if self.model_config.inference_lora_path is not None: - print("Unloading inference lora") + print_acc("Unloading inference lora") self.assistant_lora.is_active = False # move weights off the device self.assistant_lora.force_to('cpu', self.torch_dtype) @@ -1867,6 +1875,11 @@ class StableDiffusion: bypass_flux_guidance(self.unet) cast_dtype = self.unet.dtype + # changes from orig implementation + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): noise_pred = self.unet( hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] @@ -2513,7 +2526,7 @@ class StableDiffusion: params.append(named_params[diffusers_key]) param_data = {"params": params, "lr": unet_lr} trainable_parameters.append(param_data) - print(f"Found {len(params)} trainable parameter in unet") + print_acc(f"Found {len(params)} trainable parameter in unet") if text_encoder: named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) @@ -2526,7 +2539,7 @@ class StableDiffusion: param_data = {"params": params, "lr": text_encoder_lr} trainable_parameters.append(param_data) - print(f"Found {len(params)} trainable parameter in text encoder") + print_acc(f"Found {len(params)} trainable parameter in text encoder") if refiner: named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, @@ -2541,7 +2554,7 @@ class StableDiffusion: param_data = {"params": params, "lr": refiner_lr} trainable_parameters.append(param_data) - print(f"Found {len(params)} trainable parameter in refiner") + print_acc(f"Found {len(params)} trainable parameter in refiner") return trainable_parameters