diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index bd8cf957..3d8fcde3 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1341,6 +1341,12 @@ class SDTrainer(BaseSDTrainProcess): quad_count=quad_count ) + if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: + self.adapter.add_extra_values(batch.extra_values.detach()) + + if self.train_config.do_cfg: + self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True) + self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5ed8005f..8e91c825 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -246,6 +246,7 @@ class BaseSDTrainProcess(BaseTrainProcess): output_ext=sample_config.ext, adapter_conditioning_scale=sample_config.adapter_conditioning_scale, refiner_start_at=sample_config.refiner_start_at, + extra_values=sample_config.extra_values, **extra_args )) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4eba8ff4..54bfcb7f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -50,6 +50,7 @@ class SampleConfig: self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists + self.extra_values = kwargs.get('extra_values', []) class LormModuleSettingsConfig: @@ -526,6 +527,7 @@ class DatasetConfig: self.num_workers: int = kwargs.get('num_workers', 4) self.prefetch_factor: int = kwargs.get('prefetch_factor', 2) + self.extra_values: List[float] = kwargs.get('extra_values', []) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: @@ -574,6 +576,7 @@ class GenerateImageConfig: latents: Union[torch.Tensor | None] = None, # input latent to start with, extra_kwargs: dict = None, # extra data to save with prompt file refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end + extra_values: List[float] = None, # extra values to save with prompt file ): self.width: int = width self.height: int = height @@ -601,6 +604,7 @@ class GenerateImageConfig: self.adapter_conditioning_scale: float = adapter_conditioning_scale self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {} self.refiner_start_at = refiner_start_at + self.extra_values = extra_values if extra_values is not None else [] # prompt string will override any settings above self._process_prompt_string() @@ -610,7 +614,7 @@ class GenerateImageConfig: self.negative_prompt_2 = negative_prompt if prompt_2 is None: - self.prompt_2 = prompt + self.prompt_2 = self.prompt # parse prompt paths if self.output_path is None and self.output_folder is None: @@ -759,6 +763,12 @@ class GenerateImageConfig: self.adapter_conditioning_scale = float(content) elif flag == 'ref': self.refiner_start_at = float(content) + elif flag == 'ev': + # split by comma + self.extra_values = [float(val) for val in content.split(',')] + elif flag == 'extra_values': + # split by comma + self.extra_values = [float(val) for val in content.split(',')] def post_process_embeddings( self, diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index f59eb124..eaf443af 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -9,6 +9,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5En from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.ilora import InstantLoRAModule +from toolkit.models.single_value_adapter import SingleValueAdapter from toolkit.models.te_adapter import TEAdapter from toolkit.models.te_aug_adapter import TEAugAdapter from toolkit.models.vd_adapter import VisionDirectAdapter @@ -87,6 +88,7 @@ class CustomAdapter(torch.nn.Module): self.te_adapter: TEAdapter = None self.te_augmenter: TEAugAdapter = None self.vd_adapter: VisionDirectAdapter = None + self.single_value_adapter: SingleValueAdapter = None self.conditional_embeds: Optional[torch.Tensor] = None self.unconditional_embeds: Optional[torch.Tensor] = None @@ -173,6 +175,8 @@ class CustomAdapter(torch.nn.Module): self.te_augmenter = TEAugAdapter(self, self.sd_ref()) elif self.adapter_type == 'vision_direct': self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder) + elif self.adapter_type == 'single_value': + self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens) else: raise ValueError(f"unknown adapter type: {self.adapter_type}") @@ -204,7 +208,7 @@ class CustomAdapter(torch.nn.Module): def setup_clip(self): adapter_config = self.config sd = self.sd_ref() - if self.config.type == "text_encoder": + if self.config.type == "text_encoder" or self.config.type == "single_value": return if self.config.type == 'photo_maker': try: @@ -374,6 +378,9 @@ class CustomAdapter(torch.nn.Module): if 'dvadapter' in state_dict: self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict) + if 'sv_adapter' in state_dict: + self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict) + if 'vision_encoder' in state_dict and self.config.train_image_encoder: self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) @@ -417,6 +424,9 @@ class CustomAdapter(torch.nn.Module): if self.config.train_image_encoder: state_dict["vision_encoder"] = self.vision_encoder.state_dict() return state_dict + elif self.adapter_type == 'single_value': + state_dict["sv_adapter"] = self.single_value_adapter.state_dict() + return state_dict elif self.adapter_type == 'ilora': if self.config.train_image_encoder: state_dict["vision_encoder"] = self.vision_encoder.state_dict() @@ -425,6 +435,14 @@ class CustomAdapter(torch.nn.Module): else: raise NotImplementedError + def add_extra_values(self, extra_values: torch.Tensor, is_unconditional=False): + if self.adapter_type == 'single_value': + if is_unconditional: + self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) + else: + self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype)) + + def condition_prompt( self, prompt: Union[List[str], str], @@ -843,6 +861,8 @@ class CustomAdapter(torch.nn.Module): yield from self.te_augmenter.parameters(recurse) if self.config.train_image_encoder: yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'single_value': + yield from self.single_value_adapter.parameters(recurse) else: raise NotImplementedError diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 02c5e327..b94ddddc 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -99,6 +99,7 @@ class DataLoaderBatchDTO: self.clip_image_embeds: Union[List[dict], None] = None self.clip_image_embeds_unconditional: Union[List[dict], None] = None self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code + self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None if not is_latents_cached: # only return a tensor if latents are not cached self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index a0007d50..2d55e563 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -266,6 +266,9 @@ class CaptionProcessingDTOMixin: self.caption: str = None self.caption_short: str = None + dataset_config: DatasetConfig = kwargs.get('dataset_config', None) + self.extra_values: List[float] = dataset_config.extra_values + # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): if self.raw_caption is not None: @@ -292,11 +295,15 @@ class CaptionProcessingDTOMixin: prompt = prompt.replace('\n', ' ') prompt = prompt.replace('\r', ' ') - prompt = json.loads(prompt) - if 'caption' in prompt: - prompt = prompt['caption'] - if 'caption_short' in prompt: - short_caption = prompt['caption_short'] + prompt_json = json.loads(prompt) + if 'caption' in prompt_json: + prompt = prompt_json['caption'] + if 'caption_short' in prompt_json: + short_caption = prompt_json['caption_short'] + + if 'extra_values' in prompt_json: + self.extra_values = prompt_json['extra_values'] + prompt = clean_caption(prompt) if short_caption is not None: short_caption = clean_caption(short_caption) diff --git a/toolkit/models/single_value_adapter.py b/toolkit/models/single_value_adapter.py new file mode 100644 index 00000000..9284d020 --- /dev/null +++ b/toolkit/models/single_value_adapter.py @@ -0,0 +1,402 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING + +from diffusers import Transformer2DModel +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class SingleValueAdapterAttnProcessor(nn.Module): + r""" + Attention processor for Custom TE for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + adapter + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, + adapter_hidden_size=None, has_bias=False, **kwargs): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + self.hidden_size = hidden_size + self.adapter_hidden_size = adapter_hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + + @property + def is_active(self): + return self.adapter_ref().is_active + # return False + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # only use one TE or the other. If our adapter is active only use ours + if self.is_active and self.conditional_embeds is not None: + + adapter_hidden_states = self.conditional_embeds + if adapter_hidden_states.shape[0] < batch_size: + # doing cfg + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # needs to be shape (batch, 1, 1) + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query + + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) + + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * vd_hidden_states + + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class SingleValueAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + num_values: int = 1, + ): + super(SingleValueAdapter, self).__init__() + is_pixart = sd.is_pixart + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.token_size = num_values + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + # if is_pixart: + # to_k_bias = unet_sd[layer_name + ".to_k.bias"] + # to_v_bias = unet_sd[layer_name + ".to_v.bias"] + # else: + # to_k_bias = None + # to_v_bias = None + + # add zero padding to the adapter + if to_k_adapter.shape[1] < self.token_size: + to_k_adapter = torch.cat([ + to_k_adapter, + torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + to_v_adapter = torch.cat([ + to_v_adapter, + torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to( + to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 + ], + dim=1 + ) + # if is_pixart: + # to_k_bias = torch.cat([ + # to_k_bias, + # torch.zeros(self.token_size - to_k_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) + # to_v_bias = torch.cat([ + # to_v_bias, + # torch.zeros(self.token_size - to_v_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) + elif to_k_adapter.shape[1] > self.token_size: + to_k_adapter = to_k_adapter[:, :self.token_size] + to_v_adapter = to_v_adapter[:, :self.token_size] + # if is_pixart: + # to_k_bias = to_k_bias[:self.token_size] + # to_v_bias = to_v_bias[:self.token_size] + else: + to_k_adapter = to_k_adapter + to_v_adapter = to_v_adapter + # if is_pixart: + # to_k_bias = to_k_bias + # to_v_bias = to_v_bias + + weights = { + "to_k_adapter.weight": to_k_adapter * 0.01, + "to_v_adapter.weight": to_v_adapter * 0.01, + } + # if is_pixart: + # weights["to_k_adapter.bias"] = to_k_bias + # weights["to_v_adapter.bias"] = to_v_bias + + attn_procs[name] = SingleValueAdapterAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + adapter=self, + adapter_hidden_size=self.token_size, + has_bias=False, + ) + attn_procs[name].load_state_dict(weights) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList([ + transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks)) + ] + [ + transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) + ]) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + def forward(self, input): + return input diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index cf7225b7..adf91d32 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -685,6 +685,14 @@ class StableDiffusion: is_generating_samples=True, ) + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values(extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: # if we have a refiner loaded, set the denoising end at the refiner start extra['denoising_end'] = gen_config.refiner_start_at