diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index dacb23c7..03c5f213 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -329,24 +329,7 @@ class SDTrainer(BaseSDTrainProcess): target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) elif self.sd.is_flow_matching: - # only if preconditioning model outputs - # if not preconditioning, (target = noise - batch.latents) - - # if preconditioning outputs, target latents - # model_pred = model_pred * (-sigmas) + noisy_model_input - if self.train_config.target_noise_multiplier != 1.0: - # we are adjusting the target noise, need to recompute the noisy latents with - # the noise adjusted above - with torch.no_grad(): - noisy_latents = self.sd.add_noise(batch.latents, noise, timesteps).detach() - - noise_pred = precondition_model_outputs_flow_match( - noise_pred, - noisy_latents, - timesteps, - self.sd.noise_scheduler - ) - target = batch.latents.detach() + target = (noise - batch.latents).detach() else: target = noise @@ -392,14 +375,8 @@ class SDTrainer(BaseSDTrainProcess): loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) loss = loss_per_element else: - # handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100 - if self.sd.is_flow_matching and prior_pred is None: - # outputs should be preprocessed latents - sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch) - weighting = torch.ones_like(sigmas) - loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1) - - elif self.train_config.loss_type == "mae": + # handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100 + if self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") diff --git a/repositories/sd-scripts b/repositories/sd-scripts index b78c0e2a..25f961bc 160000 --- a/repositories/sd-scripts +++ b/repositories/sd-scripts @@ -1 +1 @@ -Subproject commit b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c +Subproject commit 25f961bc779bc79aef440813e3e8e92244ac5739 diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index b9059b7b..3cd9efb8 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -13,7 +13,6 @@ from toolkit.paths import SD_SCRIPTS_ROOT sys.path.append(SD_SCRIPTS_ROOT) from diffusers import ( - StableDiffusionPipeline, DDPMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, @@ -24,10 +23,8 @@ from diffusers import ( EulerDiscreteScheduler, HeunDiscreteScheduler, KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - StableDiffusion3Pipeline + KDPM2AncestralDiscreteScheduler ) -from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import torch import re from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel @@ -136,261 +133,6 @@ def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None): return noise -def sample_images( - accelerator, - args: argparse.Namespace, - epoch, - steps, - device, - vae, - tokenizer, - text_encoder, - unet, - prompt_replacement=None, - force_sample=False -): - """ - StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した - """ - if not force_sample: - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: - return - else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch - return - - is_sample_only = args.sample_only - is_generating_only = hasattr(args, "is_generating_only") and args.is_generating_only - - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") - return - - org_vae_device = vae.device # CPUにいるはず - vae.to(device) - - # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - - if args.sample_prompts.endswith(".txt"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - lines = f.readlines() - prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] - elif args.sample_prompts.endswith(".json"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - prompts = json.load(f) - - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True - - pipeline = StableDiffusionLongPromptWeightingPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - clip_skip=args.clip_skip, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - pipeline.to(device) - - if is_generating_only: - save_dir = args.output_dir - else: - save_dir = args.output_dir + "/sample" - os.makedirs(save_dir, exist_ok=True) - - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - - with torch.no_grad(): - with accelerator.autocast(): - for i, prompt in enumerate(prompts): - if not accelerator.is_main_process: - continue - - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - prompt = prompt.get("prompt") - - prompt = replace_filewords_prompt(prompt, args) - negative_prompt = replace_filewords_prompt(negative_prompt, args) - else: - prompt = replace_filewords_prompt(prompt, args) - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue - - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - image = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - ).images[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - - if is_generating_only: - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) - else: - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{i:04d}{seed_suffix}.png" - ) - if is_sample_only: - # make prompt txt file - img_path_no_ext = os.path.join(save_dir, img_filename[:-4]) - with open(img_path_no_ext + ".txt", "w") as f: - # put prompt in txt file - f.write(prompt) - # close file - f.close() - - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass - - # clear pipeline and cache to reduce vram usage - del pipeline - torch.cuda.empty_cache() - - torch.set_rng_state(rng_state) - if cuda_rng_state is not None: - torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) - - # https://www.crosslabs.org//blog/diffusion-with-offset-noise def apply_noise_offset(noise, noise_offset): if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001): @@ -591,7 +333,7 @@ def encode_prompts_sd3( truncate: bool = True, max_length=None, dropout_prob=0.0, - pipeline: StableDiffusion3Pipeline = None, + pipeline = None, ): text_embeds_list = [] pooled_text_embeds = None # always text_encoder_2's pool