remove files

This commit is contained in:
layerdiffusion
2024-08-05 21:24:35 -07:00
parent ae1d995d0d
commit 37e656d5aa
13 changed files with 573 additions and 1274 deletions

View File

@@ -1,72 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"

View File

@@ -1,73 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr_m18.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"

View File

@@ -1,98 +0,0 @@
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: false
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True

View File

@@ -1,5 +0,0 @@
model:
target: modules.models.sd3.sd3_model.SD3Inferencer
params:
shift: 3
state_dict: null

View File

@@ -1,98 +0,0 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: False
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -1,70 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@@ -1,70 +0,0 @@
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@@ -1,98 +1,98 @@
import logging # import logging
#
import torch # import torch
from torch import Tensor # from torch import Tensor
import platform # import platform
from modules.sd_hijack_utils import CondFunc # from modules.sd_hijack_utils import CondFunc
from packaging import version # from packaging import version
from modules import shared # from modules import shared
#
log = logging.getLogger(__name__) # log = logging.getLogger(__name__)
#
#
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility. # # use check `getattr` and try it for compatibility.
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability, # # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 # # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool: # def check_for_mps() -> bool:
if version.parse(torch.__version__) <= version.parse("2.0.1"): # if version.parse(torch.__version__) <= version.parse("2.0.1"):
if not getattr(torch, 'has_mps', False): # if not getattr(torch, 'has_mps', False):
return False # return False
try: # try:
torch.zeros(1).to(torch.device("mps")) # torch.zeros(1).to(torch.device("mps"))
return True # return True
except Exception: # except Exception:
return False # return False
else: # else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built() # return torch.backends.mps.is_available() and torch.backends.mps.is_built()
#
#
has_mps = check_for_mps() # has_mps = check_for_mps()
#
#
def torch_mps_gc() -> None: # def torch_mps_gc() -> None:
try: # try:
if shared.state.current_latent is not None: # if shared.state.current_latent is not None:
log.debug("`current_latent` is set, skipping MPS garbage collection") # log.debug("`current_latent` is set, skipping MPS garbage collection")
return # return
from torch.mps import empty_cache # from torch.mps import empty_cache
empty_cache() # empty_cache()
except Exception: # except Exception:
log.warning("MPS garbage collection failed", exc_info=True) # log.warning("MPS garbage collection failed", exc_info=True)
#
#
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 # # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs): # def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps': # if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype) # output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64: # if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) # return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): # elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) # return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs) # return cumsum_func(input, *args, **kwargs)
#
#
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 # # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor: # def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
try: # try:
return orig_func(*args, **kwargs) # return orig_func(*args, **kwargs)
except RuntimeError as e: # except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e): # if "not implemented for" in str(e) and "Half" in str(e):
input_tensor = args[0] # input_tensor = args[0]
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype) # return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
else: # else:
print(f"An unexpected RuntimeError occurred: {str(e)}") # print(f"An unexpected RuntimeError occurred: {str(e)}")
#
if has_mps: # if has_mps:
if platform.mac_ver()[0].startswith("13.2."): # if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) # # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) # CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
#
if version.parse(torch.__version__) < version.parse("1.13"): # if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working # # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
#
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), # CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) # lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 # # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), # CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') # lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 # # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) # CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"): # elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) # cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) # cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None) # CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) # CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) # CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
#
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113 # # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') # CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
#
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 # # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None) # CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
#
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311 # # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386': # if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']: # for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') # CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')

View File

@@ -1,31 +1,31 @@
import importlib # import importlib
import torch # import torch
#
from modules import shared # from modules import shared
#
#
def check_for_npu(): # def check_for_npu():
if importlib.util.find_spec("torch_npu") is None: # if importlib.util.find_spec("torch_npu") is None:
return False # return False
import torch_npu # import torch_npu
#
try: # try:
# Will raise a RuntimeError if no NPU is found # # Will raise a RuntimeError if no NPU is found
_ = torch_npu.npu.device_count() # _ = torch_npu.npu.device_count()
return torch.npu.is_available() # return torch.npu.is_available()
except RuntimeError: # except RuntimeError:
return False # return False
#
#
def get_npu_device_string(): # def get_npu_device_string():
if shared.cmd_opts.device_id is not None: # if shared.cmd_opts.device_id is not None:
return f"npu:{shared.cmd_opts.device_id}" # return f"npu:{shared.cmd_opts.device_id}"
return "npu:0" # return "npu:0"
#
#
def torch_npu_gc(): # def torch_npu_gc():
with torch.npu.device(get_npu_device_string()): # with torch.npu.device(get_npu_device_string()):
torch.npu.empty_cache() # torch.npu.empty_cache()
#
#
has_npu = check_for_npu() # has_npu = check_for_npu()

