mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added initial direct vision pixtral support
This commit is contained in:
@@ -33,7 +33,8 @@ class FeedForward(nn.Module):
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
|
||||
# type: ignore
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -103,15 +104,18 @@ class Attention(nn.Module):
|
||||
else:
|
||||
cache.update(xk, xv)
|
||||
key, val = cache.key, cache.value
|
||||
key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
||||
val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
||||
key = key.view(seqlen_sum * cache.max_seq_len,
|
||||
self.n_kv_heads, self.head_dim)
|
||||
val = val.view(seqlen_sum * cache.max_seq_len,
|
||||
self.n_kv_heads, self.head_dim)
|
||||
|
||||
# Repeat keys and values to match number of query heads
|
||||
key, val = repeat_kv(key, val, self.repeats, dim=1)
|
||||
|
||||
# xformers requires (B=1, S, H, D)
|
||||
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
|
||||
output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
|
||||
output = memory_efficient_attention(
|
||||
xq, key, val, mask if cache is None else cache.mask)
|
||||
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
|
||||
|
||||
assert isinstance(output, torch.Tensor)
|
||||
@@ -260,8 +264,8 @@ class PixtralVisionEncoder(nn.Module):
|
||||
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
||||
self._freqs_cis: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
model_folder = pretrained_model_name_or_path
|
||||
else:
|
||||
@@ -275,11 +279,12 @@ class PixtralVisionEncoder(nn.Module):
|
||||
with open(os.path.join(model_folder, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = PixtralVisionEncoder(**config)
|
||||
model = cls(**config)
|
||||
|
||||
# see if there is a state_dict
|
||||
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
|
||||
state_dict = load_file(os.path.join(model_folder, "model.safetensors"))
|
||||
state_dict = load_file(os.path.join(
|
||||
model_folder, "model.safetensors"))
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
@@ -319,14 +324,17 @@ class PixtralVisionEncoder(nn.Module):
|
||||
image_features: tensor of token features for all tokens of all images of
|
||||
shape (N_toks, D)
|
||||
"""
|
||||
assert isinstance(images, list), f"Expected list of images, got {type(images)}"
|
||||
assert isinstance(
|
||||
images, list), f"Expected list of images, got {type(images)}"
|
||||
assert all(len(img.shape) == 3 for img in
|
||||
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
|
||||
patch_embeds_list = [self.patch_conv(
|
||||
img.unsqueeze(0)).squeeze(0) for img in images]
|
||||
|
||||
# flatten to a single sequence
|
||||
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
|
||||
patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
|
||||
for p in patch_embeds_list], dim=0)
|
||||
patch_embeds = self.ln_pre(patch_embeds)
|
||||
|
||||
# positional embeddings
|
||||
@@ -355,7 +363,8 @@ class VisionLanguageAdapter(nn.Module):
|
||||
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
|
||||
# type: ignore[no-any-return]
|
||||
return self.w_out(self.gelu(self.w_in(x)))
|
||||
|
||||
|
||||
class VisionTransformerBlocks(nn.Module):
|
||||
@@ -401,7 +410,8 @@ def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> tor
|
||||
Returns:
|
||||
torch.Tensor: Normalized image with shape (C, H, W).
|
||||
"""
|
||||
assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
||||
assert image.shape[0] == len(mean) == len(
|
||||
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
||||
|
||||
# Reshape mean and std to (C, 1, 1) for broadcasting
|
||||
mean = mean.view(-1, 1, 1)
|
||||
@@ -473,10 +483,12 @@ class PixtralVisionImagePreprocessor:
|
||||
"""
|
||||
# should not have batch
|
||||
if len(image.shape) == 4:
|
||||
raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}")
|
||||
raise ValueError(
|
||||
f"Expected image with shape (C, H, W), got {image.shape}")
|
||||
|
||||
if image.min() < 0.0 or image.max() > 1.0:
|
||||
raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||
raise ValueError(
|
||||
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||
|
||||
w, h = self._image_to_num_tokens(image)
|
||||
assert w > 0
|
||||
@@ -490,3 +502,98 @@ class PixtralVisionImagePreprocessor:
|
||||
processed_image = transform_image(image, new_image_size)
|
||||
|
||||
return processed_image
|
||||
|
||||
|
||||
class PixtralVisionImagePreprocessorCompatibleReturn:
|
||||
def __init__(self, pixel_values) -> None:
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
# Compatable version with ai toolkit flow
|
||||
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
||||
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
|
||||
super().__init__(
|
||||
image_patch_size=image_patch_size,
|
||||
max_image_size=max_image_size
|
||||
)
|
||||
self.size = {
|
||||
'height': max_image_size,
|
||||
'width': max_image_size
|
||||
}
|
||||
self.image_mean = DATASET_MEAN
|
||||
self.image_std = DATASET_STD
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
) -> torch.Tensor:
|
||||
out_stack = []
|
||||
if len(images.shape) == 3:
|
||||
images = images.unsqueeze(0)
|
||||
for i in range(images.shape[0]):
|
||||
image = images[i]
|
||||
processed_image = super().__call__(image)
|
||||
out_stack.append(processed_image)
|
||||
|
||||
output = torch.stack(out_stack, dim=0)
|
||||
return PixtralVisionImagePreprocessorCompatibleReturn(output)
|
||||
|
||||
|
||||
class PixtralVisionEncoderCompatibleReturn:
|
||||
def __init__(self, hidden_states) -> None:
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
|
||||
class PixtralVisionEncoderCompatibleConfig:
|
||||
def __init__(self):
|
||||
self.image_size = 1024
|
||||
self.hidden_size = 1024
|
||||
self.patch_size = 16
|
||||
|
||||
|
||||
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
intermediate_size: int = 4096,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
rope_theta: float = 1e4, # for rope-2D
|
||||
image_token_id: int = 10,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_channels=num_channels,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
rope_theta=rope_theta,
|
||||
image_token_id=image_token_id,
|
||||
)
|
||||
self.config = PixtralVisionEncoderCompatibleConfig()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images,
|
||||
output_hidden_states=True,
|
||||
) -> torch.Tensor:
|
||||
out_stack = []
|
||||
if len(images.shape) == 3:
|
||||
images = images.unsqueeze(0)
|
||||
for i in range(images.shape[0]):
|
||||
image = images[i]
|
||||
# must be in an array
|
||||
image_output = super().forward([image])
|
||||
out_stack.append(image_output)
|
||||
|
||||
output = torch.stack(out_stack, dim=0)
|
||||
return PixtralVisionEncoderCompatibleReturn([output])
|
||||
|
||||
Reference in New Issue
Block a user