mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Work on additional image embedding methods. Finalized zipper resampler. It works amazing
This commit is contained in:
@@ -557,6 +557,21 @@ class CustomAdapter(torch.nn.Module):
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora':
|
||||
if tensors_0_1 is None:
|
||||
# scale the noise down
|
||||
tensors_0_1 = torch.rand([1, 3, self.input_size, self.input_size], device=self.device)
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
mean = torch.tensor(self.clip_image_processor.image_mean).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
std = torch.tensor(self.clip_image_processor.image_std).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
@@ -626,7 +641,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
self.ilora_module.img_embeds = img_embeds
|
||||
self.ilora_module(img_embeds)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.type == 'photo_maker':
|
||||
|
||||
Reference in New Issue
Block a user