mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Quick patch to scope xformer imports until a better solution
This commit is contained in:
@@ -32,5 +32,4 @@ sentencepiece
|
||||
huggingface_hub
|
||||
peft
|
||||
gradio
|
||||
python-slugify
|
||||
xformers
|
||||
python-slugify
|
||||
@@ -1,15 +1,16 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Any, Union
|
||||
from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
@@ -86,8 +87,9 @@ class Attention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[BlockDiagonalMask] = None,
|
||||
mask: Optional['BlockDiagonalMask'] = None,
|
||||
) -> torch.Tensor:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
assert mask is None or cache is None
|
||||
seqlen_sum, _ = x.shape
|
||||
|
||||
@@ -155,7 +157,7 @@ class TransformerBlock(nn.Module):
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[BlockDiagonalMask] = None,
|
||||
mask: Optional['BlockDiagonalMask'] = None,
|
||||
) -> torch.Tensor:
|
||||
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
|
||||
h = x + r
|
||||
@@ -317,6 +319,7 @@ class PixtralVisionEncoder(nn.Module):
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
"""
|
||||
Args:
|
||||
images: list of N_img images of variable sizes, each of shape (C, H, W)
|
||||
@@ -387,7 +390,7 @@ class VisionTransformerBlocks(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: BlockDiagonalMask,
|
||||
mask: 'BlockDiagonalMask',
|
||||
freqs_cis: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layers:
|
||||
|
||||
Reference in New Issue
Block a user