diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3747972d..09156909 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -20,6 +20,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, Guid from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter from toolkit.custom_adapter import CustomAdapter +from toolkit.print import print_acc from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork @@ -59,8 +60,6 @@ class SDTrainer(BaseSDTrainProcess): self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None - self.scaler = torch.cuda.amp.GradScaler() - self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" self.do_grad_scale = True @@ -70,12 +69,12 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter_config.train: self.do_grad_scale = False - if self.train_config.dtype in ["fp16", "float16"]: - # patch the scaler to allow fp16 training - org_unscale_grads = self.scaler._unscale_grads_ - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - self.scaler._unscale_grads_ = _unscale_grads_replacer + # if self.train_config.dtype in ["fp16", "float16"]: + # # patch the scaler to allow fp16 training + # org_unscale_grads = self.scaler._unscale_grads_ + # def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + # return org_unscale_grads(optimizer, inv_scale, found_inf, True) + # self.scaler._unscale_grads_ = _unscale_grads_replacer self.cached_blank_embeds: Optional[PromptEmbeds] = None self.cached_trigger_embeds: Optional[PromptEmbeds] = None @@ -168,11 +167,11 @@ class SDTrainer(BaseSDTrainProcess): raise ValueError("Cannot unload text encoder if training text encoder") # cache embeddings - print("\n***** UNLOADING TEXT ENCODER *****") - print("This will train only with a blank prompt or trigger word, if set") - print("If this is not what you want, remove the unload_text_encoder flag") - print("***********************************") - print("") + print_acc("\n***** UNLOADING TEXT ENCODER *****") + print_acc("This will train only with a blank prompt or trigger word, if set") + print_acc("If this is not what you want, remove the unload_text_encoder flag") + print_acc("***********************************") + print_acc("") self.sd.text_encoder_to(self.device_torch) self.cached_blank_embeds = self.sd.encode_prompt("") if self.trigger_word is not None: @@ -388,7 +387,7 @@ class SDTrainer(BaseSDTrainProcess): pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \ self.train_config.diffusion_feature_extractor_weight - else: + elif self.dfe.version == 2: # version 2 # do diffusion feature extraction on target with torch.no_grad(): @@ -403,6 +402,17 @@ class SDTrainer(BaseSDTrainProcess): dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 + elif self.dfe.version == 3: + dfe_loss = self.dfe( + noise_pred=noise_pred, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + scheduler=self.sd.noise_scheduler + ) + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight + else: + raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") if target is None: @@ -484,7 +494,7 @@ class SDTrainer(BaseSDTrainProcess): prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier if torch.isnan(prior_loss).any(): - print("Prior loss is nan") + print_acc("Prior loss is nan") prior_loss = None else: prior_loss = prior_loss.mean([1, 2, 3]) @@ -553,7 +563,6 @@ class SDTrainer(BaseSDTrainProcess): noise=noise, sd=self.sd, unconditional_embeds=unconditional_embeds, - scaler=self.scaler, **kwargs ) @@ -668,7 +677,7 @@ class SDTrainer(BaseSDTrainProcess): # loss = self.apply_snr(loss, timesteps) loss = loss.mean() - loss.backward() + self.accelerator.backward(loss) # detach it so parent class can run backward on no grads without throwing error loss = loss.detach() @@ -823,7 +832,7 @@ class SDTrainer(BaseSDTrainProcess): # loss = self.apply_snr(loss, timesteps) loss = loss.mean() - loss.backward() + self.accelerator.backward(loss) # detach it so parent class can run backward on no grads without throwing error loss = loss.detach() @@ -1087,6 +1096,8 @@ class SDTrainer(BaseSDTrainProcess): # expand to match latents mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + # make avg 1.0 + mask_multiplier = mask_multiplier / mask_multiplier.mean() def get_adapter_multiplier(): if self.adapter and isinstance(self.adapter, T2IAdapter): @@ -1446,8 +1457,8 @@ class SDTrainer(BaseSDTrainProcess): quad_count=quad_count ) else: - print("No Clip Image") - print([file_item.path for file_item in batch.file_items]) + print_acc("No Clip Image") + print_acc([file_item.path for file_item in batch.file_items]) raise ValueError("Could not find clip image") if not self.adapter_config.train_image_encoder: @@ -1625,7 +1636,7 @@ class SDTrainer(BaseSDTrainProcess): ) # check if nan if torch.isnan(loss): - print("loss is nan") + print_acc("loss is nan") loss = torch.zeros_like(loss).requires_grad_(True) with self.timer('backward'): @@ -1640,10 +1651,7 @@ class SDTrainer(BaseSDTrainProcess): # if self.is_bfloat: # loss.backward() # else: - if not self.do_grad_scale: - loss.backward() - else: - self.scaler.scale(loss).backward() + self.accelerator.backward(loss) return loss.detach() # flush() @@ -1668,21 +1676,14 @@ class SDTrainer(BaseSDTrainProcess): if not self.is_grad_accumulation_step: # fix this for multi params if self.train_config.optimizer != 'adafactor': - if self.do_grad_scale: - self.scaler.unscale_(self.optimizer) if isinstance(self.params[0], dict): for i in range(len(self.params)): - torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) else: - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # only step if we are not accumulating with self.timer('optimizer_step'): - # self.optimizer.step() - if not self.do_grad_scale: - self.optimizer.step() - else: - self.scaler.step(self.optimizer) - self.scaler.update() + self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) if self.adapter and isinstance(self.adapter, CustomAdapter): diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3b210154..2183709f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -61,6 +61,11 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw DecoratorConfig from toolkit.logging import create_logger from diffusers import FluxTransformer2DModel +from toolkit.accelerator import get_accelerator +from toolkit.print import print_acc +from accelerate import Accelerator +import transformers +import diffusers def flush(): torch.cuda.empty_cache() @@ -71,6 +76,14 @@ class BaseSDTrainProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): super().__init__(process_id, job, config) + self.accelerator: Accelerator = get_accelerator() + if self.accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_error() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + self.sd: StableDiffusion self.embedding: Union[Embedding, None] = None @@ -82,8 +95,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.grad_accumulation_step = 1 # if true, then we do not do an optimizer step. We are accumulating gradients self.is_grad_accumulation_step = False - self.device = self.get_conf('device', self.job.device) - self.device_torch = torch.device(self.device) + self.device = str(self.accelerator.device) + self.device_torch = self.accelerator.device network_config = self.get_conf('network', None) if network_config is not None: self.network_config = NetworkConfig(**network_config) @@ -91,6 +104,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network_config = None self.train_config = TrainConfig(**self.get_conf('train', {})) model_config = self.get_conf('model', {}) + self.modules_being_trained: List[torch.nn.Module] = [] # update modelconfig dtype to match train model_config['dtype'] = self.train_config.dtype @@ -222,6 +236,8 @@ class BaseSDTrainProcess(BaseTrainProcess): return generate_image_config_list def sample(self, step=None, is_first=False): + if not self.accelerator.is_main_process: + return flush() sample_folder = os.path.join(self.save_root, 'samples') gen_img_config_list = [] @@ -316,6 +332,8 @@ class BaseSDTrainProcess(BaseTrainProcess): elif self.model_config.is_xl: o_dict['ss_base_model_version'] = 'sdxl_1.0' + elif self.model_config.is_flux: + o_dict['ss_base_model_version'] = 'flux.1' else: o_dict['ss_base_model_version'] = 'sd_1.5' @@ -344,6 +362,8 @@ class BaseSDTrainProcess(BaseTrainProcess): return info def clean_up_saves(self): + if not self.accelerator.is_main_process: + return # remove old saves # get latest saved step latest_item = None @@ -400,7 +420,7 @@ class BaseSDTrainProcess(BaseTrainProcess): items_to_remove = list(dict.fromkeys(items_to_remove)) for item in items_to_remove: - self.print(f"Removing old save: {item}") + print_acc(f"Removing old save: {item}") if os.path.isdir(item): shutil.rmtree(item) else: @@ -418,6 +438,8 @@ class BaseSDTrainProcess(BaseTrainProcess): pass def save(self, step=None): + if not self.accelerator.is_main_process: + return flush() if self.ema is not None: # always save params as ema @@ -594,10 +616,10 @@ class BaseSDTrainProcess(BaseTrainProcess): state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) except Exception as e: - print(e) - print("Could not save optimizer") + print_acc(e) + print_acc("Could not save optimizer") - self.print(f"Saved to {file_path}") + print_acc(f"Saved to {file_path}") self.clean_up_saves() self.post_save_hook(file_path) @@ -619,7 +641,49 @@ class BaseSDTrainProcess(BaseTrainProcess): return params def hook_before_train_loop(self): - self.logger.start() + if self.accelerator.is_main_process: + self.logger.start() + self.prepare_accelerator() + + + def prepare_accelerator(self): + # set some config + self.accelerator.even_batches=False + + # # prepare all the models stuff for accelerator (hopefully we dont miss any) + self.sd.vae = self.accelerator.prepare(self.sd.vae) + if self.sd.unet is not None: + self.sd.unet_unwrapped = self.sd.unet + self.sd.unet = self.accelerator.prepare(self.sd.unet) + # todo always tdo it? + self.modules_being_trained.append(self.sd.unet) + if self.sd.text_encoder is not None and self.train_config.train_text_encoder: + if isinstance(self.sd.text_encoder, list): + self.sd.text_encoder = [self.accelerator.prepare(model) for model in self.sd.text_encoder] + self.modules_being_trained.extend(self.sd.text_encoder) + else: + self.sd.text_encoder = self.accelerator.prepare(self.sd.text_encoder) + self.modules_being_trained.append(self.sd.text_encoder) + if self.sd.refiner_unet is not None and self.train_config.train_refiner: + self.sd.refiner_unet = self.accelerator.prepare(self.sd.refiner_unet) + self.modules_being_trained.append(self.sd.refiner_unet) + # todo, do we need to do the network or will "unet" get it? + if self.sd.network is not None: + self.sd.network = self.accelerator.prepare(self.sd.network) + self.modules_being_trained.append(self.sd.network) + if self.adapter is not None and self.adapter_config.train: + # todo adapters may not be a module. need to check + self.adapter = self.accelerator.prepare(self.adapter) + self.modules_being_trained.append(self.adapter) + + # prepare other things + self.optimizer = self.accelerator.prepare(self.optimizer) + if self.lr_scheduler is not None: + self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) + # self.data_loader = self.accelerator.prepare(self.data_loader) + # if self.data_loader_reg is not None: + # self.data_loader_reg = self.accelerator.prepare(self.data_loader_reg) + def ensure_params_requires_grad(self, force=False): if self.train_config.do_paramiter_swapping and not force: @@ -692,6 +756,8 @@ class BaseSDTrainProcess(BaseTrainProcess): return latest_path def load_training_state_from_metadata(self, path): + if not self.accelerator.is_main_process: + return meta = None # if path is folder, then it is diffusers if os.path.isdir(path): @@ -708,7 +774,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if 'epoch' in meta['training_info']: self.epoch_num = meta['training_info']['epoch'] self.start_step = self.step_num - print(f"Found step {self.step_num} in metadata, starting from there") + print_acc(f"Found step {self.step_num} in metadata, starting from there") def load_weights(self, path): if self.network is not None: @@ -716,7 +782,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.load_training_state_from_metadata(path) return extra_weights else: - print("load_weights not implemented for non-network models") + print_acc("load_weights not implemented for non-network models") return None def apply_snr(self, seperated_loss, timesteps): @@ -747,7 +813,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if 'epoch' in meta['training_info']: self.epoch_num = meta['training_info']['epoch'] self.start_step = self.step_num - print(f"Found step {self.step_num} in metadata, starting from there") + print_acc(f"Found step {self.step_num} in metadata, starting from there") # def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): # self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch) @@ -1000,7 +1066,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, - timestep_type=timestep_type + timestep_type=timestep_type, + latents=latents ) else: self.sd.noise_scheduler.set_timesteps( @@ -1244,7 +1311,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.adapter.to(self.device_torch, dtype=dtype) if latest_save_path is not None and not is_control_net: # load adapter from path - print(f"Loading adapter from {latest_save_path}") + print_acc(f"Loading adapter from {latest_save_path}") if is_t2i: loaded_state_dict = load_t2i_model( latest_save_path, @@ -1290,7 +1357,7 @@ class BaseSDTrainProcess(BaseTrainProcess): latest_save_path = self.get_latest_save_path() if latest_save_path is not None: - print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") model_config_to_load.name_or_path = latest_save_path self.load_training_state_from_metadata(latest_save_path) @@ -1357,7 +1424,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # block.attn.set_processor(processor) # except ImportError: - # print("sage attention is not installed. Using SDP instead") + # print_acc("sage attention is not installed. Using SDP instead") if self.train_config.gradient_checkpointing: if self.sd.is_flux: @@ -1531,8 +1598,8 @@ class BaseSDTrainProcess(BaseTrainProcess): latest_save_path = self.get_latest_save_path(lora_name) extra_weights = None if latest_save_path is not None: - self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - self.print(f"Loading from {latest_save_path}") + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + print_acc(f"Loading from {latest_save_path}") extra_weights = self.load_weights(latest_save_path) self.network.multiplier = 1.0 @@ -1569,7 +1636,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if latest_save_path is not None: state_dict = load_file(latest_save_path) self.decorator.load_state_dict(state_dict) - self.load_training_state_from_metadata(path) + self.load_training_state_from_metadata(latest_save_path) params.append({ 'params': list(self.decorator.parameters()), @@ -1665,17 +1732,17 @@ class BaseSDTrainProcess(BaseTrainProcess): previous_lrs.append(group['lr']) try: - print(f"Loading optimizer state from {optimizer_state_file_path}") + print_acc(f"Loading optimizer state from {optimizer_state_file_path}") optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) optimizer.load_state_dict(optimizer_state_dict) del optimizer_state_dict flush() except Exception as e: - print(f"Failed to load optimizer state from {optimizer_state_file_path}") - print(e) + print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") + print_acc(e) # update the optimizer LR from the params - print(f"Updating optimizer LR from params") + print_acc(f"Updating optimizer LR from params") if len(previous_lrs) > 0: for i, group in enumerate(optimizer.param_groups): group['lr'] = previous_lrs[i] @@ -1711,24 +1778,27 @@ class BaseSDTrainProcess(BaseTrainProcess): self.hook_before_train_loop() if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling: - self.print("Generating first sample from first sample config") + print_acc("Generating first sample from first sample config") self.sample(0, is_first=True) # sample first if self.train_config.skip_first_sample or self.train_config.disable_sampling: - self.print("Skipping first sample due to config setting") + print_acc("Skipping first sample due to config setting") elif self.step_num <= 1 or self.train_config.force_first_sample: - self.print("Generating baseline samples before training") + print_acc("Generating baseline samples before training") self.sample(self.step_num) - - self.progress_bar = ToolkitProgressBar( - total=self.train_config.steps, - desc=self.job.name, - leave=True, - initial=self.step_num, - iterable=range(0, self.train_config.steps), - ) - self.progress_bar.pause() + + if self.accelerator.is_local_main_process: + self.progress_bar = ToolkitProgressBar( + total=self.train_config.steps, + desc=self.job.name, + leave=True, + initial=self.step_num, + iterable=range(0, self.train_config.steps), + ) + self.progress_bar.pause() + else: + self.progress_bar = None if self.data_loader is not None: dataloader = self.data_loader @@ -1753,7 +1823,7 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() # self.step_num = 0 - # print(f"Compiling Model") + # print_acc(f"Compiling Model") # torch.compile(self.sd.unet, dynamic=True) # make sure all params require grad @@ -1769,7 +1839,7 @@ class BaseSDTrainProcess(BaseTrainProcess): did_first_flush = False for step in range(start_step_num, self.train_config.steps): if self.train_config.do_paramiter_swapping: - self.optimizer.swap_paramiters() + self.optimizer.optimizer.swap_paramiters() self.timer.start('train_loop') if self.train_config.do_random_cfg: self.train_config.do_cfg = True @@ -1779,7 +1849,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.is_grad_accumulation_step = True if self.train_config.free_u: self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2) - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() with torch.no_grad(): # if is even step and we have a reg dataset, use that # todo improve this logic to send one of each through if we can buckets and batch size might be an issue @@ -1802,13 +1873,15 @@ class BaseSDTrainProcess(BaseTrainProcess): except StopIteration: with self.timer('reset_batch:reg'): # hit the end of an epoch, reset - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() dataloader_iterator_reg = iter(dataloader_reg) trigger_dataloader_setup_epoch(dataloader_reg) with self.timer('get_batch:reg'): batch = next(dataloader_iterator_reg) - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() is_reg_step = True elif dataloader is not None: try: @@ -1817,7 +1890,8 @@ class BaseSDTrainProcess(BaseTrainProcess): except StopIteration: with self.timer('reset_batch'): # hit the end of an epoch, reset - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() dataloader_iterator = iter(dataloader) trigger_dataloader_setup_epoch(dataloader) self.epoch_num += 1 @@ -1827,7 +1901,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.grad_accumulation_step = 0 with self.timer('get_batch'): batch = next(dataloader_iterator) - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() else: batch = None batch_list.append(batch) @@ -1849,8 +1924,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### - - loss_dict = self.hook_train_loop(batch_list) + with self.accelerator.accumulate(self.modules_being_trained): + loss_dict = self.hook_train_loop(batch_list) self.timer.stop('train_loop') if not did_first_flush: flush() @@ -1880,7 +1955,8 @@ class BaseSDTrainProcess(BaseTrainProcess): for key, value in loss_dict.items(): prog_bar_string += f" {key}: {value:.3e}" - self.progress_bar.set_postfix_str(prog_bar_string) + if self.progress_bar is not None: + self.progress_bar.set_postfix_str(prog_bar_string) # if the batch is a DataLoaderBatchDTO, then we need to clean it up if isinstance(batch, DataLoaderBatchDTO): @@ -1889,8 +1965,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # don't do on first step if self.step_num != self.start_step: + if is_sample_step or is_save_step: + self.accelerator.wait_for_everyone() if is_sample_step: - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() flush() # print above the progress bar if self.train_config.free_u: @@ -1902,57 +1981,70 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() self.ensure_params_requires_grad() - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() if is_save_step: + self.accelerator # print above the progress bar - self.progress_bar.pause() - self.print(f"Saving at step {self.step_num}") + if self.progress_bar is not None: + self.progress_bar.pause() + print_acc(f"Saving at step {self.step_num}") self.save(self.step_num) self.ensure_params_requires_grad() - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() with self.timer('log_to_tensorboard'): # log to tensorboard - if self.writer is not None: - for key, value in loss_dict.items(): - self.writer.add_scalar(f"{key}", value, self.step_num) - self.writer.add_scalar(f"lr", learning_rate, self.step_num) - self.progress_bar.unpause() + if self.accelerator.is_main_process: + if self.writer is not None: + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) + self.writer.add_scalar(f"lr", learning_rate, self.step_num) + if self.progress_bar is not None: + self.progress_bar.unpause() - # log to logger - self.logger.log({ - 'learning_rate': learning_rate, - }) - for key, value in loss_dict.items(): + if self.accelerator.is_main_process: + # log to logger self.logger.log({ - f'loss/{key}': value, + 'learning_rate': learning_rate, }) - elif self.logging_config.log_every is None: - # log every step - self.logger.log({ - 'learning_rate': learning_rate, - }) - for key, value in loss_dict.items(): + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) + elif self.logging_config.log_every is None: + if self.accelerator.is_main_process: + # log every step self.logger.log({ - f'loss/{key}': value, + 'learning_rate': learning_rate, }) + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0: - self.progress_bar.pause() + if self.progress_bar is not None: + self.progress_bar.pause() # print the timers and clear them self.timer.print() self.timer.reset() - self.progress_bar.unpause() + if self.progress_bar is not None: + self.progress_bar.unpause() # commit log - self.logger.commit(step=self.step_num) + if self.accelerator.is_main_process: + self.logger.commit(step=self.step_num) # sets progress bar to match out step - self.progress_bar.update(step - self.progress_bar.n) + if self.progress_bar is not None: + self.progress_bar.update(step - self.progress_bar.n) ############################# # End of step @@ -1966,16 +2058,19 @@ class BaseSDTrainProcess(BaseTrainProcess): ################################################################### ## END TRAIN LOOP ################################################################### - - self.progress_bar.close() + self.accelerator.wait_for_everyone() + if self.progress_bar is not None: + self.progress_bar.close() if self.train_config.free_u: self.sd.pipeline.disable_freeu() if not self.train_config.disable_sampling: self.sample(self.step_num) self.logger.commit(step=self.step_num) - print("") - self.save() - self.logger.finish() + print_acc("") + if self.accelerator.is_main_process: + self.save() + self.logger.finish() + self.accelerator.end_training() if self.save_config.push_to_hub: if("HF_TOKEN" not in os.environ): @@ -2001,6 +2096,8 @@ class BaseSDTrainProcess(BaseTrainProcess): repo_id: str, private: bool = False, ): + if not self.accelerator.is_main_process: + return readme_content = self._generate_readme(repo_id) readme_path = os.path.join(self.save_root, "README.md") with open(readme_path, "w", encoding="utf-8") as f: diff --git a/requirements.txt b/requirements.txt index f45766b8..2f2cdc3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -torch -torchvision +torch==2.5.1 +torchvision==0.20.1 safetensors -git+https://github.com/huggingface/diffusers.git +diffusers==0.32.2 transformers lycoris-lora==1.8.3 flatten_json diff --git a/run.py b/run.py index 6f133081..9a3e57fd 100644 --- a/run.py +++ b/run.py @@ -20,20 +20,26 @@ if os.environ.get("DEBUG_TOOLKIT", "0") == "1": torch.autograd.set_detect_anomaly(True) import argparse from toolkit.job import get_job +from toolkit.accelerator import get_accelerator +from toolkit.print import print_acc + +accelerator = get_accelerator() def print_end_message(jobs_completed, jobs_failed): + if not accelerator.is_main_process: + return failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else "" completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}" - print("") - print("========================================") - print("Result:") + print_acc("") + print_acc("========================================") + print_acc("Result:") if len(completed_string) > 0: - print(f" - {completed_string}") + print_acc(f" - {completed_string}") if len(failure_string) > 0: - print(f" - {failure_string}") - print("========================================") + print_acc(f" - {failure_string}") + print_acc("========================================") def main(): @@ -70,7 +76,8 @@ def main(): jobs_completed = 0 jobs_failed = 0 - print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") + if accelerator.is_main_process: + print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}") for config_file in config_file_list: try: @@ -79,7 +86,7 @@ def main(): job.cleanup() jobs_completed += 1 except Exception as e: - print(f"Error running job: {e}") + print_acc(f"Error running job: {e}") jobs_failed += 1 if not args.recover: print_end_message(jobs_completed, jobs_failed) diff --git a/todo_multigpu.md b/todo_multigpu.md new file mode 100644 index 00000000..02d5abda --- /dev/null +++ b/todo_multigpu.md @@ -0,0 +1,3 @@ +- only do ema on main device? shouldne be needed other than saving and sampling +- check when to unwrap model and what it does +- disable timer for non main local \ No newline at end of file diff --git a/toolkit/accelerator.py b/toolkit/accelerator.py new file mode 100644 index 00000000..ebcf0095 --- /dev/null +++ b/toolkit/accelerator.py @@ -0,0 +1,17 @@ +from accelerate import Accelerator +from diffusers.utils.torch_utils import is_compiled_module + +global_accelerator = None + + +def get_accelerator() -> Accelerator: + global global_accelerator + if global_accelerator is None: + global_accelerator = Accelerator() + return global_accelerator + +def unwrap_model(model): + accelerator = get_accelerator() + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 15dca441..7aa30b73 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -483,6 +483,7 @@ class ModelConfig: self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False) if self.split_model_over_gpus and not self.is_flux: raise ValueError("split_model_over_gpus is only supported with flux models currently") + self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3) class EMAConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 12a4df4b..d73570d2 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -123,11 +123,11 @@ class CustomAdapter(torch.nn.Module): torch_dtype = get_torch_dtype(self.sd_ref().dtype) if self.adapter_type == 'photo_maker': sd = self.sd_ref() - embed_dim = sd.unet.config['cross_attention_dim'] + embed_dim = sd.unet_unwrapped.config['cross_attention_dim'] self.fuse_module = FuseModule(embed_dim) elif self.adapter_type == 'clip_fusion': sd = self.sd_ref() - embed_dim = sd.unet.config['cross_attention_dim'] + embed_dim = sd.unet_unwrapped.config['cross_attention_dim'] vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) if self.config.image_encoder_arch == 'clip': @@ -288,7 +288,7 @@ class CustomAdapter(torch.nn.Module): self.vision_encoder = SAFEVisionModel( in_channels=3, num_tokens=self.config.safe_tokens, - num_vectors=sd.unet.config['cross_attention_dim'], + num_vectors=sd.unet_unwrapped.config['cross_attention_dim'], reducer_channels=self.config.safe_reducer_channels, channels=self.config.safe_channels, downscale_factor=8 diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 5285b371..1fd6a3b8 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -20,6 +20,8 @@ from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.print import print_acc +from toolkit.accelerator import get_accelerator import platform @@ -90,7 +92,7 @@ class ImageDataset(Dataset, CaptionMixin): file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] # this might take a while - print(f" - Preprocessing image dimensions") + print_acc(f" - Preprocessing image dimensions") new_file_list = [] bad_count = 0 for file in tqdm(self.file_list): @@ -102,8 +104,8 @@ class ImageDataset(Dataset, CaptionMixin): self.file_list = new_file_list - print(f" - Found {len(self.file_list)} images") - print(f" - Found {bad_count} images that are too small") + print_acc(f" - Found {len(self.file_list)} images") + print_acc(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.path}" self.transform = transforms.Compose([ @@ -128,8 +130,8 @@ class ImageDataset(Dataset, CaptionMixin): try: img = exif_transpose(Image.open(img_path)).convert('RGB') except Exception as e: - print(f"Error opening image: {img_path}") - print(e) + print_acc(f"Error opening image: {img_path}") + print_acc(e) # make a noise image if we can't open it img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)) @@ -140,7 +142,7 @@ class ImageDataset(Dataset, CaptionMixin): if self.random_crop: if self.random_scale and min_img_size > self.resolution: if min_img_size < self.resolution: - print( + print_acc( f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") scale_size = self.resolution else: @@ -243,11 +245,11 @@ class PairedImageDataset(Dataset): matched_files = [t for t in (set(tuple(i) for i in matched_files))] self.file_list = matched_files - print(f" - Found {len(self.file_list)} matching pairs") + print_acc(f" - Found {len(self.file_list)} matching pairs") else: self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if file.lower().endswith(supported_exts)] - print(f" - Found {len(self.file_list)} images") + print_acc(f" - Found {len(self.file_list)} images") self.transform = transforms.Compose([ transforms.ToTensor(), @@ -435,11 +437,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti ]) # this might take a while - print(f"Dataset: {self.dataset_path}") - print(f" - Preprocessing image dimensions") + print_acc(f"Dataset: {self.dataset_path}") + print_acc(f" - Preprocessing image dimensions") dataset_folder = self.dataset_path if not os.path.isdir(self.dataset_path): dataset_folder = os.path.dirname(dataset_folder) + dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json') dataloader_version = "0.1.1" if os.path.exists(dataset_size_file): @@ -448,12 +451,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.size_database = json.load(f) if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version: - print("Upgrading size database to new version") + print_acc("Upgrading size database to new version") # old version, delete and recreate self.size_database = {} except Exception as e: - print(f"Error loading size database: {dataset_size_file}") - print(e) + print_acc(f"Error loading size database: {dataset_size_file}") + print_acc(e) self.size_database = {} else: self.size_database = {} @@ -473,22 +476,22 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti ) self.file_list.append(file_item) except Exception as e: - print(traceback.format_exc()) - print(f"Error processing image: {file}") - print(e) + print_acc(traceback.format_exc()) + print_acc(f"Error processing image: {file}") + print_acc(e) bad_count += 1 # save the size database with open(dataset_size_file, 'w') as f: json.dump(self.size_database, f) - print(f" - Found {len(self.file_list)} images") - # print(f" - Found {bad_count} images that are too small") + print_acc(f" - Found {len(self.file_list)} images") + # print_acc(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" # handle x axis flips if self.dataset_config.flip_x: - print(" - adding x axis flips") + print_acc(" - adding x axis flips") current_file_list = [x for x in self.file_list] for file_item in current_file_list: # create a copy that is flipped on the x axis @@ -498,7 +501,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti # handle y axis flips if self.dataset_config.flip_y: - print(" - adding y axis flips") + print_acc(" - adding y axis flips") current_file_list = [x for x in self.file_list] for file_item in current_file_list: # create a copy that is flipped on the y axis @@ -507,7 +510,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.file_list.append(new_file_item) if self.dataset_config.flip_x or self.dataset_config.flip_y: - print(f" - Found {len(self.file_list)} images after adding flips") + print_acc(f" - Found {len(self.file_list)} images after adding flips") self.setup_epoch() diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 1bba4431..57626c3f 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -24,6 +24,8 @@ from torchvision import transforms from PIL import Image, ImageFilter, ImageOps from PIL.ImageOps import exif_transpose import albumentations as A +from toolkit.print import print_acc +from toolkit.accelerator import get_accelerator from toolkit.train_tools import get_torch_dtype @@ -32,6 +34,8 @@ if TYPE_CHECKING: from toolkit.data_transfer_object.data_loader import FileItemDTO from toolkit.stable_diffusion_model import StableDiffusion +accelerator = get_accelerator() + # def get_associated_caption_from_img_path(img_path): # https://demo.albumentations.ai/ class Augments: @@ -263,7 +267,7 @@ class BucketsMixin: file_item.crop_y = int((file_item.scale_to_height - new_height) / 2) if file_item.crop_y < 0 or file_item.crop_x < 0: - print('debug') + print_acc('debug') # check if bucket exists, if not, create it bucket_key = f'{file_item.crop_width}x{file_item.crop_height}' @@ -275,10 +279,10 @@ class BucketsMixin: self.shuffle_buckets() self.build_batch_indices() if not quiet: - print(f'Bucket sizes for {self.dataset_path}:') + print_acc(f'Bucket sizes for {self.dataset_path}:') for key, bucket in self.buckets.items(): - print(f'{key}: {len(bucket.file_list_idx)} files') - print(f'{len(self.buckets)} buckets made') + print_acc(f'{key}: {len(bucket.file_list_idx)} files') + print_acc(f'{len(self.buckets)} buckets made') class CaptionProcessingDTOMixin: @@ -447,8 +451,8 @@ class ImageProcessingDTOMixin: img = Image.open(self.path) img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.path}") if self.use_alpha_as_mask: # we do this to make sure it does not replace the alpha with another color @@ -462,11 +466,11 @@ class ImageProcessingDTOMixin: w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match - print( + print_acc( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") elif h > w and self.scale_to_height < self.scale_to_width: # throw error, they should match - print( + print_acc( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") if self.flip_x: @@ -482,7 +486,7 @@ class ImageProcessingDTOMixin: # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: # todo look into this. This still happens sometimes - print('size mismatch') + print_acc('size mismatch') img = img.crop(( self.crop_x, self.crop_y, @@ -501,7 +505,7 @@ class ImageProcessingDTOMixin: if self.dataset_config.random_crop: if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution: if min_img_size < self.dataset_config.resolution: - print( + print_acc( f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}") scale_size = self.dataset_config.resolution else: @@ -567,8 +571,8 @@ class ControlFileItemDTOMixin: img = Image.open(self.control_path).convert('RGB') img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.control_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.control_path}") if self.full_size_control_images: # we just scale them to 512x512: @@ -782,8 +786,8 @@ class ClipImageFileItemDTOMixin: except Exception as e: # make a random noise image img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution)) - print(f"Error: {e}") - print(f"Error loading image: {clip_image_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {clip_image_path}") img = img.convert('RGB') @@ -981,8 +985,8 @@ class MaskFileItemDTOMixin: img = Image.open(self.mask_path) img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.mask_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.mask_path}") if self.use_alpha_as_mask: # pipeline expectws an rgb image so we need to put alpha in all channels @@ -999,11 +1003,11 @@ class MaskFileItemDTOMixin: fix_size = False if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match - print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") fix_size = True elif h > w and self.scale_to_height < self.scale_to_width: # throw error, they should match - print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") fix_size = True if fix_size: @@ -1085,8 +1089,8 @@ class UnconditionalFileItemDTOMixin: img = Image.open(self.unconditional_path) img = exif_transpose(img) except Exception as e: - print(f"Error: {e}") - print(f"Error loading image: {self.mask_path}") + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {self.mask_path}") img = img.convert('RGB') w, h = img.size @@ -1166,9 +1170,9 @@ class PoiFileItemDTOMixin: with open(caption_path, 'r', encoding='utf-8') as f: json_data = json.load(f) if 'poi' not in json_data: - print(f"Warning: poi not found in caption file: {caption_path}") + print_acc(f"Warning: poi not found in caption file: {caption_path}") if self.poi not in json_data['poi']: - print(f"Warning: poi not found in caption file: {caption_path}") + print_acc(f"Warning: poi not found in caption file: {caption_path}") # poi has, x, y, width, height # do full image if no poi self.poi_x = 0 @@ -1242,8 +1246,8 @@ class PoiFileItemDTOMixin: # now we have our random crop, but it may be smaller than resolution. Check and expand if needed current_resolution = get_resolution(poi_width, poi_height) except Exception as e: - print(f"Error: {e}") - print(f"Error getting resolution: {self.path}") + print_acc(f"Error: {e}") + print_acc(f"Error getting resolution: {self.path}") raise e return False if current_resolution >= self.dataset_config.resolution: @@ -1252,7 +1256,7 @@ class PoiFileItemDTOMixin: else: num_loops += 1 if num_loops > 100: - print( + print_acc( f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.") return False @@ -1279,7 +1283,7 @@ class PoiFileItemDTOMixin: if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height: # todo look into this. This still happens sometimes - print('size mismatch') + print_acc('size mismatch') return True @@ -1373,88 +1377,89 @@ class LatentCachingMixin: self.latent_cache = {} def cache_latents_all_latents(self: 'AiToolkitDataset'): - print(f"Caching latents for {self.dataset_path}") - # cache all latents to disk - to_disk = self.is_caching_latents_to_disk - to_memory = self.is_caching_latents_to_memory + with accelerator.main_process_first(): + print_acc(f"Caching latents for {self.dataset_path}") + # cache all latents to disk + to_disk = self.is_caching_latents_to_disk + to_memory = self.is_caching_latents_to_memory - if to_disk: - print(" - Saving latents to disk") - if to_memory: - print(" - Keeping latents in memory") - # move sd items to cpu except for vae - self.sd.set_device_state_preset('cache_latents') + if to_disk: + print_acc(" - Saving latents to disk") + if to_memory: + print_acc(" - Keeping latents in memory") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_latents') - # use tqdm to show progress - i = 0 - for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): - # set latent space version - if self.sd.model_config.latent_space_version is not None: - file_item.latent_space_version = self.sd.model_config.latent_space_version - elif self.sd.is_xl: - file_item.latent_space_version = 'sdxl' - elif self.sd.is_v3: - file_item.latent_space_version = 'sd3' - elif self.sd.is_auraflow: - file_item.latent_space_version = 'sdxl' - elif self.sd.is_flux: - file_item.latent_space_version = 'flux1' - elif self.sd.model_config.is_pixart_sigma: - file_item.latent_space_version = 'sdxl' - else: - file_item.latent_space_version = 'sd1' - file_item.is_caching_to_disk = to_disk - file_item.is_caching_to_memory = to_memory - file_item.latent_load_device = self.sd.device + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): + # set latent space version + if self.sd.model_config.latent_space_version is not None: + file_item.latent_space_version = self.sd.model_config.latent_space_version + elif self.sd.is_xl: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_v3: + file_item.latent_space_version = 'sd3' + elif self.sd.is_auraflow: + file_item.latent_space_version = 'sdxl' + elif self.sd.is_flux: + file_item.latent_space_version = 'flux1' + elif self.sd.model_config.is_pixart_sigma: + file_item.latent_space_version = 'sdxl' + else: + file_item.latent_space_version = 'sd1' + file_item.is_caching_to_disk = to_disk + file_item.is_caching_to_memory = to_memory + file_item.latent_load_device = self.sd.device - latent_path = file_item.get_latent_path(recalculate=True) - # check if it is saved to disk already - if os.path.exists(latent_path): - if to_memory: - # load it into memory - state_dict = load_file(latent_path, device='cpu') - file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) - else: - # not saved to disk, calculate - # load the image first - file_item.load_and_process_image(self.transform, only_load_latents=True) - dtype = self.sd.torch_dtype - device = self.sd.device_torch - # add batch dimension - try: - imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) - latent = self.sd.encode_images(imgs).squeeze(0) - except Exception as e: - print(f"Error processing image: {file_item.path}") - print(f"Error: {str(e)}") - raise e - # save_latent - if to_disk: - state_dict = OrderedDict([ - ('latent', latent.clone().detach().cpu()), - ]) - # metadata - meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) - os.makedirs(os.path.dirname(latent_path), exist_ok=True) - save_file(state_dict, latent_path, metadata=meta) + latent_path = file_item.get_latent_path(recalculate=True) + # check if it is saved to disk already + if os.path.exists(latent_path): + if to_memory: + # load it into memory + state_dict = load_file(latent_path, device='cpu') + file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + else: + # not saved to disk, calculate + # load the image first + file_item.load_and_process_image(self.transform, only_load_latents=True) + dtype = self.sd.torch_dtype + device = self.sd.device_torch + # add batch dimension + try: + imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + latent = self.sd.encode_images(imgs).squeeze(0) + except Exception as e: + print_acc(f"Error processing image: {file_item.path}") + print_acc(f"Error: {str(e)}") + raise e + # save_latent + if to_disk: + state_dict = OrderedDict([ + ('latent', latent.clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) + os.makedirs(os.path.dirname(latent_path), exist_ok=True) + save_file(state_dict, latent_path, metadata=meta) - if to_memory: - # keep it in memory - file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) + if to_memory: + # keep it in memory + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) - del imgs - del latent - del file_item.tensor + del imgs + del latent + del file_item.tensor - # flush(garbage_collect=False) - file_item.is_latent_cached = True - i += 1 - # flush every 100 - # if i % 100 == 0: - # flush() + # flush(garbage_collect=False) + file_item.is_latent_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() - # restore device state - self.sd.restore_device_state() + # restore device state + self.sd.restore_device_state() class CLIPCachingMixin: @@ -1469,9 +1474,9 @@ class CLIPCachingMixin: if not self.is_caching_clip_vision_to_disk: return with torch.no_grad(): - print(f"Caching clip vision for {self.dataset_path}") + print_acc(f"Caching clip vision for {self.dataset_path}") - print(" - Saving clip to disk") + print_acc(" - Saving clip to disk") # move sd items to cpu except for vae self.sd.set_device_state_preset('cache_clip') @@ -1512,7 +1517,7 @@ class CLIPCachingMixin: self.clip_vision_num_unconditional_cache = 1 # cache unconditionals - print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk") + print_acc(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk") clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache') unconditional_paths = [] diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index e36b91a4..8c6fd966 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -3,6 +3,12 @@ import os from torch import nn from safetensors.torch import load_file import torch.nn.functional as F +from diffusers import AutoencoderTiny +from transformers import SiglipImageProcessor, SiglipVisionModel +import lpips + +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler class ResBlock(nn.Module): @@ -147,7 +153,189 @@ class DiffusionFeatureExtractor(nn.Module): return x +class DiffusionFeatureExtractor3(nn.Module): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16): + super().__init__() + self.version = 3 + vae = AutoencoderTiny.from_pretrained( + "madebyollin/taef1", torch_dtype=torch.bfloat16) + self.vae = vae + image_encoder_path = "google/siglip-so400m-patch14-384" + try: + self.image_processor = SiglipImageProcessor.from_pretrained( + image_encoder_path) + except EnvironmentError: + self.image_processor = SiglipImageProcessor() + self.vision_encoder = SiglipVisionModel.from_pretrained( + image_encoder_path, + ignore_mismatched_sizes=True + ).to(device, dtype=dtype) + + self.lpips_model = lpips_model = lpips.LPIPS(net='vgg') + self.lpips_model = lpips_model.to(device, dtype=torch.float32) + self.losses = {} + self.log_every = 100 + self.step = 0 + + def get_siglip_features(self, tensors_0_1): + dtype = torch.bfloat16 + device = self.vae.device + # resize to 384x384 + images = F.interpolate(tensors_0_1, size=(384, 384), + mode='bicubic', align_corners=False) + + mean = torch.tensor(self.image_processor.image_mean).to( + device, dtype=dtype + ).detach() + std = torch.tensor(self.image_processor.image_std).to( + device, dtype=dtype + ).detach() + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = ( + images - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + + last_hidden_state = id_embeds['last_hidden_state'] + return last_hidden_state + + def get_lpips_features(self, tensors_0_1): + device = self.vae.device + tensors_n1p1 = (tensors_0_1 * 2) - 1 + def get_lpips_features(img): # -1 to 1 + in0_input = self.lpips_model.scaling_layer(img) + outs0 = self.lpips_model.net.forward(in0_input) + + feats0 = {} + + feats_list = [] + for kk in range(self.lpips_model.L): + feats0[kk] = lpips.normalize_tensor(outs0[kk]) + feats_list.append(feats0[kk]) + + # 512 in + # vgg + # 0 torch.Size([1, 64, 512, 512]) + # 1 torch.Size([1, 128, 256, 256]) + # 2 torch.Size([1, 256, 128, 128]) + # 3 torch.Size([1, 512, 64, 64]) + # 4 torch.Size([1, 512, 32, 32]) + + return feats_list + + # do lpips + lpips_feat_list = [x.detach() for x in get_lpips_features( + tensors_n1p1.to(device, dtype=torch.float32))] + + return lpips_feat_list + + + def forward( + self, + noise_pred, + noisy_latents, + timesteps, + batch: DataLoaderBatchDTO, + scheduler: CustomFlowMatchEulerDiscreteScheduler, + lpips_weight=20.0, + clip_weight=0.1, + pixel_weight=1.0 + ): + dtype = torch.bfloat16 + device = self.vae.device + + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + + latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) + + latents = ( + latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] + tensors_n1p1 = self.vae.decode(latents).sample # -1 to 1 + + pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 + + pred_clip_output = self.get_siglip_features(pred_images) + lpips_feat_list_pred = self.get_lpips_features(pred_images.float()) + + with torch.no_grad(): + target_img = batch.tensor.to(device, dtype=dtype) + # go from -1 to 1 to 0 to 1 + target_img = (target_img + 1) / 2 + target_clip_output = self.get_siglip_features(target_img).detach() + lpips_feat_list_target = self.get_lpips_features(target_img.float()) + + clip_loss = torch.nn.functional.mse_loss( + pred_clip_output.float(), target_clip_output.float() + ) * clip_weight + + if 'clip_loss' not in self.losses: + self.losses['clip_loss'] = clip_loss.item() + else: + self.losses['clip_loss'] += clip_loss.item() + + total_loss = clip_loss + + lpips_loss = 0 + for idx, lpips_feat in enumerate(lpips_feat_list_pred): + lpips_loss += torch.nn.functional.mse_loss( + lpips_feat.float(), lpips_feat_list_target[idx].float() + ) * lpips_weight + + if 'lpips_loss' not in self.losses: + self.losses['lpips_loss'] = lpips_loss.item() + else: + self.losses['lpips_loss'] += lpips_loss.item() + + total_loss += lpips_loss + + mse_loss = torch.nn.functional.mse_loss( + stepped_latents.float(), batch.latents.float() + ) * pixel_weight + + if 'pixel_loss' not in self.losses: + self.losses['pixel_loss'] = mse_loss.item() + else: + self.losses['pixel_loss'] += mse_loss.item() + + if self.step % self.log_every == 0 and self.step > 0: + print(f"DFE losses:") + for key in self.losses: + self.losses[key] /= self.log_every + # print in 2.000e-01 format + print(f" - {key}: {self.losses[key]:.3e}") + self.losses[key] = 0.0 + + total_loss += mse_loss + self.step += 1 + + return total_loss + + def load_dfe(model_path) -> DiffusionFeatureExtractor: + if model_path == "v3": + dfe = DiffusionFeatureExtractor3() + dfe.eval() + return dfe if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # if it ende with safetensors diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py index 829283e4..c2fb5ac9 100644 --- a/toolkit/models/flux.py +++ b/toolkit/models/flux.py @@ -117,16 +117,18 @@ def split_gpu_single_block_forward( return hidden_state_out -def add_model_gpu_splitter_to_flux(transformer: FluxTransformer2DModel): +def add_model_gpu_splitter_to_flux( + transformer: FluxTransformer2DModel, + # ~ 5 billion for all other params + other_module_params: Optional[int] = 5e9, + # since they are not trainable, multiply by smaller number + other_module_param_count_scale: Optional[float] = 0.3 +): gpu_id_list = [i for i in range(torch.cuda.device_count())] # if len(gpu_id_list) > 2: # raise ValueError("Cannot split to more than 2 GPUs currently.") - - # ~ 5 billion for all other params - other_module_params = 5e9 - # since they are not trainable, multiply by smaller number - other_module_params *= 0.5 + other_module_params *= other_module_param_count_scale # since we are not tuning the total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 00cf06ee..4d97b2cd 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -108,6 +108,7 @@ class Adafactor(torch.optim.Optimizer): warmup_init=False, do_paramiter_swapping=False, paramiter_swapping_factor=0.1, + stochastic_accumulation=True, ): if lr is not None and relative_step: raise ValueError( @@ -136,13 +137,14 @@ class Adafactor(torch.optim.Optimizer): self.is_stochastic_rounding_accumulation = False # setup stochastic grad accum hooks - for group in self.param_groups: - for param in group['params']: - if param.requires_grad and param.dtype != torch.float32: - self.is_stochastic_rounding_accumulation = True - param.register_post_accumulate_grad_hook( - stochastic_grad_accummulation - ) + if stochastic_accumulation: + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) self.do_paramiter_swapping = do_paramiter_swapping self.paramiter_swapping_factor = paramiter_swapping_factor diff --git a/toolkit/print.py b/toolkit/print.py new file mode 100644 index 00000000..0ada4102 --- /dev/null +++ b/toolkit/print.py @@ -0,0 +1,6 @@ +from toolkit.accelerator import get_accelerator + + +def print_acc(*args, **kwargs): + if get_accelerator().is_local_main_process: + print(*args, **kwargs) diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 6c5b90df..a4c53db1 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -3,12 +3,27 @@ from typing import Union from torch.distributions import LogNormal from diffusers import FlowMatchEulerDiscreteScheduler import torch +import numpy as np + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_noise_sigma = 1.0 + self.timestep_type = "linear" with torch.no_grad(): # create weights for timesteps @@ -89,7 +104,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample - def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'): + def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None): + self.timestep_type = timestep_type if timestep_type == 'linear': timesteps = torch.linspace(1000, 0, num_timesteps, device=device) self.timesteps = timesteps @@ -108,6 +124,42 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): self.timesteps = timesteps.to(device=device) return timesteps + elif timestep_type == 'flux_shift': + # matches inference dynamic shifting + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps + ) + + sigmas = timesteps / self.config.num_train_timesteps + + if latents is None: + raise ValueError('latents is None') + + h = latents.shape[2] // 2 # Divide by ph + w = latents.shape[3] // 2 # Divide by pw + image_seq_len = h * w + + # todo need to know the mu for the shift + mu = calculate_shift( + image_seq_len, + self.config.get("base_image_seq_len", 256), + self.config.get("max_image_seq_len", 4096), + self.config.get("base_shift", 0.5), + self.config.get("max_shift", 1.16), + ) + sigmas = self.time_shift(mu, 1.0, sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas + + self.timesteps = timesteps.to(device=device) + return timesteps + elif timestep_type == 'lognorm_blend': # disgtribute timestepd to the center/early and blend in linear alpha = 0.75 @@ -128,5 +180,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): timesteps, _ = torch.sort(timesteps, descending=True) timesteps = timesteps.to(torch.int) + self.timesteps = timesteps.to(device=device) + return timesteps else: raise ValueError(f"Invalid timestep type: {timestep_type}") diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index b27d7c1b..7bdc586b 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -63,7 +63,10 @@ from huggingface_hub import hf_hub_download from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING +from toolkit.print import print_acc +from diffusers import FluxFillPipeline if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -133,18 +136,17 @@ class StableDiffusion: noise_scheduler=None, quantize_device=None, ): + self.accelerator = get_accelerator() self.custom_pipeline = custom_pipeline - self.device = device + self.device = str(self.accelerator.device) self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) - self.device_torch = torch.device(self.device) + self.device_torch = self.accelerator.device - self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( - model_config.vae_device) + self.vae_device_torch = self.accelerator.device self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) - self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( - model_config.te_device) + self.te_device_torch = self.accelerator.device self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) self.model_config = model_config @@ -155,6 +157,7 @@ class StableDiffusion: self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] + self.unet_unwrapped: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler @@ -189,7 +192,7 @@ class StableDiffusion: if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): self.is_flow_matching = True - self.quantize_device = quantize_device if quantize_device is not None else self.device + self.quantize_device = self.device_torch self.low_vram = self.model_config.low_vram # merge in and preview active with -1 weight @@ -257,8 +260,8 @@ class StableDiffusion: pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) if self.model_config.experimental_xl: - print("Experimental XL mode enabled") - print("Loading and injecting alt weights") + print_acc("Experimental XL mode enabled") + print_acc("Loading and injecting alt weights") # load the mismatched weight and force it in raw_state_dict = load_file(model_path) replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() @@ -268,17 +271,17 @@ class StableDiffusion: # replace weight with mismatched weight te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) flush() - print("Injecting alt weights") + print_acc("Injecting alt weights") elif self.model_config.is_v3: if self.custom_pipeline is not None: pipln = self.custom_pipeline else: pipln = StableDiffusion3Pipeline - print("Loading SD3 model") + print_acc("Loading SD3 model") # assume it is the large model base_model_path = "stabilityai/stable-diffusion-3.5-large" - print("Loading transformer") + print_acc("Loading transformer") subfolder = 'transformer' transformer_path = model_path # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set @@ -301,7 +304,7 @@ class StableDiffusion: ) if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu - transformer.to(torch.device(self.quantize_device), dtype=dtype) + transformer.to(self.quantize_device, dtype=dtype) flush() if self.model_config.lora_path is not None: @@ -309,7 +312,7 @@ class StableDiffusion: if self.model_config.quantize: quantization_type = qfloat8 - print("Quantizing transformer") + print_acc("Quantizing transformer") quantize(transformer, weights=quantization_type) freeze(transformer) transformer.to(self.device_torch) @@ -317,11 +320,11 @@ class StableDiffusion: transformer.to(self.device_torch, dtype=dtype) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print("Loading vae") + print_acc("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() - print("Loading t5") + print_acc("Loading t5") tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) text_encoder_3 = T5EncoderModel.from_pretrained( base_model_path, @@ -333,7 +336,7 @@ class StableDiffusion: flush() if self.model_config.quantize: - print("Quantizing T5") + print_acc("Quantizing T5") quantize(text_encoder_3, weights=qfloat8) freeze(text_encoder_3) flush() @@ -357,7 +360,7 @@ class StableDiffusion: **load_args ) except Exception as e: - print(f"Error loading from pretrained: {e}") + print_acc(f"Error loading from pretrained: {e}") raise e else: @@ -532,10 +535,10 @@ class StableDiffusion: tokenizer = pipe.tokenizer elif self.model_config.is_flux: - print("Loading Flux model") + print_acc("Loading Flux model") # base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = self.model_config.name_or_path_original - print("Loading transformer") + print_acc("Loading transformer") subfolder = 'transformer' transformer_path = model_path local_files_only = False @@ -558,11 +561,14 @@ class StableDiffusion: ) # hack in model gpu splitter if self.model_config.split_model_over_gpus: - add_model_gpu_splitter_to_flux(transformer) + add_model_gpu_splitter_to_flux( + transformer, + other_module_param_count_scale=self.model_config.split_model_other_module_param_count_scale + ) if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu - transformer.to(torch.device(self.quantize_device), dtype=dtype) + transformer.to(self.quantize_device, dtype=dtype) flush() if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: @@ -584,7 +590,7 @@ class StableDiffusion: load_lora_path, "pytorch_lora_weights.safetensors" ) elif not os.path.exists(load_lora_path): - print(f"Grabbing lora from the hub: {load_lora_path}") + print_acc(f"Grabbing lora from the hub: {load_lora_path}") new_lora_path = hf_hub_download( load_lora_path, filename="pytorch_lora_weights.safetensors" @@ -607,7 +613,7 @@ class StableDiffusion: self.model_config.lora_path = self.model_config.assistant_lora_path if self.model_config.lora_path is not None: - print("Fusing in LoRA") + print_acc("Fusing in LoRA") # need the pipe for peft pipe: FluxPipeline = FluxPipeline( scheduler=None, @@ -638,7 +644,7 @@ class StableDiffusion: # double blocks transformer.transformer_blocks = transformer.transformer_blocks.to( - torch.device(self.quantize_device), dtype=dtype + self.quantize_device, dtype=dtype ) pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") pipe.fuse_lora() @@ -649,7 +655,7 @@ class StableDiffusion: # single blocks transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( - torch.device(self.quantize_device), dtype=dtype + self.quantize_device, dtype=dtype ) pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") pipe.fuse_lora() @@ -677,7 +683,7 @@ class StableDiffusion: # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = qfloat8 - print("Quantizing transformer") + print_acc("Quantizing transformer") quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) freeze(transformer) transformer.to(self.device_torch) @@ -687,11 +693,11 @@ class StableDiffusion: flush() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - print("Loading vae") + print_acc("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() - print("Loading t5") + print_acc("Loading t5") tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) @@ -700,17 +706,17 @@ class StableDiffusion: flush() if self.model_config.quantize_te: - print("Quantizing T5") + print_acc("Quantizing T5") quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) flush() - print("Loading clip") + print_acc("Loading clip") text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) - print("making pipe") + print_acc("making pipe") pipe: FluxPipeline = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, @@ -723,7 +729,7 @@ class StableDiffusion: pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer - print("preparing") + print_acc("preparing") text_encoder = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] @@ -839,7 +845,7 @@ class StableDiffusion: self.is_loaded = True if self.model_config.assistant_lora_path is not None: - print("Loading assistant lora") + print_acc("Loading assistant lora") self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( self.model_config.assistant_lora_path, self) @@ -849,7 +855,7 @@ class StableDiffusion: self.assistant_lora.is_active = False if self.model_config.inference_lora_path is not None: - print("Loading inference lora") + print_acc("Loading inference lora") self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( self.model_config.inference_lora_path, self) # disable during training @@ -920,11 +926,12 @@ class StableDiffusion: sampler=None, pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, ): + network = unwrap_model(self.network) merge_multiplier = 1.0 flush() # if using assistant, unfuse it if self.model_config.assistant_lora_path is not None: - print("Unloading assistant lora") + print_acc("Unloading assistant lora") if self.invert_assistant_lora: self.assistant_lora.is_active = True # move weights on to the device @@ -933,18 +940,17 @@ class StableDiffusion: self.assistant_lora.is_active = False if self.model_config.inference_lora_path is not None: - print("Loading inference lora") + print_acc("Loading inference lora") self.assistant_lora.is_active = True # move weights on to the device self.assistant_lora.force_to(self.device_torch, self.torch_dtype) - if self.network is not None: - self.network.eval() - network = self.network + if network is not None: + network.eval() # check if we have the same network weight for all samples. If we do, we can merge in th # the network to drastically speed up inference unique_network_weights = set([x.network_multiplier for x in image_configs]) - if len(unique_network_weights) == 1 and self.network.can_merge_in: + if len(unique_network_weights) == 1 and network.can_merge_in: can_merge_in = True merge_multiplier = unique_network_weights.pop() network.merge_in(merge_weight=merge_multiplier) @@ -1029,9 +1035,9 @@ class StableDiffusion: if self.model_config.use_flux_cfg: pipeline = FluxWithCFGPipeline( vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], + transformer=unwrap_model(self.unet), + text_encoder=unwrap_model(self.text_encoder[0]), + text_encoder_2=unwrap_model(self.text_encoder[1]), tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, @@ -1041,9 +1047,9 @@ class StableDiffusion: else: pipeline = FluxPipeline( vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], + transformer=unwrap_model(self.unet), + text_encoder=unwrap_model(self.text_encoder[0]), + text_encoder_2=unwrap_model(self.text_encoder[1]), tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, @@ -1122,15 +1128,15 @@ class StableDiffusion: flush() start_multiplier = 1.0 - if self.network is not None: - start_multiplier = self.network.multiplier + if network is not None: + start_multiplier = network.multiplier # pipeline.to(self.device_torch) with network: with torch.no_grad(): - if self.network is not None: - assert self.network.is_active + if network is not None: + assert network.is_active for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): gen_config = image_configs[i] @@ -1167,8 +1173,8 @@ class StableDiffusion: validation_image = validation_image.unsqueeze(0) self.adapter.set_reference_images(validation_image) - if self.network is not None: - self.network.multiplier = gen_config.network_multiplier + if network is not None: + network.multiplier = gen_config.network_multiplier torch.manual_seed(gen_config.seed) torch.cuda.manual_seed(gen_config.seed) @@ -1335,6 +1341,12 @@ class StableDiffusion: **extra ).images[0] else: + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + latents = callback_kwargs["latents"] + if latents.dtype != self.unet.dtype: + latents = latents.to(self.unet.dtype) + return {"latents": latents} img = pipeline( prompt_embeds=conditional_embeds.text_embeds, pooled_prompt_embeds=conditional_embeds.pooled_embeds, @@ -1346,6 +1358,7 @@ class StableDiffusion: guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, generator=generator, + callback_on_step_end=callback_on_step_end, **extra ).images[0] elif self.is_pixart: @@ -1451,9 +1464,9 @@ class StableDiffusion: torch.cuda.set_rng_state(cuda_rng_state) self.restore_device_state() - if self.network is not None: - self.network.train() - self.network.multiplier = start_multiplier + if network is not None: + network.train() + network.multiplier = start_multiplier self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: @@ -1462,7 +1475,7 @@ class StableDiffusion: # refuse loras if self.model_config.assistant_lora_path is not None: - print("Loading assistant lora") + print_acc("Loading assistant lora") if self.invert_assistant_lora: self.assistant_lora.is_active = False # move weights off the device @@ -1471,7 +1484,7 @@ class StableDiffusion: self.assistant_lora.is_active = True if self.model_config.inference_lora_path is not None: - print("Unloading inference lora") + print_acc("Unloading inference lora") self.assistant_lora.is_active = False # move weights off the device self.assistant_lora.force_to('cpu', self.torch_dtype) @@ -1497,7 +1510,7 @@ class StableDiffusion: if width is None: width = pixel_width // VAE_SCALE_FACTOR - num_channels = self.unet.config['in_channels'] + num_channels = self.unet_unwrapped.config['in_channels'] if self.is_flux: # has 64 channels in for some reason num_channels = 16 @@ -1805,8 +1818,8 @@ class StableDiffusion: ratios=aspect_ratio_bin) added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - if self.unet.config.sample_size == 128 or ( - self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): + if self.unet_unwrapped.config.sample_size == 128 or ( + self.vae_scale_factor == 16 and self.unet_unwrapped.config.sample_size == 64): resolution = torch.tensor([height, width]).repeat(batch_size, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) @@ -1829,7 +1842,7 @@ class StableDiffusion: )[0] # learned sigma - if self.unet.config.out_channels // 2 == self.unet.config.in_channels: + if self.unet_unwrapped.config.out_channels // 2 == self.unet_unwrapped.config.in_channels: noise_pred = noise_pred.chunk(2, dim=1)[0] else: noise_pred = noise_pred @@ -1857,7 +1870,7 @@ class StableDiffusion: txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) # # handle guidance - if self.unet.config.guidance_embeds: + if self.unet_unwrapped.config.guidance_embeds: if isinstance(guidance_embedding_scale, list): guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch) else: @@ -1870,6 +1883,11 @@ class StableDiffusion: bypass_flux_guidance(self.unet) cast_dtype = self.unet.dtype + # changes from orig implementation + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): noise_pred = self.unet( hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] @@ -2444,7 +2462,7 @@ class StableDiffusion: # diffusers if self.is_flux: # only save the unet - transformer: FluxTransformer2DModel = self.unet + transformer: FluxTransformer2DModel = unwrap_model(self.unet) transformer.save_pretrained( save_directory=os.path.join(output_file, 'transformer'), safe_serialization=True, @@ -2516,7 +2534,7 @@ class StableDiffusion: params.append(named_params[diffusers_key]) param_data = {"params": params, "lr": unet_lr} trainable_parameters.append(param_data) - print(f"Found {len(params)} trainable parameter in unet") + print_acc(f"Found {len(params)} trainable parameter in unet") if text_encoder: named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) @@ -2529,7 +2547,7 @@ class StableDiffusion: param_data = {"params": params, "lr": text_encoder_lr} trainable_parameters.append(param_data) - print(f"Found {len(params)} trainable parameter in text encoder") + print_acc(f"Found {len(params)} trainable parameter in text encoder") if refiner: named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, @@ -2544,7 +2562,7 @@ class StableDiffusion: param_data = {"params": params, "lr": refiner_lr} trainable_parameters.append(param_data) - print(f"Found {len(params)} trainable parameter in refiner") + print_acc(f"Found {len(params)} trainable parameter in refiner") return trainable_parameters