View File

@@ -1,215 +0,0 @@
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple
def narrow_trunc(
input: Tensor,
dim: int,
start: int,
length: int
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
) -> Tensor:
batch_x_heads, k_tokens, k_channels_per_head = key.shape
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
key_chunk = narrow_trunc(
key,
1,
chunk_idx,
kv_chunk_size
)
value_chunk = narrow_trunc(
value,
1,
chunk_idx,
kv_chunk_size
)
return summarize_chunk(query, key_chunk, value_chunk)
chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= torch.unsqueeze(max_diffs, -1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(dim=0)
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
This is efficient version of attention presented in
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
Args:
query: queries for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
key: keys for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
value: values to be used in attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
"""
batch_x_heads, q_tokens, q_channels_per_head = query.shape
_, k_tokens, _ = key.shape
scale = q_channels_per_head ** -0.5
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
def get_query_chunk(chunk_idx: int) -> Tensor:
return narrow_trunc(
query,
1,
chunk_idx,
min(query_chunk_size, q_tokens)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(
_query_chunk_attention,
kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk,
)
)
if q_tokens <= query_chunk_size:
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn(
query=query,
key=key,
value=value,
)
res = torch.zeros_like(query)
for i in range(math.ceil(q_tokens / query_chunk_size)):
attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
)
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
return res

View File

@@ -1,140 +1,140 @@
from transformers import BertPreTrainedModel, BertConfig # from transformers import BertPreTrainedModel, BertConfig
import torch.nn as nn # import torch.nn as nn
import torch # import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig # from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer # from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional # from typing import Optional
#
from modules import torch_utils # from modules import torch_utils
#
#
class BertSeriesConfig(BertConfig): # class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): # def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
#
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) # super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim # self.project_dim = project_dim
self.pooler_fn = pooler_fn # self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder # self.learn_encoder = learn_encoder
#
class RobertaSeriesConfig(XLMRobertaConfig): # class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): # def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) # super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim # self.project_dim = project_dim
self.pooler_fn = pooler_fn # self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder # self.learn_encoder = learn_encoder
#
#
class BertSeriesModelWithTransformation(BertPreTrainedModel): # class BertSeriesModelWithTransformation(BertPreTrainedModel):
#
_keys_to_ignore_on_load_unexpected = [r"pooler"] # _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig # config_class = BertSeriesConfig
#
def __init__(self, config=None, **kargs): # def __init__(self, config=None, **kargs):
# modify initialization for autoloading # # modify initialization for autoloading
if config is None: # if config is None:
config = XLMRobertaConfig() # config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1 # config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0 # config.bos_token_id=0
config.eos_token_id=2 # config.eos_token_id=2
config.hidden_act='gelu' # config.hidden_act='gelu'
config.hidden_dropout_prob=0.1 # config.hidden_dropout_prob=0.1
config.hidden_size=1024 # config.hidden_size=1024
config.initializer_range=0.02 # config.initializer_range=0.02
config.intermediate_size=4096 # config.intermediate_size=4096
config.layer_norm_eps=1e-05 # config.layer_norm_eps=1e-05
config.max_position_embeddings=514 # config.max_position_embeddings=514
#
config.num_attention_heads=16 # config.num_attention_heads=16
config.num_hidden_layers=24 # config.num_hidden_layers=24
config.output_past=True # config.output_past=True
config.pad_token_id=1 # config.pad_token_id=1
config.position_embedding_type= "absolute" # config.position_embedding_type= "absolute"
#
config.type_vocab_size= 1 # config.type_vocab_size= 1
config.use_cache=True # config.use_cache=True
config.vocab_size= 250002 # config.vocab_size= 250002
config.project_dim = 768 # config.project_dim = 768
config.learn_encoder = False # config.learn_encoder = False
super().__init__(config) # super().__init__(config)
self.roberta = XLMRobertaModel(config) # self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim) # self.transformation = nn.Linear(config.hidden_size,config.project_dim)
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') # self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
self.pooler = lambda x: x[:,0] # self.pooler = lambda x: x[:,0]
self.post_init() # self.post_init()
#
def encode(self,c): # def encode(self,c):
device = torch_utils.get_param(self).device # device = torch_utils.get_param(self).device
text = self.tokenizer(c, # text = self.tokenizer(c,
truncation=True, # truncation=True,
max_length=77, # max_length=77,
return_length=False, # return_length=False,
return_overflowing_tokens=False, # return_overflowing_tokens=False,
padding="max_length", # padding="max_length",
return_tensors="pt") # return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device) # text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor( # text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device) # text['attention_mask']).to(device)
features = self(**text) # features = self(**text)
return features['projection_state'] # return features['projection_state']
#
def forward( # def forward(
self, # self,
input_ids: Optional[torch.Tensor] = None, # input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, # token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, # position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None, # head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, # inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, # encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, # encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, # output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None, # return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, # output_hidden_states: Optional[bool] = None,
) : # ) :
r""" # r"""
""" # """
#
return_dict = return_dict if return_dict is not None else self.config.use_return_dict # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
#
#
outputs = self.roberta( # outputs = self.roberta(
input_ids=input_ids, # input_ids=input_ids,
attention_mask=attention_mask, # attention_mask=attention_mask,
token_type_ids=token_type_ids, # token_type_ids=token_type_ids,
position_ids=position_ids, # position_ids=position_ids,
head_mask=head_mask, # head_mask=head_mask,
inputs_embeds=inputs_embeds, # inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states, # encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, # encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, # output_attentions=output_attentions,
output_hidden_states=True, # output_hidden_states=True,
return_dict=return_dict, # return_dict=return_dict,
) # )
#
# last module outputs # # last module outputs
sequence_output = outputs[0] # sequence_output = outputs[0]
#
#
# project every module # # project every module
sequence_output_ln = self.pre_LN(sequence_output) # sequence_output_ln = self.pre_LN(sequence_output)
#
# pooler # # pooler
pooler_output = self.pooler(sequence_output_ln) # pooler_output = self.pooler(sequence_output_ln)
pooler_output = self.transformation(pooler_output) # pooler_output = self.transformation(pooler_output)
projection_state = self.transformation(outputs.last_hidden_state) # projection_state = self.transformation(outputs.last_hidden_state)
#
return { # return {
'pooler_output':pooler_output, # 'pooler_output':pooler_output,
'last_hidden_state':outputs.last_hidden_state, # 'last_hidden_state':outputs.last_hidden_state,
'hidden_states':outputs.hidden_states, # 'hidden_states':outputs.hidden_states,
'attentions':outputs.attentions, # 'attentions':outputs.attentions,
'projection_state':projection_state, # 'projection_state':projection_state,
'sequence_out': sequence_output # 'sequence_out': sequence_output
} # }
#
#
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): # class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta' # base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig # config_class= RobertaSeriesConfig

