mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added initial direct vision pixtral support
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -175,3 +175,4 @@ cython_debug/
|
||||
!/extensions/example
|
||||
/temp
|
||||
/wandb
|
||||
.vscode/settings.json
|
||||
@@ -206,6 +206,8 @@ class AdapterConfig:
|
||||
self.ilora_down: bool = kwargs.get('ilora_down', True)
|
||||
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
||||
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
||||
|
||||
self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512)
|
||||
|
||||
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
@@ -257,6 +258,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder = SiglipVisionModel.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'pixtral':
|
||||
self.image_processor = PixtralVisionImagePreprocessorCompatible(
|
||||
max_image_size=self.config.pixtral_max_image_size,
|
||||
)
|
||||
self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit':
|
||||
try:
|
||||
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -700,9 +708,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
else:
|
||||
return prompt_embeds
|
||||
|
||||
def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
|
||||
def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
|
||||
if shape is None:
|
||||
shape = [batch_size, 3, self.input_size, self.input_size]
|
||||
tensors_0_1 = torch.rand(shape, device=self.device)
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
@@ -761,7 +771,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
batch_size = clip_image.shape[0]
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
# add an unconditional so we can save it
|
||||
unconditional = self.get_empty_clip_image(batch_size).to(
|
||||
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
|
||||
clip_image.device, dtype=clip_image.dtype
|
||||
)
|
||||
clip_image = torch.cat([unconditional, clip_image], dim=0)
|
||||
|
||||
@@ -17,6 +17,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
from torchvision import transforms
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
@@ -733,6 +734,7 @@ class ClipImageFileItemDTOMixin:
|
||||
return self._clip_vision_embeddings_path
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
|
||||
if self.is_vision_clip_cached:
|
||||
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
|
||||
|
||||
@@ -759,8 +761,24 @@ class ClipImageFileItemDTOMixin:
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if img.width != img.height:
|
||||
|
||||
if is_dynamic_size_and_aspect:
|
||||
# just match the bucket size for now
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
|
||||
elif img.width != img.height:
|
||||
min_size = min(img.width, img.height)
|
||||
if self.dataset_config.square_crop:
|
||||
# center crop to a square
|
||||
|
||||
@@ -33,7 +33,8 @@ class FeedForward(nn.Module):
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
|
||||
# type: ignore
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -103,15 +104,18 @@ class Attention(nn.Module):
|
||||
else:
|
||||
cache.update(xk, xv)
|
||||
key, val = cache.key, cache.value
|
||||
key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
||||
val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
||||
key = key.view(seqlen_sum * cache.max_seq_len,
|
||||
self.n_kv_heads, self.head_dim)
|
||||
val = val.view(seqlen_sum * cache.max_seq_len,
|
||||
self.n_kv_heads, self.head_dim)
|
||||
|
||||
# Repeat keys and values to match number of query heads
|
||||
key, val = repeat_kv(key, val, self.repeats, dim=1)
|
||||
|
||||
# xformers requires (B=1, S, H, D)
|
||||
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
|
||||
output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
|
||||
output = memory_efficient_attention(
|
||||
xq, key, val, mask if cache is None else cache.mask)
|
||||
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
|
||||
|
||||
assert isinstance(output, torch.Tensor)
|
||||
@@ -260,8 +264,8 @@ class PixtralVisionEncoder(nn.Module):
|
||||
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
||||
self._freqs_cis: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
model_folder = pretrained_model_name_or_path
|
||||
else:
|
||||
@@ -275,11 +279,12 @@ class PixtralVisionEncoder(nn.Module):
|
||||
with open(os.path.join(model_folder, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = PixtralVisionEncoder(**config)
|
||||
model = cls(**config)
|
||||
|
||||
# see if there is a state_dict
|
||||
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
|
||||
state_dict = load_file(os.path.join(model_folder, "model.safetensors"))
|
||||
state_dict = load_file(os.path.join(
|
||||
model_folder, "model.safetensors"))
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
@@ -319,14 +324,17 @@ class PixtralVisionEncoder(nn.Module):
|
||||
image_features: tensor of token features for all tokens of all images of
|
||||
shape (N_toks, D)
|
||||
"""
|
||||
assert isinstance(images, list), f"Expected list of images, got {type(images)}"
|
||||
assert isinstance(
|
||||
images, list), f"Expected list of images, got {type(images)}"
|
||||
assert all(len(img.shape) == 3 for img in
|
||||
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
|
||||
patch_embeds_list = [self.patch_conv(
|
||||
img.unsqueeze(0)).squeeze(0) for img in images]
|
||||
|
||||
# flatten to a single sequence
|
||||
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
|
||||
patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
|
||||
for p in patch_embeds_list], dim=0)
|
||||
patch_embeds = self.ln_pre(patch_embeds)
|
||||
|
||||
# positional embeddings
|
||||
@@ -355,7 +363,8 @@ class VisionLanguageAdapter(nn.Module):
|
||||
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
|
||||
# type: ignore[no-any-return]
|
||||
return self.w_out(self.gelu(self.w_in(x)))
|
||||
|
||||
|
||||
class VisionTransformerBlocks(nn.Module):
|
||||
@@ -401,7 +410,8 @@ def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> tor
|
||||
Returns:
|
||||
torch.Tensor: Normalized image with shape (C, H, W).
|
||||
"""
|
||||
assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
||||
assert image.shape[0] == len(mean) == len(
|
||||
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
||||
|
||||
# Reshape mean and std to (C, 1, 1) for broadcasting
|
||||
mean = mean.view(-1, 1, 1)
|
||||
@@ -473,10 +483,12 @@ class PixtralVisionImagePreprocessor:
|
||||
"""
|
||||
# should not have batch
|
||||
if len(image.shape) == 4:
|
||||
raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}")
|
||||
raise ValueError(
|
||||
f"Expected image with shape (C, H, W), got {image.shape}")
|
||||
|
||||
if image.min() < 0.0 or image.max() > 1.0:
|
||||
raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||
raise ValueError(
|
||||
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||
|
||||
w, h = self._image_to_num_tokens(image)
|
||||
assert w > 0
|
||||
@@ -490,3 +502,98 @@ class PixtralVisionImagePreprocessor:
|
||||
processed_image = transform_image(image, new_image_size)
|
||||
|
||||
return processed_image
|
||||
|
||||
|
||||
class PixtralVisionImagePreprocessorCompatibleReturn:
|
||||
def __init__(self, pixel_values) -> None:
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
# Compatable version with ai toolkit flow
|
||||
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
||||
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
|
||||
super().__init__(
|
||||
image_patch_size=image_patch_size,
|
||||
max_image_size=max_image_size
|
||||
)
|
||||
self.size = {
|
||||
'height': max_image_size,
|
||||
'width': max_image_size
|
||||
}
|
||||
self.image_mean = DATASET_MEAN
|
||||
self.image_std = DATASET_STD
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
) -> torch.Tensor:
|
||||
out_stack = []
|
||||
if len(images.shape) == 3:
|
||||
images = images.unsqueeze(0)
|
||||
for i in range(images.shape[0]):
|
||||
image = images[i]
|
||||
processed_image = super().__call__(image)
|
||||
out_stack.append(processed_image)
|
||||
|
||||
output = torch.stack(out_stack, dim=0)
|
||||
return PixtralVisionImagePreprocessorCompatibleReturn(output)
|
||||
|
||||
|
||||
class PixtralVisionEncoderCompatibleReturn:
|
||||
def __init__(self, hidden_states) -> None:
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
|
||||
class PixtralVisionEncoderCompatibleConfig:
|
||||
def __init__(self):
|
||||
self.image_size = 1024
|
||||
self.hidden_size = 1024
|
||||
self.patch_size = 16
|
||||
|
||||
|
||||
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
intermediate_size: int = 4096,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
rope_theta: float = 1e4, # for rope-2D
|
||||
image_token_id: int = 10,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_channels=num_channels,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
rope_theta=rope_theta,
|
||||
image_token_id=image_token_id,
|
||||
)
|
||||
self.config = PixtralVisionEncoderCompatibleConfig()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images,
|
||||
output_hidden_states=True,
|
||||
) -> torch.Tensor:
|
||||
out_stack = []
|
||||
if len(images.shape) == 3:
|
||||
images = images.unsqueeze(0)
|
||||
for i in range(images.shape[0]):
|
||||
image = images[i]
|
||||
# must be in an array
|
||||
image_output = super().forward([image])
|
||||
out_stack.append(image_output)
|
||||
|
||||
output = torch.stack(out_stack, dim=0)
|
||||
return PixtralVisionEncoderCompatibleReturn([output])
|
||||
|
||||
@@ -9,6 +9,7 @@ from collections import OrderedDict
|
||||
|
||||
from diffusers import Transformer2DModel, FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
|
||||
@@ -1461,6 +1461,7 @@ class StableDiffusion:
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=None,
|
||||
return_conditional_pred=False,
|
||||
guidance_embedding_scale=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
conditional_pred = None
|
||||
@@ -1736,10 +1737,12 @@ class StableDiffusion:
|
||||
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
# # handle guidance
|
||||
guidance_scale = 1.0 # ?
|
||||
if self.unet.config.guidance_embeds:
|
||||
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance = torch.tensor(guidance_scale, device=self.device_torch)
|
||||
else:
|
||||
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user