mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-27 07:43:56 +00:00
various bug fixes. Created an contextual alpha mask module to calculate alpha mask
This commit is contained in:
@@ -1085,7 +1085,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
unconditional_embeds=unconditional_embeds
|
||||
)
|
||||
).detach()
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
|
||||
@@ -186,6 +186,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sample_config = self.first_sample_config if is_first else self.sample_config
|
||||
start_seed = sample_config.seed
|
||||
current_seed = start_seed
|
||||
|
||||
test_image_paths = []
|
||||
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
||||
test_image_path_list = self.adapter_config.test_img_path.split(',')
|
||||
test_image_path_list = [p.strip() for p in test_image_path_list]
|
||||
test_image_path_list = [p for p in test_image_path_list if p != '']
|
||||
# divide up images so they are evenly distributed across prompts
|
||||
for i in range(len(sample_config.prompts)):
|
||||
test_image_paths.append(test_image_path_list[i % len(test_image_path_list)])
|
||||
|
||||
for i in range(len(sample_config.prompts)):
|
||||
if sample_config.walk_seed:
|
||||
current_seed = start_seed + i
|
||||
@@ -219,7 +229,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
||||
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
|
||||
extra_args['adapter_image_path'] = test_image_paths[i]
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
|
||||
@@ -403,6 +403,7 @@ class DatasetConfig:
|
||||
# remove empty lines
|
||||
random_triggers = [line for line in random_triggers if line.strip() != '']
|
||||
self.random_triggers: List[str] = random_triggers
|
||||
self.random_triggers_max: int = kwargs.get('random_triggers_max', 1)
|
||||
self.caption_ext: str = kwargs.get('caption_ext', None)
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
|
||||
@@ -277,7 +277,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
raise ValueError(f"unknown shape: {v.shape}")
|
||||
self.fuse_module.load_state_dict(current_state_dict, strict=strict)
|
||||
|
||||
if 'vision_encoder' in state_dict:
|
||||
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
|
||||
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
|
||||
|
||||
if 'fuse_module' in state_dict:
|
||||
@@ -411,7 +411,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds
|
||||
return prompt_embeds.clone()
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
|
||||
@@ -348,8 +348,14 @@ class CaptionProcessingDTOMixin:
|
||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
|
||||
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
|
||||
# add random triggers
|
||||
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
|
||||
num_triggers = self.dataset_config.random_triggers_max
|
||||
if num_triggers > 1:
|
||||
num_triggers = random.randint(0, num_triggers)
|
||||
|
||||
if num_triggers > 0:
|
||||
# add random triggers
|
||||
for i in range(num_triggers):
|
||||
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
# shuffle again
|
||||
|
||||
@@ -86,6 +86,49 @@ class ZipperBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ContextualAlphaMask(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 768,
|
||||
):
|
||||
super(ContextualAlphaMask, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
half_dim = dim // 2
|
||||
quarter_dim = dim // 4
|
||||
|
||||
self.fc1 = nn.Linear(self.dim, self.dim)
|
||||
self.fc2 = nn.Linear(self.dim, half_dim)
|
||||
self.norm1 = nn.LayerNorm(half_dim)
|
||||
self.fc3 = nn.Linear(half_dim, half_dim)
|
||||
self.fc4 = nn.Linear(half_dim, quarter_dim)
|
||||
self.norm2 = nn.LayerNorm(quarter_dim)
|
||||
self.fc5 = nn.Linear(quarter_dim, quarter_dim)
|
||||
self.fc6 = nn.Linear(quarter_dim, 1)
|
||||
# set fc6 weights to near zero
|
||||
self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
# x = (batch_size, 77, 768)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.norm1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc3(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc4(x)
|
||||
x = self.norm2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc5(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc6(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
# CLIPFusionModule
|
||||
# Fuses any size of vision and text embeddings into a single embedding.
|
||||
# remaps tokens and vectors.
|
||||
@@ -96,7 +139,7 @@ class CLIPFusionModule(nn.Module):
|
||||
text_tokens: int = 77,
|
||||
vision_hidden_size: int = 1024,
|
||||
vision_tokens: int = 257,
|
||||
num_blocks: int = 2,
|
||||
num_blocks: int = 1,
|
||||
):
|
||||
super(CLIPFusionModule, self).__init__()
|
||||
|
||||
@@ -125,6 +168,10 @@ class CLIPFusionModule(nn.Module):
|
||||
) for i in range(num_blocks)
|
||||
])
|
||||
|
||||
self.ctx_alpha = ContextualAlphaMask(
|
||||
dim=self.text_hidden_size,
|
||||
)
|
||||
|
||||
def forward(self, text_embeds, vision_embeds):
|
||||
# text_embeds = (batch_size, 77, 768)
|
||||
# vision_embeds = (batch_size, 257, 1024)
|
||||
@@ -138,6 +185,8 @@ class CLIPFusionModule(nn.Module):
|
||||
x = block(x)
|
||||
x = x + res
|
||||
|
||||
x = text_embeds + x
|
||||
# alpha mask
|
||||
alpha = self.ctx_alpha(text_embeds)
|
||||
x = alpha * x + (1 - alpha) * text_embeds
|
||||
|
||||
return x
|
||||
|
||||
@@ -804,28 +804,27 @@ class StableDiffusion:
|
||||
detach_unconditional=False,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = concat_prompt_embeds([
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
])
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = concat_prompt_embeds([
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
])
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
|
||||
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||
# then we are doing it, otherwise we are not and takes half the time.
|
||||
do_classifier_free_guidance = True
|
||||
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||
# then we are doing it, otherwise we are not and takes half the time.
|
||||
do_classifier_free_guidance = True
|
||||
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||
do_classifier_free_guidance = False
|
||||
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||
do_classifier_free_guidance = False
|
||||
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
latents = latents.to(self.device_torch)
|
||||
text_embeddings = text_embeddings.to(self.device_torch)
|
||||
timestep = timestep.to(self.device_torch)
|
||||
|
||||
Reference in New Issue
Block a user