Quick patch to scope xformer imports until a better solution

This commit is contained in:
Jaret Burkett
2024-09-28 15:36:42 -06:00
parent 2e5f6668dc
commit b4f64de4c2
2 changed files with 10 additions and 8 deletions

View File

@@ -32,5 +32,4 @@ sentencepiece
huggingface_hub
peft
gradio
python-slugify
xformers
python-slugify

View File

@@ -1,15 +1,16 @@
import math
from typing import List, Optional, Tuple, Any, Union
from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING
import os
import torch
import torch.nn as nn
from dataclasses import dataclass
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from xformers.ops.fmha import memory_efficient_attention
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
import json
if TYPE_CHECKING:
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
@@ -86,8 +87,9 @@ class Attention(nn.Module):
x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: Optional[Any] = None,
mask: Optional[BlockDiagonalMask] = None,
mask: Optional['BlockDiagonalMask'] = None,
) -> torch.Tensor:
from xformers.ops.fmha import memory_efficient_attention
assert mask is None or cache is None
seqlen_sum, _ = x.shape
@@ -155,7 +157,7 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: Optional[Any] = None,
mask: Optional[BlockDiagonalMask] = None,
mask: Optional['BlockDiagonalMask'] = None,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
h = x + r
@@ -317,6 +319,7 @@ class PixtralVisionEncoder(nn.Module):
self,
images: List[torch.Tensor],
) -> torch.Tensor:
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
"""
Args:
images: list of N_img images of variable sizes, each of shape (C, H, W)
@@ -387,7 +390,7 @@ class VisionTransformerBlocks(nn.Module):
def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
mask: 'BlockDiagonalMask',
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers: