Added initial direct vision pixtral support

This commit is contained in:
Jaret Burkett
2024-09-28 10:47:51 -06:00
parent 86b5938cf3
commit 58537fc92b
7 changed files with 165 additions and 23 deletions

1
.gitignore vendored
View File

@@ -175,3 +175,4 @@ cython_debug/
!/extensions/example
/temp
/wandb
.vscode/settings.json

View File

@@ -206,6 +206,8 @@ class AdapterConfig:
self.ilora_down: bool = kwargs.get('ilora_down', True)
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
self.ilora_up: bool = kwargs.get('ilora_up', True)
self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512)
self.flux_only_double: bool = kwargs.get('flux_only_double', False)

View File

@@ -17,6 +17,7 @@ from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
from toolkit.train_tools import get_torch_dtype
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
@@ -257,6 +258,13 @@ class CustomAdapter(torch.nn.Module):
self.vision_encoder = SiglipVisionModel.from_pretrained(
adapter_config.image_encoder_path,
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'pixtral':
self.image_processor = PixtralVisionImagePreprocessorCompatible(
max_image_size=self.config.pixtral_max_image_size,
)
self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained(
adapter_config.image_encoder_path,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit':
try:
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
@@ -700,9 +708,11 @@ class CustomAdapter(torch.nn.Module):
else:
return prompt_embeds
def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor:
with torch.no_grad():
tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
if shape is None:
shape = [batch_size, 3, self.input_size, self.input_size]
tensors_0_1 = torch.rand(shape, device=self.device)
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
@@ -761,7 +771,7 @@ class CustomAdapter(torch.nn.Module):
batch_size = clip_image.shape[0]
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
# add an unconditional so we can save it
unconditional = self.get_empty_clip_image(batch_size).to(
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
clip_image.device, dtype=clip_image.dtype
)
clip_image = torch.cat([unconditional, clip_image], dim=0)

View File

@@ -17,6 +17,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image, ImageFilter, ImageOps
@@ -733,6 +734,7 @@ class ClipImageFileItemDTOMixin:
return self._clip_vision_embeddings_path
def load_clip_image(self: 'FileItemDTO'):
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
if self.is_vision_clip_cached:
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
@@ -759,8 +761,24 @@ class ClipImageFileItemDTOMixin:
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if img.width != img.height:
if is_dynamic_size_and_aspect:
# just match the bucket size for now
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
elif img.width != img.height:
min_size = min(img.width, img.height)
if self.dataset_config.square_crop:
# center crop to a square

View File

@@ -33,7 +33,8 @@ class FeedForward(nn.Module):
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
# type: ignore
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -103,15 +104,18 @@ class Attention(nn.Module):
else:
cache.update(xk, xv)
key, val = cache.key, cache.value
key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
key = key.view(seqlen_sum * cache.max_seq_len,
self.n_kv_heads, self.head_dim)
val = val.view(seqlen_sum * cache.max_seq_len,
self.n_kv_heads, self.head_dim)
# Repeat keys and values to match number of query heads
key, val = repeat_kv(key, val, self.repeats, dim=1)
# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
output = memory_efficient_attention(
xq, key, val, mask if cache is None else cache.mask)
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
assert isinstance(output, torch.Tensor)
@@ -260,8 +264,8 @@ class PixtralVisionEncoder(nn.Module):
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None
@staticmethod
def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
if os.path.isdir(pretrained_model_name_or_path):
model_folder = pretrained_model_name_or_path
else:
@@ -275,11 +279,12 @@ class PixtralVisionEncoder(nn.Module):
with open(os.path.join(model_folder, "config.json"), "r") as f:
config = json.load(f)
model = PixtralVisionEncoder(**config)
model = cls(**config)
# see if there is a state_dict
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
state_dict = load_file(os.path.join(model_folder, "model.safetensors"))
state_dict = load_file(os.path.join(
model_folder, "model.safetensors"))
model.load_state_dict(state_dict)
return model
@@ -319,14 +324,17 @@ class PixtralVisionEncoder(nn.Module):
image_features: tensor of token features for all tokens of all images of
shape (N_toks, D)
"""
assert isinstance(images, list), f"Expected list of images, got {type(images)}"
assert isinstance(
images, list), f"Expected list of images, got {type(images)}"
assert all(len(img.shape) == 3 for img in
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
# pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
patch_embeds_list = [self.patch_conv(
img.unsqueeze(0)).squeeze(0) for img in images]
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
for p in patch_embeds_list], dim=0)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
@@ -355,7 +363,8 @@ class VisionLanguageAdapter(nn.Module):
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
# type: ignore[no-any-return]
return self.w_out(self.gelu(self.w_in(x)))
class VisionTransformerBlocks(nn.Module):
@@ -401,7 +410,8 @@ def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> tor
Returns:
torch.Tensor: Normalized image with shape (C, H, W).
"""
assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
assert image.shape[0] == len(mean) == len(
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
# Reshape mean and std to (C, 1, 1) for broadcasting
mean = mean.view(-1, 1, 1)
@@ -473,10 +483,12 @@ class PixtralVisionImagePreprocessor:
"""
# should not have batch
if len(image.shape) == 4:
raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}")
raise ValueError(
f"Expected image with shape (C, H, W), got {image.shape}")
if image.min() < 0.0 or image.max() > 1.0:
raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
raise ValueError(
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
w, h = self._image_to_num_tokens(image)
assert w > 0
@@ -490,3 +502,98 @@ class PixtralVisionImagePreprocessor:
processed_image = transform_image(image, new_image_size)
return processed_image
class PixtralVisionImagePreprocessorCompatibleReturn:
def __init__(self, pixel_values) -> None:
self.pixel_values = pixel_values
# Compatable version with ai toolkit flow
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
super().__init__(
image_patch_size=image_patch_size,
max_image_size=max_image_size
)
self.size = {
'height': max_image_size,
'width': max_image_size
}
self.image_mean = DATASET_MEAN
self.image_std = DATASET_STD
def __call__(
self,
images,
return_tensors="pt",
do_resize=True,
do_rescale=False,
) -> torch.Tensor:
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
processed_image = super().__call__(image)
out_stack.append(processed_image)
output = torch.stack(out_stack, dim=0)
return PixtralVisionImagePreprocessorCompatibleReturn(output)
class PixtralVisionEncoderCompatibleReturn:
def __init__(self, hidden_states) -> None:
self.hidden_states = hidden_states
class PixtralVisionEncoderCompatibleConfig:
def __init__(self):
self.image_size = 1024
self.hidden_size = 1024
self.patch_size = 16
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
def __init__(
self,
hidden_size: int = 1024,
num_channels: int = 3,
image_size: int = 1024,
patch_size: int = 16,
intermediate_size: int = 4096,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
rope_theta: float = 1e4, # for rope-2D
image_token_id: int = 10,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
rope_theta=rope_theta,
image_token_id=image_token_id,
)
self.config = PixtralVisionEncoderCompatibleConfig()
def forward(
self,
images,
output_hidden_states=True,
) -> torch.Tensor:
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
# must be in an array
image_output = super().forward([image])
out_stack.append(image_output)
output = torch.stack(out_stack, dim=0)
return PixtralVisionEncoderCompatibleReturn([output])

View File

@@ -9,6 +9,7 @@ from collections import OrderedDict
from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor
from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT

View File

@@ -1461,6 +1461,7 @@ class StableDiffusion:
detach_unconditional=False,
rescale_cfg=None,
return_conditional_pred=False,
guidance_embedding_scale=1.0,
**kwargs,
):
conditional_pred = None
@@ -1736,10 +1737,12 @@ class StableDiffusion:
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
# # handle guidance
guidance_scale = 1.0 # ?
if self.unet.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=self.device_torch)
guidance = guidance.expand(latents.shape[0])
if isinstance(guidance_scale, list):
guidance = torch.tensor(guidance_scale, device=self.device_torch)
else:
guidance = torch.tensor([guidance_scale], device=self.device_torch)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None