Added initial direct vision pixtral support

This commit is contained in:
Jaret Burkett
2024-09-28 10:47:51 -06:00
parent 86b5938cf3
commit 58537fc92b
7 changed files with 165 additions and 23 deletions

1
.gitignore vendored
View File

@@ -175,3 +175,4 @@ cython_debug/
!/extensions/example !/extensions/example
/temp /temp
/wandb /wandb
.vscode/settings.json

View File

@@ -206,6 +206,8 @@ class AdapterConfig:
self.ilora_down: bool = kwargs.get('ilora_down', True) self.ilora_down: bool = kwargs.get('ilora_down', True)
self.ilora_mid: bool = kwargs.get('ilora_mid', True) self.ilora_mid: bool = kwargs.get('ilora_mid', True)
self.ilora_up: bool = kwargs.get('ilora_up', True) self.ilora_up: bool = kwargs.get('ilora_up', True)
self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512)
self.flux_only_double: bool = kwargs.get('flux_only_double', False) self.flux_only_double: bool = kwargs.get('flux_only_double', False)

View File

@@ -17,6 +17,7 @@ from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
@@ -257,6 +258,13 @@ class CustomAdapter(torch.nn.Module):
self.vision_encoder = SiglipVisionModel.from_pretrained( self.vision_encoder = SiglipVisionModel.from_pretrained(
adapter_config.image_encoder_path, adapter_config.image_encoder_path,
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'pixtral':
self.image_processor = PixtralVisionImagePreprocessorCompatible(
max_image_size=self.config.pixtral_max_image_size,
)
self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained(
adapter_config.image_encoder_path,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit': elif self.config.image_encoder_arch == 'vit':
try: try:
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
@@ -700,9 +708,11 @@ class CustomAdapter(torch.nn.Module):
else: else:
return prompt_embeds return prompt_embeds
def get_empty_clip_image(self, batch_size: int) -> torch.Tensor: def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor:
with torch.no_grad(): with torch.no_grad():
tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device) if shape is None:
shape = [batch_size, 3, self.input_size, self.input_size]
tensors_0_1 = torch.rand(shape, device=self.device)
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype)) dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale tensors_0_1 = tensors_0_1 * noise_scale
@@ -761,7 +771,7 @@ class CustomAdapter(torch.nn.Module):
batch_size = clip_image.shape[0] batch_size = clip_image.shape[0]
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
# add an unconditional so we can save it # add an unconditional so we can save it
unconditional = self.get_empty_clip_image(batch_size).to( unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
clip_image.device, dtype=clip_image.dtype clip_image.device, dtype=clip_image.dtype
) )
clip_image = torch.cat([unconditional, clip_image], dim=0) clip_image = torch.cat([unconditional, clip_image], dim=0)

View File

@@ -17,6 +17,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.basic import flush, value_map from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution from toolkit.buckets import get_bucket_for_image_size, get_resolution
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
from toolkit.prompt_utils import inject_trigger_into_prompt from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms from torchvision import transforms
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
@@ -733,6 +734,7 @@ class ClipImageFileItemDTOMixin:
return self._clip_vision_embeddings_path return self._clip_vision_embeddings_path
def load_clip_image(self: 'FileItemDTO'): def load_clip_image(self: 'FileItemDTO'):
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible)
if self.is_vision_clip_cached: if self.is_vision_clip_cached:
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path()) self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
@@ -759,8 +761,24 @@ class ClipImageFileItemDTOMixin:
if self.flip_y: if self.flip_y:
# do a flip # do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM) img = img.transpose(Image.FLIP_TOP_BOTTOM)
if img.width != img.height: if is_dynamic_size_and_aspect:
# just match the bucket size for now
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
elif img.width != img.height:
min_size = min(img.width, img.height) min_size = min(img.width, img.height)
if self.dataset_config.square_crop: if self.dataset_config.square_crop:
# center crop to a square # center crop to a square

View File

