From b4f64de4c20d6a9e2fbdbd63f02b3895c80ab892 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 28 Sep 2024 15:36:42 -0600 Subject: [PATCH] Quick patch to scope xformer imports until a better solution --- requirements.txt | 3 +-- toolkit/models/pixtral_vision.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7f5d36b9..c251a40a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,5 +32,4 @@ sentencepiece huggingface_hub peft gradio -python-slugify -xformers \ No newline at end of file +python-slugify \ No newline at end of file diff --git a/toolkit/models/pixtral_vision.py b/toolkit/models/pixtral_vision.py index 51da56b1..815f3310 100644 --- a/toolkit/models/pixtral_vision.py +++ b/toolkit/models/pixtral_vision.py @@ -1,15 +1,16 @@ import math -from typing import List, Optional, Tuple, Any, Union +from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING import os import torch import torch.nn as nn from dataclasses import dataclass -from xformers.ops.fmha.attn_bias import BlockDiagonalMask -from xformers.ops.fmha import memory_efficient_attention from huggingface_hub import snapshot_download from safetensors.torch import load_file import json +if TYPE_CHECKING: + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): @@ -86,8 +87,9 @@ class Attention(nn.Module): x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[Any] = None, - mask: Optional[BlockDiagonalMask] = None, + mask: Optional['BlockDiagonalMask'] = None, ) -> torch.Tensor: + from xformers.ops.fmha import memory_efficient_attention assert mask is None or cache is None seqlen_sum, _ = x.shape @@ -155,7 +157,7 @@ class TransformerBlock(nn.Module): x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[Any] = None, - mask: Optional[BlockDiagonalMask] = None, + mask: Optional['BlockDiagonalMask'] = None, ) -> torch.Tensor: r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) h = x + r @@ -317,6 +319,7 @@ class PixtralVisionEncoder(nn.Module): self, images: List[torch.Tensor], ) -> torch.Tensor: + from xformers.ops.fmha.attn_bias import BlockDiagonalMask """ Args: images: list of N_img images of variable sizes, each of shape (C, H, W) @@ -387,7 +390,7 @@ class VisionTransformerBlocks(nn.Module): def forward( self, x: torch.Tensor, - mask: BlockDiagonalMask, + mask: 'BlockDiagonalMask', freqs_cis: Optional[torch.Tensor], ) -> torch.Tensor: for layer in self.layers: