Work on additional image embedding methods. Finalized zipper resampler. It works amazing

This commit is contained in:
Jaret Burkett
2024-02-10 09:00:05 -07:00
parent a8481c1670
commit e074058faa
7 changed files with 261 additions and 47 deletions

View File

@@ -557,6 +557,21 @@ class CustomAdapter(torch.nn.Module):
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'ilora':
if tensors_0_1 is None:
# scale the noise down
tensors_0_1 = torch.rand([1, 3, self.input_size, self.input_size], device=self.device)
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
# tensors_0_1 = tensors_0_1 * 0
mean = torch.tensor(self.clip_image_processor.image_mean).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
std = torch.tensor(self.clip_image_processor.image_std).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
@@ -626,7 +641,7 @@ class CustomAdapter(torch.nn.Module):
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
self.ilora_module.img_embeds = img_embeds
self.ilora_module(img_embeds)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.type == 'photo_maker':