@@ -33,7 +33,8 @@ class FeedForward(nn.Module):
self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore # type: ignore
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -103,15 +104,18 @@ class Attention(nn.Module):
else: else:
cache.update(xk, xv) cache.update(xk, xv)
key, val = cache.key, cache.value key, val = cache.key, cache.value
key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) key = key.view(seqlen_sum * cache.max_seq_len,
val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) self.n_kv_heads, self.head_dim)
val = val.view(seqlen_sum * cache.max_seq_len,
self.n_kv_heads, self.head_dim)
# Repeat keys and values to match number of query heads # Repeat keys and values to match number of query heads
key, val = repeat_kv(key, val, self.repeats, dim=1) key, val = repeat_kv(key, val, self.repeats, dim=1)
# xformers requires (B=1, S, H, D) # xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...] xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask) output = memory_efficient_attention(
xq, key, val, mask if cache is None else cache.mask)
output = output.view(seqlen_sum, self.n_heads * self.head_dim) output = output.view(seqlen_sum, self.n_heads * self.head_dim)
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
@@ -260,8 +264,8 @@ class PixtralVisionEncoder(nn.Module):
assert head_dim % 2 == 0, "ROPE requires even head_dim" assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None self._freqs_cis: Optional[torch.Tensor] = None
@staticmethod @classmethod
def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder': def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
model_folder = pretrained_model_name_or_path model_folder = pretrained_model_name_or_path
else: else:
@@ -275,11 +279,12 @@ class PixtralVisionEncoder(nn.Module):
with open(os.path.join(model_folder, "config.json"), "r") as f: with open(os.path.join(model_folder, "config.json"), "r") as f:
config = json.load(f) config = json.load(f)
model = PixtralVisionEncoder(**config) model = cls(**config)
# see if there is a state_dict # see if there is a state_dict
if os.path.exists(os.path.join(model_folder, "model.safetensors")): if os.path.exists(os.path.join(model_folder, "model.safetensors")):
state_dict = load_file(os.path.join(model_folder, "model.safetensors")) state_dict = load_file(os.path.join(
model_folder, "model.safetensors"))
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
@@ -319,14 +324,17 @@ class PixtralVisionEncoder(nn.Module):
image_features: tensor of token features for all tokens of all images of image_features: tensor of token features for all tokens of all images of
shape (N_toks, D) shape (N_toks, D)
""" """
assert isinstance(images, list), f"Expected list of images, got {type(images)}" assert isinstance(
images, list), f"Expected list of images, got {type(images)}"
assert all(len(img.shape) == 3 for img in assert all(len(img.shape) == 3 for img in
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}" images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
# pass images through initial convolution independently # pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images] patch_embeds_list = [self.patch_conv(
img.unsqueeze(0)).squeeze(0) for img in images]
# flatten to a single sequence # flatten to a single sequence
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0) patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
for p in patch_embeds_list], dim=0)
patch_embeds = self.ln_pre(patch_embeds) patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings # positional embeddings
@@ -355,7 +363,8 @@ class VisionLanguageAdapter(nn.Module):
self.w_out = nn.Linear(out_dim, out_dim, bias=True) self.w_out = nn.Linear(out_dim, out_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return] # type: ignore[no-any-return]
return self.w_out(self.gelu(self.w_in(x)))
class VisionTransformerBlocks(nn.Module): class VisionTransformerBlocks(nn.Module):
@@ -401,7 +410,8 @@ def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> tor
Returns: Returns:
torch.Tensor: Normalized image with shape (C, H, W). torch.Tensor: Normalized image with shape (C, H, W).
""" """
assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}" assert image.shape[0] == len(mean) == len(
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
# Reshape mean and std to (C, 1, 1) for broadcasting # Reshape mean and std to (C, 1, 1) for broadcasting
mean = mean.view(-1, 1, 1) mean = mean.view(-1, 1, 1)
@@ -473,10 +483,12 @@ class PixtralVisionImagePreprocessor:
""" """
# should not have batch # should not have batch
if len(image.shape) == 4: if len(image.shape) == 4:
raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}") raise ValueError(
f"Expected image with shape (C, H, W), got {image.shape}")
if image.min() < 0.0 or image.max() > 1.0: if image.min() < 0.0 or image.max() > 1.0:
raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}") raise ValueError(
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
w, h = self._image_to_num_tokens(image) w, h = self._image_to_num_tokens(image)
assert w > 0 assert w > 0
@@ -490,3 +502,98 @@ class PixtralVisionImagePreprocessor:
processed_image = transform_image(image, new_image_size) processed_image = transform_image(image, new_image_size)
return processed_image return processed_image
class PixtralVisionImagePreprocessorCompatibleReturn:
def __init__(self, pixel_values) -> None:
self.pixel_values = pixel_values
# Compatable version with ai toolkit flow
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
super().__init__(
image_patch_size=image_patch_size,
max_image_size=max_image_size
)
self.size = {
'height': max_image_size,
'width': max_image_size
}
self.image_mean = DATASET_MEAN
self.image_std = DATASET_STD
def __call__(
self,
images,
return_tensors="pt",
do_resize=True,
do_rescale=False,
) -> torch.Tensor:
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
processed_image = super().__call__(image)
out_stack.append(processed_image)
output = torch.stack(out_stack, dim=0)
return PixtralVisionImagePreprocessorCompatibleReturn(output)
class PixtralVisionEncoderCompatibleReturn:
def __init__(self, hidden_states) -> None:
self.hidden_states = hidden_states
class PixtralVisionEncoderCompatibleConfig:
def __init__(self):
self.image_size = 1024
self.hidden_size = 1024
self.patch_size = 16
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
def __init__(
self,
hidden_size: int = 1024,
num_channels: int = 3,
image_size: int = 1024,
patch_size: int = 16,
intermediate_size: int = 4096,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
rope_theta: float = 1e4, # for rope-2D
image_token_id: int = 10,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
rope_theta=rope_theta,
image_token_id=image_token_id,
)
self.config = PixtralVisionEncoderCompatibleConfig()
def forward(
self,
images,
output_hidden_states=True,
) -> torch.Tensor:
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
# must be in an array
image_output = super().forward([image])
out_stack.append(image_output)
output = torch.stack(out_stack, dim=0)
return PixtralVisionEncoderCompatibleReturn([output])

View File

@@ -9,6 +9,7 @@ from collections import OrderedDict
from diffusers import Transformer2DModel, FluxTransformer2DModel from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor
from toolkit.config_modules import AdapterConfig from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT

View File

@@ -1461,6 +1461,7 @@ class StableDiffusion:
detach_unconditional=False, detach_unconditional=False,
rescale_cfg=None, rescale_cfg=None,
return_conditional_pred=False, return_conditional_pred=False,
guidance_embedding_scale=1.0,
**kwargs, **kwargs,
): ):
conditional_pred = None conditional_pred = None
@@ -1736,10 +1737,12 @@ class StableDiffusion:
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
# # handle guidance # # handle guidance
guidance_scale = 1.0 # ?
if self.unet.config.guidance_embeds: if self.unet.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=self.device_torch) if isinstance(guidance_scale, list):
guidance = guidance.expand(latents.shape[0]) guidance = torch.tensor(guidance_scale, device=self.device_torch)
else:
guidance = torch.tensor([guidance_scale], device=self.device_torch)
guidance = guidance.expand(latents.shape[0])
else: else:
guidance = None guidance = None