mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added comparitive loss when training clip encoder. Allow selecting clip layer. on ip adapter. Improvements to prior prediction
This commit is contained in:
@@ -225,52 +225,51 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.correct_pred_norm and not is_reg:
|
||||
with torch.no_grad():
|
||||
# this only works if doing a prior pred
|
||||
if prior_pred is not None:
|
||||
prior_mean = prior_pred.mean([2,3], keepdim=True)
|
||||
prior_std = prior_pred.std([2,3], keepdim=True)
|
||||
noise_mean = noise_pred.mean([2,3], keepdim=True)
|
||||
noise_std = noise_pred.std([2,3], keepdim=True)
|
||||
target = None
|
||||
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
|
||||
if self.train_config.correct_pred_norm and not is_reg:
|
||||
with torch.no_grad():
|
||||
# this only works if doing a prior pred
|
||||
if prior_pred is not None:
|
||||
prior_mean = prior_pred.mean([2,3], keepdim=True)
|
||||
prior_std = prior_pred.std([2,3], keepdim=True)
|
||||
noise_mean = noise_pred.mean([2,3], keepdim=True)
|
||||
noise_std = noise_pred.std([2,3], keepdim=True)
|
||||
|
||||
mean_adjust = prior_mean - noise_mean
|
||||
std_adjust = prior_std - noise_std
|
||||
mean_adjust = prior_mean - noise_mean
|
||||
std_adjust = prior_std - noise_std
|
||||
|
||||
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
|
||||
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
|
||||
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
|
||||
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
|
||||
|
||||
target_mean = noise_mean + mean_adjust
|
||||
target_std = noise_std + std_adjust
|
||||
target_mean = noise_mean + mean_adjust
|
||||
target_std = noise_std + std_adjust
|
||||
|
||||
eps = 1e-5
|
||||
eps = 1e-5
|
||||
# match the noise to the prior
|
||||
noise = (noise - noise_mean) / (noise_std + eps)
|
||||
noise = noise * (target_std + eps) + target_mean
|
||||
noise = noise.detach()
|
||||
|
||||
# adjust the noise target to match the current knowledge of the model
|
||||
# noise_mean, noise_std = get_mean_std(noise)
|
||||
# match the noise to the prior
|
||||
noise = (noise - noise_mean) / (noise_std + eps)
|
||||
noise = noise * (target_std + eps) + target_mean
|
||||
noise = noise.detach()
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
assert not self.train_config.train_turbo
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
stretched_mask_multiplier = value_map(
|
||||
mask_multiplier,
|
||||
batch.file_items[0].dataset_config.mask_min_value,
|
||||
1.0,
|
||||
0.0,
|
||||
1.0
|
||||
)
|
||||
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
assert not self.train_config.train_turbo
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
stretched_mask_multiplier = value_map(
|
||||
mask_multiplier,
|
||||
batch.file_items[0].dataset_config.mask_min_value,
|
||||
1.0,
|
||||
0.0,
|
||||
1.0
|
||||
)
|
||||
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
||||
|
||||
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
||||
|
||||
# target_mask_multiplier = mask_multiplier
|
||||
# mask_multiplier = 1.0
|
||||
target = noise
|
||||
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
||||
# set masked multiplier to 1.0 so we dont double apply it
|
||||
# mask_multiplier = 1.0
|
||||
# target_mask_multiplier = mask_multiplier
|
||||
# mask_multiplier = 1.0
|
||||
target = noise
|
||||
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
||||
# set masked multiplier to 1.0 so we dont double apply it
|
||||
# mask_multiplier = 1.0
|
||||
elif prior_pred is not None:
|
||||
assert not self.train_config.train_turbo
|
||||
# matching adapter prediction
|
||||
@@ -281,6 +280,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if target is None:
|
||||
target = noise
|
||||
|
||||
pred = noise_pred
|
||||
|
||||
if self.train_config.train_turbo:
|
||||
@@ -360,6 +362,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
# check for additional losses
|
||||
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
|
||||
|
||||
loss = loss + self.adapter.additional_loss.mean()
|
||||
self.adapter.additional_loss = None
|
||||
|
||||
return loss
|
||||
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
@@ -677,6 +686,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
conditioned_prompts=None,
|
||||
**kwargs
|
||||
):
|
||||
# todo for embeddings, we need to run without trigger words
|
||||
@@ -980,6 +990,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# it will be injected into the tokenizer when called
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
quad_count = random.randint(1, 4)
|
||||
self.adapter.train()
|
||||
self.adapter.trigger_pre_te(
|
||||
tensors_0_1=clip_images,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
|
||||
with self.timer('encode_prompt'):
|
||||
unconditional_embeds = None
|
||||
if grad_on_text_encoder:
|
||||
@@ -1140,6 +1161,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
||||
|
||||
with self.timer('encode_adapter'):
|
||||
self.adapter.train()
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
|
||||
@@ -1170,8 +1192,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
do_inverted_masked_prior = True
|
||||
|
||||
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
|
||||
|
||||
if ((
|
||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
|
||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
@@ -1182,8 +1206,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
unconditional_embeds=unconditional_embeds
|
||||
).detach()
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
conditioned_prompts=conditioned_prompts
|
||||
)
|
||||
if prior_pred is not None:
|
||||
prior_pred = prior_pred.detach()
|
||||
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
|
||||
@@ -130,6 +130,7 @@ class NetworkConfig:
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
|
||||
|
||||
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -169,6 +170,13 @@ class AdapterConfig:
|
||||
|
||||
self.class_names = kwargs.get('class_names', [])
|
||||
|
||||
self.clip_layer: CLIPLayer = kwargs.get('clip_layer', None)
|
||||
if self.clip_layer is None:
|
||||
if self.type.startswith('ip+'):
|
||||
self.clip_layer = 'penultimate_hidden_states'
|
||||
else:
|
||||
self.clip_layer = 'last_hidden_state'
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -438,7 +438,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
is_unconditional=False,
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
if self.adapter_type == 'ilora':
|
||||
return prompt_embeds
|
||||
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds.clone()
|
||||
@@ -503,7 +506,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.token_mask
|
||||
)
|
||||
return prompt_embeds
|
||||
elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
@@ -535,22 +538,96 @@ class CustomAdapter(torch.nn.Module):
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
if self.adapter_type == 'ilora':
|
||||
self.ilora_module.img_embeds = img_embeds
|
||||
|
||||
return prompt_embeds
|
||||
else:
|
||||
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def trigger_pre_te(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora':
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
clip_image = self.image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
|
||||
if self.config.quad_image:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
to_cat = []
|
||||
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
|
||||
if i < quad_count:
|
||||
to_cat.append(ci)
|
||||
else:
|
||||
break
|
||||
|
||||
clip_image = torch.cat(to_cat, dim=0).detach()
|
||||
|
||||
if self.adapter_type == 'ilora':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
img_embeds = id_embeds['last_hidden_state']
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = img_embeds.chunk(quad_count, dim=0)
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
# get the mean of them
|
||||
|
||||
img_embeds = chunk_sum / quad_count
|
||||
|
||||
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
self.ilora_module.img_embeds = img_embeds
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.type == 'photo_maker':
|
||||
yield from self.fuse_module.parameters(recurse)
|
||||
|
||||
@@ -5,6 +5,7 @@ import sys
|
||||
|
||||
from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from torch.nn.modules.module import T
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
@@ -173,6 +174,7 @@ class IPAdapter(torch.nn.Module):
|
||||
self.input_size = 224
|
||||
self.clip_noise_zero = True
|
||||
self.unconditional: torch.Tensor = None
|
||||
self.additional_loss = None
|
||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -451,10 +453,7 @@ class IPAdapter(torch.nn.Module):
|
||||
):
|
||||
with torch.no_grad():
|
||||
device = self.sd_ref().unet.device
|
||||
if self.config.type.startswith('ip+'):
|
||||
clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0)
|
||||
else:
|
||||
clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0)
|
||||
clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0)
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
@@ -548,7 +547,7 @@ class IPAdapter(torch.nn.Module):
|
||||
# if drop:
|
||||
# clip_image = clip_image * 0
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.image_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
if self.preprocessor is not None:
|
||||
@@ -565,16 +564,39 @@ class IPAdapter(torch.nn.Module):
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
if self.config.type.startswith('ip+'):
|
||||
if self.config.clip_layer == 'penultimate_hidden_states':
|
||||
# they skip last layer for ip+
|
||||
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
clip_image_embeds = clip_output.hidden_states[-1]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||
if self.config.train_image_encoder and is_training:
|
||||
# perform a loss across all chunks this will teach the vision encoder to
|
||||
# identify similarities in our pairs of images and ignore things that do not make them similar
|
||||
num_losses = 0
|
||||
total_loss = None
|
||||
for chunk in chunks:
|
||||
for chunk2 in chunks:
|
||||
if chunk is not chunk2:
|
||||
loss = F.mse_loss(chunk, chunk2)
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss = total_loss + loss
|
||||
num_losses += 1
|
||||
if total_loss is not None:
|
||||
total_loss = total_loss / num_losses
|
||||
total_loss = total_loss * 1e-2
|
||||
if self.additional_loss is not None:
|
||||
total_loss = total_loss + self.additional_loss
|
||||
self.additional_loss = total_loss
|
||||
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
@@ -582,7 +604,7 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
clip_image_embeds = chunk_sum / quad_count
|
||||
|
||||
if not is_training:
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
|
||||
return clip_image_embeds
|
||||
@@ -594,6 +616,17 @@ class IPAdapter(torch.nn.Module):
|
||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
if self.config.train_image_encoder:
|
||||
self.image_encoder.train(mode)
|
||||
if not self.config.train_only_image_encoder:
|
||||
for attn_processor in self.adapter_modules:
|
||||
attn_processor.train(mode)
|
||||
if self.image_proj_model is not None:
|
||||
self.image_proj_model.train(mode)
|
||||
return super().train(mode)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
|
||||
@@ -53,7 +53,6 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
# reshape if needed
|
||||
if len(x.shape) == 3:
|
||||
scaler = scaler.unsqueeze(1)
|
||||
x = x * scaler
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(x.shape)
|
||||
|
||||
@@ -354,6 +354,7 @@ class ToolkitNetworkMixin:
|
||||
self.is_ssd = is_ssd
|
||||
self.is_vega = is_vega
|
||||
self.is_v2 = is_v2
|
||||
self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega
|
||||
self.is_merged_in = False
|
||||
self.is_lorm = is_lorm
|
||||
self.network_config: NetworkConfig = network_config
|
||||
@@ -361,7 +362,7 @@ class ToolkitNetworkMixin:
|
||||
self.lorm_train_mode: Literal['local', None] = None
|
||||
self.can_merge_in = not is_lorm
|
||||
|
||||
def get_keymap(self: Network):
|
||||
def get_keymap(self: Network, force_weight_mapping=False):
|
||||
use_weight_mapping = False
|
||||
|
||||
if self.is_ssd:
|
||||
@@ -377,6 +378,9 @@ class ToolkitNetworkMixin:
|
||||
else:
|
||||
keymap_tail = 'sd1'
|
||||
# todo double check this
|
||||
# use_weight_mapping = True
|
||||
|
||||
if force_weight_mapping:
|
||||
use_weight_mapping = True
|
||||
|
||||
# load keymap
|
||||
@@ -440,9 +444,9 @@ class ToolkitNetworkMixin:
|
||||
else:
|
||||
torch.save(save_dict, file)
|
||||
|
||||
def load_weights(self: Network, file):
|
||||
def load_weights(self: Network, file, force_weight_mapping=False):
|
||||
# allows us to save and load to and from ldm weights
|
||||
keymap = self.get_keymap()
|
||||
keymap = self.get_keymap(force_weight_mapping)
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
@@ -468,6 +472,11 @@ class ToolkitNetworkMixin:
|
||||
for key in to_delete:
|
||||
del load_sd[key]
|
||||
|
||||
print(f"Missing keys: {to_delete}")
|
||||
if len(to_delete) > 0 and self.is_v1:
|
||||
print(" Attempting to load with forced keymap")
|
||||
return self.load_weights(file, force_weight_mapping=True)
|
||||
|
||||
info = self.load_state_dict(load_sd, False)
|
||||
if len(extra_dict.keys()) == 0:
|
||||
extra_dict = None
|
||||
|
||||
@@ -528,6 +528,14 @@ class StableDiffusion:
|
||||
)
|
||||
gen_config.negative_prompt_2 = gen_config.negative_prompt
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.trigger_pre_te(
|
||||
tensors_0_1=validation_image,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=4
|
||||
)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user