diff --git a/.gitignore b/.gitignore index 9e03d70f..d03f32e5 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,4 @@ cython_debug/ !/extensions/example /temp /wandb +.vscode/settings.json \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 61b6a35c..0306bd70 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -206,6 +206,8 @@ class AdapterConfig: self.ilora_down: bool = kwargs.get('ilora_down', True) self.ilora_mid: bool = kwargs.get('ilora_mid', True) self.ilora_up: bool = kwargs.get('ilora_up', True) + + self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512) self.flux_only_double: bool = kwargs.get('flux_only_double', False) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 23367bbf..2f5a79d4 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -17,6 +17,7 @@ from toolkit.paths import REPOS_ROOT from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model from toolkit.train_tools import get_torch_dtype +from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict @@ -257,6 +258,13 @@ class CustomAdapter(torch.nn.Module): self.vision_encoder = SiglipVisionModel.from_pretrained( adapter_config.image_encoder_path, ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'pixtral': + self.image_processor = PixtralVisionImagePreprocessorCompatible( + max_image_size=self.config.pixtral_max_image_size, + ) + self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained( + adapter_config.image_encoder_path, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) elif self.config.image_encoder_arch == 'vit': try: self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) @@ -700,9 +708,11 @@ class CustomAdapter(torch.nn.Module): else: return prompt_embeds - def get_empty_clip_image(self, batch_size: int) -> torch.Tensor: + def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor: with torch.no_grad(): - tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device) + if shape is None: + shape = [batch_size, 3, self.input_size, self.input_size] + tensors_0_1 = torch.rand(shape, device=self.device) noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) tensors_0_1 = tensors_0_1 * noise_scale @@ -761,7 +771,7 @@ class CustomAdapter(torch.nn.Module): batch_size = clip_image.shape[0] if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': # add an unconditional so we can save it - unconditional = self.get_empty_clip_image(batch_size).to( + unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to( clip_image.device, dtype=clip_image.dtype ) clip_image = torch.cat([unconditional, clip_image], dim=0) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 80f00cf0..ca2a6815 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -17,6 +17,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.basic import flush, value_map from toolkit.buckets import get_bucket_for_image_size, get_resolution from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms from PIL import Image, ImageFilter, ImageOps @@ -733,6 +734,7 @@ class ClipImageFileItemDTOMixin: return self._clip_vision_embeddings_path def load_clip_image(self: 'FileItemDTO'): + is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) if self.is_vision_clip_cached: self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path()) @@ -759,8 +761,24 @@ class ClipImageFileItemDTOMixin: if self.flip_y: # do a flip img = img.transpose(Image.FLIP_TOP_BOTTOM) - - if img.width != img.height: + + if is_dynamic_size_and_aspect: + # just match the bucket size for now + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Control images not supported for non-bucket datasets") + + elif img.width != img.height: min_size = min(img.width, img.height) if self.dataset_config.square_crop: # center crop to a square diff --git a/toolkit/models/pixtral_vision.py b/toolkit/models/pixtral_vision.py index c82a3b3f..ebebce40 100644 --- a/toolkit/models/pixtral_vision.py +++ b/toolkit/models/pixtral_vision.py @@ -33,7 +33,8 @@ class FeedForward(nn.Module): self.w3 = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore + # type: ignore + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: @@ -103,15 +104,18 @@ class Attention(nn.Module): else: cache.update(xk, xv) key, val = cache.key, cache.value - key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) - val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) + key = key.view(seqlen_sum * cache.max_seq_len, + self.n_kv_heads, self.head_dim) + val = val.view(seqlen_sum * cache.max_seq_len, + self.n_kv_heads, self.head_dim) # Repeat keys and values to match number of query heads key, val = repeat_kv(key, val, self.repeats, dim=1) # xformers requires (B=1, S, H, D) xq, key, val = xq[None, ...], key[None, ...], val[None, ...] - output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask) + output = memory_efficient_attention( + xq, key, val, mask if cache is None else cache.mask) output = output.view(seqlen_sum, self.n_heads * self.head_dim) assert isinstance(output, torch.Tensor) @@ -260,8 +264,8 @@ class PixtralVisionEncoder(nn.Module): assert head_dim % 2 == 0, "ROPE requires even head_dim" self._freqs_cis: Optional[torch.Tensor] = None - @staticmethod - def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder': + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder': if os.path.isdir(pretrained_model_name_or_path): model_folder = pretrained_model_name_or_path else: @@ -275,11 +279,12 @@ class PixtralVisionEncoder(nn.Module): with open(os.path.join(model_folder, "config.json"), "r") as f: config = json.load(f) - model = PixtralVisionEncoder(**config) + model = cls(**config) # see if there is a state_dict if os.path.exists(os.path.join(model_folder, "model.safetensors")): - state_dict = load_file(os.path.join(model_folder, "model.safetensors")) + state_dict = load_file(os.path.join( + model_folder, "model.safetensors")) model.load_state_dict(state_dict) return model @@ -319,14 +324,17 @@ class PixtralVisionEncoder(nn.Module): image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ - assert isinstance(images, list), f"Expected list of images, got {type(images)}" + assert isinstance( + images, list), f"Expected list of images, got {type(images)}" assert all(len(img.shape) == 3 for img in images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}" # pass images through initial convolution independently - patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images] + patch_embeds_list = [self.patch_conv( + img.unsqueeze(0)).squeeze(0) for img in images] # flatten to a single sequence - patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0) + patch_embeds = torch.cat([p.flatten(1).permute(1, 0) + for p in patch_embeds_list], dim=0) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings @@ -355,7 +363,8 @@ class VisionLanguageAdapter(nn.Module): self.w_out = nn.Linear(out_dim, out_dim, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return] + # type: ignore[no-any-return] + return self.w_out(self.gelu(self.w_in(x))) class VisionTransformerBlocks(nn.Module): @@ -401,7 +410,8 @@ def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> tor Returns: torch.Tensor: Normalized image with shape (C, H, W). """ - assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}" + assert image.shape[0] == len(mean) == len( + std), f"{image.shape=}, {mean.shape=}, {std.shape=}" # Reshape mean and std to (C, 1, 1) for broadcasting mean = mean.view(-1, 1, 1) @@ -473,10 +483,12 @@ class PixtralVisionImagePreprocessor: """ # should not have batch if len(image.shape) == 4: - raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}") + raise ValueError( + f"Expected image with shape (C, H, W), got {image.shape}") if image.min() < 0.0 or image.max() > 1.0: - raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}") + raise ValueError( + f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}") w, h = self._image_to_num_tokens(image) assert w > 0 @@ -490,3 +502,98 @@ class PixtralVisionImagePreprocessor: processed_image = transform_image(image, new_image_size) return processed_image + + +class PixtralVisionImagePreprocessorCompatibleReturn: + def __init__(self, pixel_values) -> None: + self.pixel_values = pixel_values + + +# Compatable version with ai toolkit flow +class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor): + def __init__(self, image_patch_size=16, max_image_size=1024) -> None: + super().__init__( + image_patch_size=image_patch_size, + max_image_size=max_image_size + ) + self.size = { + 'height': max_image_size, + 'width': max_image_size + } + self.image_mean = DATASET_MEAN + self.image_std = DATASET_STD + + def __call__( + self, + images, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ) -> torch.Tensor: + out_stack = [] + if len(images.shape) == 3: + images = images.unsqueeze(0) + for i in range(images.shape[0]): + image = images[i] + processed_image = super().__call__(image) + out_stack.append(processed_image) + + output = torch.stack(out_stack, dim=0) + return PixtralVisionImagePreprocessorCompatibleReturn(output) + + +class PixtralVisionEncoderCompatibleReturn: + def __init__(self, hidden_states) -> None: + self.hidden_states = hidden_states + + +class PixtralVisionEncoderCompatibleConfig: + def __init__(self): + self.image_size = 1024 + self.hidden_size = 1024 + self.patch_size = 16 + + +class PixtralVisionEncoderCompatible(PixtralVisionEncoder): + def __init__( + self, + hidden_size: int = 1024, + num_channels: int = 3, + image_size: int = 1024, + patch_size: int = 16, + intermediate_size: int = 4096, + num_hidden_layers: int = 24, + num_attention_heads: int = 16, + rope_theta: float = 1e4, # for rope-2D + image_token_id: int = 10, + **kwargs, + ): + super().__init__( + hidden_size=hidden_size, + num_channels=num_channels, + image_size=image_size, + patch_size=patch_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + rope_theta=rope_theta, + image_token_id=image_token_id, + ) + self.config = PixtralVisionEncoderCompatibleConfig() + + def forward( + self, + images, + output_hidden_states=True, + ) -> torch.Tensor: + out_stack = [] + if len(images.shape) == 3: + images = images.unsqueeze(0) + for i in range(images.shape[0]): + image = images[i] + # must be in an array + image_output = super().forward([image]) + out_stack.append(image_output) + + output = torch.stack(out_stack, dim=0) + return PixtralVisionEncoderCompatibleReturn([output]) diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index 793de115..e675e0b1 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -9,6 +9,7 @@ from collections import OrderedDict from diffusers import Transformer2DModel, FluxTransformer2DModel from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection +from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor from toolkit.config_modules import AdapterConfig from toolkit.paths import REPOS_ROOT diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 94140c33..6dba69b3 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1461,6 +1461,7 @@ class StableDiffusion: detach_unconditional=False, rescale_cfg=None, return_conditional_pred=False, + guidance_embedding_scale=1.0, **kwargs, ): conditional_pred = None @@ -1736,10 +1737,12 @@ class StableDiffusion: txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) # # handle guidance - guidance_scale = 1.0 # ? if self.unet.config.guidance_embeds: - guidance = torch.tensor([guidance_scale], device=self.device_torch) - guidance = guidance.expand(latents.shape[0]) + if isinstance(guidance_scale, list): + guidance = torch.tensor(guidance_scale, device=self.device_torch) + else: + guidance = torch.tensor([guidance_scale], device=self.device_torch) + guidance = guidance.expand(latents.shape[0]) else: guidance = None