View File

@@ -1,166 +1,166 @@
from transformers import BertPreTrainedModel,BertConfig # from transformers import BertPreTrainedModel,BertConfig
import torch.nn as nn # import torch.nn as nn
import torch # import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig # from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer # from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional # from typing import Optional
from modules import torch_utils # from modules import torch_utils
#
#
class BertSeriesConfig(BertConfig): # class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): # def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
#
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) # super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim # self.project_dim = project_dim
self.pooler_fn = pooler_fn # self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder # self.learn_encoder = learn_encoder
#
class RobertaSeriesConfig(XLMRobertaConfig): # class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): # def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) # super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim # self.project_dim = project_dim
self.pooler_fn = pooler_fn # self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder # self.learn_encoder = learn_encoder
#
#
class BertSeriesModelWithTransformation(BertPreTrainedModel): # class BertSeriesModelWithTransformation(BertPreTrainedModel):
#
_keys_to_ignore_on_load_unexpected = [r"pooler"] # _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] # _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig # config_class = BertSeriesConfig
#
def __init__(self, config=None, **kargs): # def __init__(self, config=None, **kargs):
# modify initialization for autoloading # # modify initialization for autoloading
if config is None: # if config is None:
config = XLMRobertaConfig() # config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1 # config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0 # config.bos_token_id=0
config.eos_token_id=2 # config.eos_token_id=2
config.hidden_act='gelu' # config.hidden_act='gelu'
config.hidden_dropout_prob=0.1 # config.hidden_dropout_prob=0.1
config.hidden_size=1024 # config.hidden_size=1024
config.initializer_range=0.02 # config.initializer_range=0.02
config.intermediate_size=4096 # config.intermediate_size=4096
config.layer_norm_eps=1e-05 # config.layer_norm_eps=1e-05
config.max_position_embeddings=514 # config.max_position_embeddings=514
#
config.num_attention_heads=16 # config.num_attention_heads=16
config.num_hidden_layers=24 # config.num_hidden_layers=24
config.output_past=True # config.output_past=True
config.pad_token_id=1 # config.pad_token_id=1
config.position_embedding_type= "absolute" # config.position_embedding_type= "absolute"
#
config.type_vocab_size= 1 # config.type_vocab_size= 1
config.use_cache=True # config.use_cache=True
config.vocab_size= 250002 # config.vocab_size= 250002
config.project_dim = 1024 # config.project_dim = 1024
config.learn_encoder = False # config.learn_encoder = False
super().__init__(config) # super().__init__(config)
self.roberta = XLMRobertaModel(config) # self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim) # self.transformation = nn.Linear(config.hidden_size,config.project_dim)
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') # self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
# self.pooler = lambda x: x[:,0] # # self.pooler = lambda x: x[:,0]
# self.post_init() # # self.post_init()
#
self.has_pre_transformation = True # self.has_pre_transformation = True
if self.has_pre_transformation: # if self.has_pre_transformation:
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) # self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init() # self.post_init()
#
def encode(self,c): # def encode(self,c):
device = torch_utils.get_param(self).device # device = torch_utils.get_param(self).device
text = self.tokenizer(c, # text = self.tokenizer(c,
truncation=True, # truncation=True,
max_length=77, # max_length=77,
return_length=False, # return_length=False,
return_overflowing_tokens=False, # return_overflowing_tokens=False,
padding="max_length", # padding="max_length",
return_tensors="pt") # return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device) # text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor( # text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device) # text['attention_mask']).to(device)
features = self(**text) # features = self(**text)
return features['projection_state'] # return features['projection_state']
#
def forward( # def forward(
self, # self,
input_ids: Optional[torch.Tensor] = None, # input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, # token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, # position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None, # head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, # inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, # encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, # encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, # output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None, # return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, # output_hidden_states: Optional[bool] = None,
) : # ) :
r""" # r"""
""" # """
#
return_dict = return_dict if return_dict is not None else self.config.use_return_dict # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
#
#
outputs = self.roberta( # outputs = self.roberta(
input_ids=input_ids, # input_ids=input_ids,
attention_mask=attention_mask, # attention_mask=attention_mask,
token_type_ids=token_type_ids, # token_type_ids=token_type_ids,
position_ids=position_ids, # position_ids=position_ids,
head_mask=head_mask, # head_mask=head_mask,
inputs_embeds=inputs_embeds, # inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states, # encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, # encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, # output_attentions=output_attentions,
output_hidden_states=True, # output_hidden_states=True,
return_dict=return_dict, # return_dict=return_dict,
) # )
#
# # last module outputs # # # last module outputs
# sequence_output = outputs[0] # # sequence_output = outputs[0]
#
#
# # project every module # # # project every module
# sequence_output_ln = self.pre_LN(sequence_output) # # sequence_output_ln = self.pre_LN(sequence_output)
#
# # pooler # # # pooler
# pooler_output = self.pooler(sequence_output_ln) # # pooler_output = self.pooler(sequence_output_ln)
# pooler_output = self.transformation(pooler_output) # # pooler_output = self.transformation(pooler_output)
# projection_state = self.transformation(outputs.last_hidden_state) # # projection_state = self.transformation(outputs.last_hidden_state)
#
if self.has_pre_transformation: # if self.has_pre_transformation:
sequence_output2 = outputs["hidden_states"][-2] # sequence_output2 = outputs["hidden_states"][-2]
sequence_output2 = self.pre_LN(sequence_output2) # sequence_output2 = self.pre_LN(sequence_output2)
projection_state2 = self.transformation_pre(sequence_output2) # projection_state2 = self.transformation_pre(sequence_output2)
#
return { # return {
"projection_state": projection_state2, # "projection_state": projection_state2,
"last_hidden_state": outputs.last_hidden_state, # "last_hidden_state": outputs.last_hidden_state,
"hidden_states": outputs.hidden_states, # "hidden_states": outputs.hidden_states,
"attentions": outputs.attentions, # "attentions": outputs.attentions,
} # }
else: # else:
projection_state = self.transformation(outputs.last_hidden_state) # projection_state = self.transformation(outputs.last_hidden_state)
return { # return {
"projection_state": projection_state, # "projection_state": projection_state,
"last_hidden_state": outputs.last_hidden_state, # "last_hidden_state": outputs.last_hidden_state,
"hidden_states": outputs.hidden_states, # "hidden_states": outputs.hidden_states,
"attentions": outputs.attentions, # "attentions": outputs.attentions,
} # }
#
#
# return { # # return {
# 'pooler_output':pooler_output, # # 'pooler_output':pooler_output,
# 'last_hidden_state':outputs.last_hidden_state, # # 'last_hidden_state':outputs.last_hidden_state,
# 'hidden_states':outputs.hidden_states, # # 'hidden_states':outputs.hidden_states,
# 'attentions':outputs.attentions, # # 'attentions':outputs.attentions,
# 'projection_state':projection_state, # # 'projection_state':projection_state,
# 'sequence_out': sequence_output # # 'sequence_out': sequence_output
# } # # }
#
#
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): # class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta' # base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig # config_class= RobertaSeriesConfig

