diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index d5a2d95e..a79b0ccf 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -356,7 +356,7 @@ class SDTrainer(BaseSDTrainProcess): # we have to encode images into latents for now # we also denoise as the unaugmented tensor is not a noisy diffirental with torch.no_grad(): - unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier target = unaugmented_latents.detach() @@ -907,6 +907,17 @@ class SDTrainer(BaseSDTrainProcess): self.timer.start('preprocess_batch') batch = self.preprocess_batch(batch) dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) if self.train_config.do_cfg or self.train_config.do_random_cfg: # pick random negative prompts diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 68b6512f..846156b2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -83,7 +83,12 @@ class BaseSDTrainProcess(BaseTrainProcess): else: self.network_config = None self.train_config = TrainConfig(**self.get_conf('train', {})) - self.model_config = ModelConfig(**self.get_conf('model', {})) + model_config = self.get_conf('model', {}) + + # update modelconfig dtype to match train + model_config['dtype'] = self.train_config.dtype + self.model_config = ModelConfig(**model_config) + self.save_config = SaveConfig(**self.get_conf('save', {})) self.sample_config = SampleConfig(**self.get_conf('sample', {})) first_sample_config = self.get_conf('first_sample', None) @@ -723,6 +728,17 @@ class BaseSDTrainProcess(BaseTrainProcess): noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) + if self.train_config.random_noise_shift > 0.0: + # get random noise -1 to 1 + noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, + dtype=noise.dtype) * 2 - 1 + + # multiply by shift amount + noise_shift *= self.train_config.random_noise_shift + + # add to noise + noise += noise_shift + return noise def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f05b8b3e..c8ef256f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -160,6 +160,8 @@ class AdapterConfig: self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False) if self.train_only_image_encoder: self.train_image_encoder = True + self.train_only_image_encoder_positional_embedding: bool = kwargs.get( + 'train_only_image_encoder_positional_embedding', False) self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) self.safe_channels: int = kwargs.get('safe_channels', 2048) @@ -260,6 +262,7 @@ class TrainConfig: # multiplier applied to loos on regularization images self.reg_weight = kwargs.get('reg_weight', 1.0) self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) + self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) # dropout that happens before encoding. It functions independently per text encoder self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) @@ -385,6 +388,11 @@ class ModelConfig: self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4 self.unet_path = kwargs.get("unet_path", None) self.unet_sample_size = kwargs.get("unet_sample_size", None) + self.vae_device = kwargs.get("vae_device", None) + self.vae_dtype = kwargs.get("vae_dtype", self.dtype) + self.te_device = kwargs.get("te_device", None) + self.te_dtype = kwargs.get("te_dtype", self.dtype) + pass class EMAConfig: diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 56909655..22a82439 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -394,7 +394,7 @@ class IPAdapter(torch.nn.Module): elif adapter_config.type == 'ip+': heads = 12 if not sd.is_xl else 20 dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 - embedding_dim = self.image_encoder.config.target_hidden_size if not self.config.image_encoder_arch.startswith( + embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith( 'convnext') else \ self.image_encoder.config.hidden_sizes[-1] @@ -964,7 +964,10 @@ class IPAdapter(torch.nn.Module): def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]: if self.config.train_only_image_encoder: - yield from self.image_encoder.parameters(recurse) + if self.config.train_only_image_encoder_positional_embedding: + yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse) + else: + yield from self.image_encoder.parameters(recurse) return if self.config.train_scaler: # no params diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index 267a0dde..ad43f221 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -21,7 +21,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 if TYPE_CHECKING: - from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline from toolkit.custom_adapter import CustomAdapter @@ -202,6 +202,10 @@ class TEAdapterAttnProcessor(nn.Module): # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 + # remove attn mask if doing clip + if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip": + attention_mask = None + hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -246,7 +250,7 @@ class TEAdapter(torch.nn.Module): if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5": self.token_size = self.te_ref().config.d_model else: - self.token_size = self.te_ref().config.target_hidden_size + self.token_size = self.te_ref().config.hidden_size # add text projection if is sdxl self.text_projection = None @@ -388,8 +392,17 @@ class TEAdapter(torch.nn.Module): # ).input_ids.to(te.device) # outputs = te(input_ids=input_ids) # outputs = outputs.last_hidden_state + if self.adapter_ref().config.text_encoder_arch == "clip": + embeds = train_tools.encode_prompts( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + attention_mask = torch.ones(embeds.shape[:2], device=embeds.device) - if self.adapter_ref().config.text_encoder_arch == "pile-t5": + elif self.adapter_ref().config.text_encoder_arch == "pile-t5": # just use aura pile embeds, attention_mask = train_tools.encode_prompts_auraflow( tokenizer, @@ -407,7 +420,8 @@ class TEAdapter(torch.nn.Module): truncate=True, max_length=self.adapter_ref().config.num_tokens, ) - attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype) + if attention_mask is not None: + attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype) if self.text_projection is not None: # pool the output of embeds ignoring 0 in the attention mask pooled_output = embeds * attn_mask_float.unsqueeze(-1) @@ -420,19 +434,19 @@ class TEAdapter(torch.nn.Module): pooled_embeds = self.text_projection(pooled_output) - t5_embeds = PromptEmbeds( + prompt_embeds = PromptEmbeds( (embeds, pooled_embeds), attention_mask=attention_mask, ).detach() else: - t5_embeds = PromptEmbeds( + prompt_embeds = PromptEmbeds( embeds, attention_mask=attention_mask, ).detach() - return t5_embeds + return prompt_embeds diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 6ee84a89..143f5ca4 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -123,6 +123,13 @@ class StableDiffusion: self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) self.device_torch = torch.device(self.device) + + self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device) + self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) + + self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device) + self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) + self.model_config = model_config self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" @@ -220,11 +227,13 @@ class StableDiffusion: text_encoders = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] for text_encoder in text_encoders: - text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) text_encoder.requires_grad_(False) text_encoder.eval() text_encoder = text_encoders + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + if self.model_config.experimental_xl: print("Experimental XL mode enabled") print("Loading and injecting alt weights") @@ -333,6 +342,8 @@ class StableDiffusion: # replace the to function with a no-op since it throws an error instead of a warning text_encoder.to = lambda *args, **kwargs: None + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + if self.model_config.is_pixart_sigma: # load the transformer only from the save transformer = Transformer2DModel.from_pretrained( @@ -375,6 +386,8 @@ class StableDiffusion: pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) tokenizer = pipe.tokenizer + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + elif self.model_config.is_auraflow: te_kwargs = {} @@ -427,7 +440,7 @@ class StableDiffusion: pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) # patch auraflow so it can handle other aspect ratios - patch_auraflow_pos_embed(pipe.transformer.pos_embed) + # patch_auraflow_pos_embed(pipe.transformer.pos_embed) flush() # text_encoder = pipe.text_encoder @@ -442,6 +455,31 @@ class StableDiffusion: else: pipln = StableDiffusionPipeline + if self.model_config.text_encoder_bits < 16: + # this is only supported for T5 models for now + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + torch_dtype=self.te_torch_dtype, + **te_kwargs + ) + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + load_args['text_encoder'] = text_encoder + # see if path exists if not os.path.exists(model_path) or os.path.isdir(model_path): # try to load with default diffusers @@ -455,7 +493,7 @@ class StableDiffusion: # variant="fp16", trust_remote_code=True, **load_args - ).to(self.device_torch) + ) else: pipe = pipln.from_single_file( model_path, @@ -467,12 +505,12 @@ class StableDiffusion: safety_checker=None, trust_remote_code=True, **load_args - ).to(self.device_torch) + ) flush() pipe.register_to_config(requires_safety_checker=False) text_encoder = pipe.text_encoder - text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) text_encoder.requires_grad_(False) text_encoder.eval() tokenizer = pipe.tokenizer @@ -488,7 +526,7 @@ class StableDiffusion: self.unet = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet - self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype) + self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) self.vae.eval() self.vae.requires_grad_(False) VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) @@ -707,7 +745,7 @@ class StableDiffusion: feature_extractor=None, requires_safety_checker=False, **extra_args - ).to(self.device_torch) + ) flush() # disable progress bar pipeline.set_progress_bar_config(disable=True) @@ -873,6 +911,9 @@ class StableDiffusion: if not self.is_xl: raise ValueError("Refiner is only supported for XL models") + conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + if self.is_xl: # fix guidance rescale for sdxl # was trained on 0.7 (I believe) @@ -1014,6 +1055,7 @@ class StableDiffusion: self.network.train() self.network.multiplier = start_multiplier + self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) @@ -1655,18 +1697,18 @@ class StableDiffusion: dtype=None ): if device is None: - device = self.device + device = self.vae_device_torch if dtype is None: - dtype = self.torch_dtype + dtype = self.vae_torch_dtype latent_list = [] # Move to vae to device if on cpu if self.vae.device == 'cpu': - self.vae.to(self.device) + self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) # move to device and dtype - image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list] + image_list = [image.to(device, dtype=dtype) for image in image_list] VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) @@ -2158,7 +2200,7 @@ class StableDiffusion: # vae state['vae'] = { 'training': 'vae' in training_modules, - 'device': self.device_torch if 'vae' in active_modules else 'cpu', + 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', 'requires_grad': 'vae' in training_modules, } @@ -2182,13 +2224,13 @@ class StableDiffusion: for i, encoder in enumerate(self.text_encoder): state['text_encoder'].append({ 'training': 'text_encoder' in training_modules, - 'device': self.device_torch if 'text_encoder' in active_modules else 'cpu', + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', 'requires_grad': 'text_encoder' in training_modules, }) else: state['text_encoder'] = { 'training': 'text_encoder' in training_modules, - 'device': self.device_torch if 'text_encoder' in active_modules else 'cpu', + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', 'requires_grad': 'text_encoder' in training_modules, }