mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-24 16:29:26 +00:00
Added initial support for training i2v adapter WIP
This commit is contained in:
@@ -161,7 +161,8 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
vae = AutoencoderTiny.from_pretrained(
|
||||
"madebyollin/taef1", torch_dtype=torch.bfloat16)
|
||||
self.vae = vae
|
||||
image_encoder_path = "google/siglip-so400m-patch14-384"
|
||||
# image_encoder_path = "google/siglip-so400m-patch14-384"
|
||||
image_encoder_path = "google/siglip2-so400m-patch16-512"
|
||||
try:
|
||||
self.image_processor = SiglipImageProcessor.from_pretrained(
|
||||
image_encoder_path)
|
||||
@@ -182,7 +183,11 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
# resize to 384x384
|
||||
images = F.interpolate(tensors_0_1, size=(384, 384),
|
||||
if 'height' in self.image_processor.size:
|
||||
size = self.image_processor.size['height']
|
||||
else:
|
||||
size = self.image_processor.crop_size['height']
|
||||
images = F.interpolate(tensors_0_1, size=(size, size),
|
||||
mode='bicubic', align_corners=False)
|
||||
|
||||
mean = torch.tensor(self.image_processor.image_mean).to(
|
||||
|
||||
Reference in New Issue
Block a user