View File

@@ -1,138 +1,138 @@
from modules import shared # from modules import shared
from modules.sd_hijack_utils import CondFunc # from modules.sd_hijack_utils import CondFunc
#
has_ipex = False # has_ipex = False
try: # try:
import torch # import torch
import intel_extension_for_pytorch as ipex # noqa: F401 # import intel_extension_for_pytorch as ipex # noqa: F401
has_ipex = True # has_ipex = True
except Exception: # except Exception:
pass # pass
#
#
def check_for_xpu(): # def check_for_xpu():
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() # return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
#
#
def get_xpu_device_string(): # def get_xpu_device_string():
if shared.cmd_opts.device_id is not None: # if shared.cmd_opts.device_id is not None:
return f"xpu:{shared.cmd_opts.device_id}" # return f"xpu:{shared.cmd_opts.device_id}"
return "xpu" # return "xpu"
#
#
def torch_xpu_gc(): # def torch_xpu_gc():
with torch.xpu.device(get_xpu_device_string()): # with torch.xpu.device(get_xpu_device_string()):
torch.xpu.empty_cache() # torch.xpu.empty_cache()
#
#
has_xpu = check_for_xpu() # has_xpu = check_for_xpu()
#
#
# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 # # Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
# Here we implement a slicing algorithm to split large batch size into smaller chunks, # # Here we implement a slicing algorithm to split large batch size into smaller chunks,
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. # # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, # # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
# which is the best trade-off between VRAM usage and performance. # # which is the best trade-off between VRAM usage and performance.
ARC_SINGLE_ALLOCATION_LIMIT = {} # ARC_SINGLE_ALLOCATION_LIMIT = {}
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention # orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
def torch_xpu_scaled_dot_product_attention( # def torch_xpu_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs # query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
): # ):
# cast to same dtype first # # cast to same dtype first
key = key.to(query.dtype) # key = key.to(query.dtype)
value = value.to(query.dtype) # value = value.to(query.dtype)
if attn_mask is not None and attn_mask.dtype != torch.bool: # if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(query.dtype) # attn_mask = attn_mask.to(query.dtype)
#
N = query.shape[:-2] # Batch size # N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length # L = query.size(-2) # Target sequence length
E = query.size(-1) # Embedding dimension of the query and key # E = query.size(-1) # Embedding dimension of the query and key
S = key.size(-2) # Source sequence length # S = key.size(-2) # Source sequence length
Ev = value.size(-1) # Embedding dimension of the value # Ev = value.size(-1) # Embedding dimension of the value
#
total_batch_size = torch.numel(torch.empty(N)) # total_batch_size = torch.numel(torch.empty(N))
device_id = query.device.index # device_id = query.device.index
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: # if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) # ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) # batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
#
if total_batch_size <= batch_size_limit: # if total_batch_size <= batch_size_limit:
return orig_sdp_attn_func( # return orig_sdp_attn_func(
query, # query,
key, # key,
value, # value,
attn_mask, # attn_mask,
dropout_p, # dropout_p,
is_causal, # is_causal,
*args, **kwargs # *args, **kwargs
) # )
#
query = torch.reshape(query, (-1, L, E)) # query = torch.reshape(query, (-1, L, E))
key = torch.reshape(key, (-1, S, E)) # key = torch.reshape(key, (-1, S, E))
value = torch.reshape(value, (-1, S, Ev)) # value = torch.reshape(value, (-1, S, Ev))
if attn_mask is not None: # if attn_mask is not None:
attn_mask = attn_mask.view(-1, L, S) # attn_mask = attn_mask.view(-1, L, S)
chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit # chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
outputs = [] # outputs = []
for i in range(chunk_count): # for i in range(chunk_count):
attn_mask_chunk = ( # attn_mask_chunk = (
None # None
if attn_mask is None # if attn_mask is None
else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] # else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
) # )
chunk_output = orig_sdp_attn_func( # chunk_output = orig_sdp_attn_func(
query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], # query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], # key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], # value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
attn_mask_chunk, # attn_mask_chunk,
dropout_p, # dropout_p,
is_causal, # is_causal,
*args, **kwargs # *args, **kwargs
) # )
outputs.append(chunk_output) # outputs.append(chunk_output)
result = torch.cat(outputs, dim=0) # result = torch.cat(outputs, dim=0)
return torch.reshape(result, (*N, L, Ev)) # return torch.reshape(result, (*N, L, Ev))
#
#
def is_xpu_device(device: str | torch.device = None): # def is_xpu_device(device: str | torch.device = None):
if device is None: # if device is None:
return False # return False
if isinstance(device, str): # if isinstance(device, str):
return device.startswith("xpu") # return device.startswith("xpu")
return device.type == "xpu" # return device.type == "xpu"
#
#
if has_xpu: # if has_xpu:
try: # try:
# torch.Generator supports "xpu" device since 2.1 # # torch.Generator supports "xpu" device since 2.1
torch.Generator("xpu") # torch.Generator("xpu")
except RuntimeError: # except RuntimeError:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) # # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
CondFunc('torch.Generator', # CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device), # lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: is_xpu_device(device)) # lambda orig_func, device=None: is_xpu_device(device))
#
# W/A for some OPs that could not handle different input dtypes # # W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm', # CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: # lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), # orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: # lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype) # weight is not None and input.dtype != weight.data.dtype)
CondFunc('torch.nn.modules.GroupNorm.forward', # CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), # lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) # lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward', # CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), # lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) # lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward', # CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), # lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) # lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.bmm', # CondFunc('torch.bmm',
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), # lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) # lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
CondFunc('torch.cat', # CondFunc('torch.cat',
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), # lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) # lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
CondFunc('torch.nn.functional.scaled_dot_product_attention', # CondFunc('torch.nn.functional.scaled_dot_product_attention',
lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), # lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
lambda orig_func, query, *args, **kwargs: query.is_xpu) # lambda orig_func, query, *args, **kwargs: query.is_xpu)