various bug fixes. Created an contextual alpha mask module to calculate alpha mask

This commit is contained in:
Jaret Burkett
2024-01-18 16:34:27 -07:00
parent 86c70a2a1f
commit f17ad8d794
7 changed files with 93 additions and 28 deletions

View File

@@ -1085,7 +1085,7 @@ class SDTrainer(BaseSDTrainProcess):
noise=noise,
batch=batch,
unconditional_embeds=unconditional_embeds
)
).detach()
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:

View File

@@ -186,6 +186,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
sample_config = self.first_sample_config if is_first else self.sample_config
start_seed = sample_config.seed
current_seed = start_seed
test_image_paths = []
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
test_image_path_list = self.adapter_config.test_img_path.split(',')
test_image_path_list = [p.strip() for p in test_image_path_list]
test_image_path_list = [p for p in test_image_path_list if p != '']
# divide up images so they are evenly distributed across prompts
for i in range(len(sample_config.prompts)):
test_image_paths.append(test_image_path_list[i % len(test_image_path_list)])
for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
@@ -219,7 +229,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
extra_args = {}
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
extra_args['adapter_image_path'] = test_image_paths[i]
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt

View File

@@ -403,6 +403,7 @@ class DatasetConfig:
# remove empty lines
random_triggers = [line for line in random_triggers if line.strip() != '']
self.random_triggers: List[str] = random_triggers
self.random_triggers_max: int = kwargs.get('random_triggers_max', 1)
self.caption_ext: str = kwargs.get('caption_ext', None)
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)

View File

@@ -277,7 +277,7 @@ class CustomAdapter(torch.nn.Module):
raise ValueError(f"unknown shape: {v.shape}")
self.fuse_module.load_state_dict(current_state_dict, strict=strict)
if 'vision_encoder' in state_dict:
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
if 'fuse_module' in state_dict:
@@ -411,7 +411,7 @@ class CustomAdapter(torch.nn.Module):
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds
return prompt_embeds.clone()
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:

View File

@@ -348,8 +348,14 @@ class CaptionProcessingDTOMixin:
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
# add random triggers
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
num_triggers = self.dataset_config.random_triggers_max
if num_triggers > 1:
num_triggers = random.randint(0, num_triggers)
if num_triggers > 0:
# add random triggers
for i in range(num_triggers):
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
if self.dataset_config.shuffle_tokens:
# shuffle again

View File

@@ -86,6 +86,49 @@ class ZipperBlock(nn.Module):
return x
class ContextualAlphaMask(nn.Module):
def __init__(
self,
dim: int = 768,
):
super(ContextualAlphaMask, self).__init__()
self.dim = dim
half_dim = dim // 2
quarter_dim = dim // 4
self.fc1 = nn.Linear(self.dim, self.dim)
self.fc2 = nn.Linear(self.dim, half_dim)
self.norm1 = nn.LayerNorm(half_dim)
self.fc3 = nn.Linear(half_dim, half_dim)
self.fc4 = nn.Linear(half_dim, quarter_dim)
self.norm2 = nn.LayerNorm(quarter_dim)
self.fc5 = nn.Linear(quarter_dim, quarter_dim)
self.fc6 = nn.Linear(quarter_dim, 1)
# set fc6 weights to near zero
self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
self.act_fn = nn.GELU()
def forward(self, x):
# x = (batch_size, 77, 768)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x = self.norm1(x)
x = self.act_fn(x)
x = self.fc3(x)
x = self.act_fn(x)
x = self.fc4(x)
x = self.norm2(x)
x = self.act_fn(x)
x = self.fc5(x)
x = self.act_fn(x)
x = self.fc6(x)
x = torch.sigmoid(x)
return x
# CLIPFusionModule
# Fuses any size of vision and text embeddings into a single embedding.
# remaps tokens and vectors.
@@ -96,7 +139,7 @@ class CLIPFusionModule(nn.Module):
text_tokens: int = 77,
vision_hidden_size: int = 1024,
vision_tokens: int = 257,
num_blocks: int = 2,
num_blocks: int = 1,
):
super(CLIPFusionModule, self).__init__()
@@ -125,6 +168,10 @@ class CLIPFusionModule(nn.Module):
) for i in range(num_blocks)
])
self.ctx_alpha = ContextualAlphaMask(
dim=self.text_hidden_size,
)
def forward(self, text_embeds, vision_embeds):
# text_embeds = (batch_size, 77, 768)
# vision_embeds = (batch_size, 257, 1024)
@@ -138,6 +185,8 @@ class CLIPFusionModule(nn.Module):
x = block(x)
x = x + res
x = text_embeds + x
# alpha mask
alpha = self.ctx_alpha(text_embeds)
x = alpha * x + (1 - alpha) * text_embeds
return x

View File

@@ -804,28 +804,27 @@ class StableDiffusion:
detach_unconditional=False,
**kwargs,
):
with torch.no_grad():
# get the embeddings
if text_embeddings is None and conditional_embeddings is None:
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
if text_embeddings is None and unconditional_embeddings is not None:
text_embeddings = concat_prompt_embeds([
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
])
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
text_embeddings = conditional_embeddings
# get the embeddings
if text_embeddings is None and conditional_embeddings is None:
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
if text_embeddings is None and unconditional_embeddings is not None:
text_embeddings = concat_prompt_embeds([
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
])
elif text_embeddings is None and conditional_embeddings is not None:
# not doing cfg
text_embeddings = conditional_embeddings
# CFG is comparing neg and positive, if we have concatenated embeddings
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# CFG is comparing neg and positive, if we have concatenated embeddings
# then we are doing it, otherwise we are not and takes half the time.
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
latents = latents.to(self.device_torch)
text_embeddings = text_embeddings.to(self.device_torch)
timestep = timestep.to(self.device_torch)