remove files

This commit is contained in:
layerdiffusion
2024-08-05 21:24:35 -07:00
parent ae1d995d0d
commit 37e656d5aa
13 changed files with 573 additions and 1274 deletions

View File

@@ -1,72 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"

View File

@@ -1,73 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr_m18.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"

View File

@@ -1,98 +0,0 @@
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: false
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True

View File

@@ -1,5 +0,0 @@
model:
target: modules.models.sd3.sd3_model.SD3Inferencer
params:
shift: 3
state_dict: null

View File

@@ -1,98 +0,0 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: False
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -1,70 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@@ -1,70 +0,0 @@
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@@ -1,98 +1,98 @@
import logging
import torch
from torch import Tensor
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
from modules import shared
log = logging.getLogger(__name__)
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
if version.parse(torch.__version__) <= version.parse("2.0.1"):
if not getattr(torch, 'has_mps', False):
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps()
def torch_mps_gc() -> None:
try:
if shared.state.current_latent is not None:
log.debug("`current_latent` is set, skipping MPS garbage collection")
return
from torch.mps import empty_cache
empty_cache()
except Exception:
log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
try:
return orig_func(*args, **kwargs)
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
input_tensor = args[0]
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")
if has_mps:
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
# import logging
#
# import torch
# from torch import Tensor
# import platform
# from modules.sd_hijack_utils import CondFunc
# from packaging import version
# from modules import shared
#
# log = logging.getLogger(__name__)
#
#
# # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# # use check `getattr` and try it for compatibility.
# # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
# # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
# def check_for_mps() -> bool:
# if version.parse(torch.__version__) <= version.parse("2.0.1"):
# if not getattr(torch, 'has_mps', False):
# return False
# try:
# torch.zeros(1).to(torch.device("mps"))
# return True
# except Exception:
# return False
# else:
# return torch.backends.mps.is_available() and torch.backends.mps.is_built()
#
#
# has_mps = check_for_mps()
#
#
# def torch_mps_gc() -> None:
# try:
# if shared.state.current_latent is not None:
# log.debug("`current_latent` is set, skipping MPS garbage collection")
# return
# from torch.mps import empty_cache
# empty_cache()
# except Exception:
# log.warning("MPS garbage collection failed", exc_info=True)
#
#
# # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
# def cumsum_fix(input, cumsum_func, *args, **kwargs):
# if input.device.type == 'mps':
# output_dtype = kwargs.get('dtype', input.dtype)
# if output_dtype == torch.int64:
# return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
# elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
# return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
# return cumsum_func(input, *args, **kwargs)
#
#
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
# def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
# try:
# return orig_func(*args, **kwargs)
# except RuntimeError as e:
# if "not implemented for" in str(e) and "Half" in str(e):
# input_tensor = args[0]
# return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
# else:
# print(f"An unexpected RuntimeError occurred: {str(e)}")
#
# if has_mps:
# if platform.mac_ver()[0].startswith("13.2."):
# # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
# CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
#
# if version.parse(torch.__version__) < version.parse("1.13"):
# # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
#
# # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
# CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
# lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
# lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
# CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
# elif version.parse(torch.__version__) > version.parse("1.13.1"):
# cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
# cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
# CondFunc('torch.cumsum', cumsum_fix_func, None)
# CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
# CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
#
# # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
#
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
# CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
#
# # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
# if platform.processor() == 'i386':
# for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
# CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')

View File

@@ -1,31 +1,31 @@
import importlib
import torch
from modules import shared
def check_for_npu():
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu
try:
# Will raise a RuntimeError if no NPU is found
_ = torch_npu.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
def get_npu_device_string():
if shared.cmd_opts.device_id is not None:
return f"npu:{shared.cmd_opts.device_id}"
return "npu:0"
def torch_npu_gc():
with torch.npu.device(get_npu_device_string()):
torch.npu.empty_cache()
has_npu = check_for_npu()
# import importlib
# import torch
#
# from modules import shared
#
#
# def check_for_npu():
# if importlib.util.find_spec("torch_npu") is None:
# return False
# import torch_npu
#
# try:
# # Will raise a RuntimeError if no NPU is found
# _ = torch_npu.npu.device_count()
# return torch.npu.is_available()
# except RuntimeError:
# return False
#
#
# def get_npu_device_string():
# if shared.cmd_opts.device_id is not None:
# return f"npu:{shared.cmd_opts.device_id}"
# return "npu:0"
#
#
# def torch_npu_gc():
# with torch.npu.device(get_npu_device_string()):
# torch.npu.empty_cache()
#
#
# has_npu = check_for_npu()

View File

@@ -1,215 +0,0 @@
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple
def narrow_trunc(
input: Tensor,
dim: int,
start: int,
length: int
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
) -> Tensor:
batch_x_heads, k_tokens, k_channels_per_head = key.shape
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
key_chunk = narrow_trunc(
key,
1,
chunk_idx,
kv_chunk_size
)
value_chunk = narrow_trunc(
value,
1,
chunk_idx,
kv_chunk_size
)
return summarize_chunk(query, key_chunk, value_chunk)
chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= torch.unsqueeze(max_diffs, -1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(dim=0)
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
This is efficient version of attention presented in
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
Args:
query: queries for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
key: keys for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
value: values to be used in attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
"""
batch_x_heads, q_tokens, q_channels_per_head = query.shape
_, k_tokens, _ = key.shape
scale = q_channels_per_head ** -0.5
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
def get_query_chunk(chunk_idx: int) -> Tensor:
return narrow_trunc(
query,
1,
chunk_idx,
min(query_chunk_size, q_tokens)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(
_query_chunk_attention,
kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk,
)
)
if q_tokens <= query_chunk_size:
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn(
query=query,
key=key,
value=value,
)
res = torch.zeros_like(query)
for i in range(math.ceil(q_tokens / query_chunk_size)):
attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
)
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
return res

View File

@@ -1,140 +1,140 @@
from transformers import BertPreTrainedModel, BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
from modules import torch_utils
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class BertSeriesModelWithTransformation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
# modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0
config.eos_token_id=2
config.hidden_act='gelu'
config.hidden_dropout_prob=0.1
config.hidden_size=1024
config.initializer_range=0.02
config.intermediate_size=4096
config.layer_norm_eps=1e-05
config.max_position_embeddings=514
config.num_attention_heads=16
config.num_hidden_layers=24
config.output_past=True
config.pad_token_id=1
config.position_embedding_type= "absolute"
config.type_vocab_size= 1
config.use_cache=True
config.vocab_size= 250002
config.project_dim = 768
config.learn_encoder = False
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
self.pooler = lambda x: x[:,0]
self.post_init()
def encode(self,c):
device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
return features['projection_state']
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) :
r"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
# last module outputs
sequence_output = outputs[0]
# project every module
sequence_output_ln = self.pre_LN(sequence_output)
# pooler
pooler_output = self.pooler(sequence_output_ln)
pooler_output = self.transformation(pooler_output)
projection_state = self.transformation(outputs.last_hidden_state)
return {
'pooler_output':pooler_output,
'last_hidden_state':outputs.last_hidden_state,
'hidden_states':outputs.hidden_states,
'attentions':outputs.attentions,
'projection_state':projection_state,
'sequence_out': sequence_output
}
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig
# from transformers import BertPreTrainedModel, BertConfig
# import torch.nn as nn
# import torch
# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
# from transformers import XLMRobertaModel,XLMRobertaTokenizer
# from typing import Optional
#
# from modules import torch_utils
#
#
# class BertSeriesConfig(BertConfig):
# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
#
# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
# self.project_dim = project_dim
# self.pooler_fn = pooler_fn
# self.learn_encoder = learn_encoder
#
# class RobertaSeriesConfig(XLMRobertaConfig):
# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
# self.project_dim = project_dim
# self.pooler_fn = pooler_fn
# self.learn_encoder = learn_encoder
#
#
# class BertSeriesModelWithTransformation(BertPreTrainedModel):
#
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
# config_class = BertSeriesConfig
#
# def __init__(self, config=None, **kargs):
# # modify initialization for autoloading
# if config is None:
# config = XLMRobertaConfig()
# config.attention_probs_dropout_prob= 0.1
# config.bos_token_id=0
# config.eos_token_id=2
# config.hidden_act='gelu'
# config.hidden_dropout_prob=0.1
# config.hidden_size=1024
# config.initializer_range=0.02
# config.intermediate_size=4096
# config.layer_norm_eps=1e-05
# config.max_position_embeddings=514
#
# config.num_attention_heads=16
# config.num_hidden_layers=24
# config.output_past=True
# config.pad_token_id=1
# config.position_embedding_type= "absolute"
#
# config.type_vocab_size= 1
# config.use_cache=True
# config.vocab_size= 250002
# config.project_dim = 768
# config.learn_encoder = False
# super().__init__(config)
# self.roberta = XLMRobertaModel(config)
# self.transformation = nn.Linear(config.hidden_size,config.project_dim)
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
# self.pooler = lambda x: x[:,0]
# self.post_init()
#
# def encode(self,c):
# device = torch_utils.get_param(self).device
# text = self.tokenizer(c,
# truncation=True,
# max_length=77,
# return_length=False,
# return_overflowing_tokens=False,
# padding="max_length",
# return_tensors="pt")
# text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
# text["attention_mask"] = torch.tensor(
# text['attention_mask']).to(device)
# features = self(**text)
# return features['projection_state']
#
# def forward(
# self,
# input_ids: Optional[torch.Tensor] = None,
# attention_mask: Optional[torch.Tensor] = None,
# token_type_ids: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.Tensor] = None,
# head_mask: Optional[torch.Tensor] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# encoder_hidden_states: Optional[torch.Tensor] = None,
# encoder_attention_mask: Optional[torch.Tensor] = None,
# output_attentions: Optional[bool] = None,
# return_dict: Optional[bool] = None,
# output_hidden_states: Optional[bool] = None,
# ) :
# r"""
# """
#
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
#
#
# outputs = self.roberta(
# input_ids=input_ids,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids,
# position_ids=position_ids,
# head_mask=head_mask,
# inputs_embeds=inputs_embeds,
# encoder_hidden_states=encoder_hidden_states,
# encoder_attention_mask=encoder_attention_mask,
# output_attentions=output_attentions,
# output_hidden_states=True,
# return_dict=return_dict,
# )
#
# # last module outputs
# sequence_output = outputs[0]
#
#
# # project every module
# sequence_output_ln = self.pre_LN(sequence_output)
#
# # pooler
# pooler_output = self.pooler(sequence_output_ln)
# pooler_output = self.transformation(pooler_output)
# projection_state = self.transformation(outputs.last_hidden_state)
#
# return {
# 'pooler_output':pooler_output,
# 'last_hidden_state':outputs.last_hidden_state,
# 'hidden_states':outputs.hidden_states,
# 'attentions':outputs.attentions,
# 'projection_state':projection_state,
# 'sequence_out': sequence_output
# }
#
#
# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
# base_model_prefix = 'roberta'
# config_class= RobertaSeriesConfig

View File

@@ -1,166 +1,166 @@
from transformers import BertPreTrainedModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
from modules import torch_utils
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class BertSeriesModelWithTransformation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
# modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0
config.eos_token_id=2
config.hidden_act='gelu'
config.hidden_dropout_prob=0.1
config.hidden_size=1024
config.initializer_range=0.02
config.intermediate_size=4096
config.layer_norm_eps=1e-05
config.max_position_embeddings=514
config.num_attention_heads=16
config.num_hidden_layers=24
config.output_past=True
config.pad_token_id=1
config.position_embedding_type= "absolute"
config.type_vocab_size= 1
config.use_cache=True
config.vocab_size= 250002
config.project_dim = 1024
config.learn_encoder = False
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
# self.pooler = lambda x: x[:,0]
# self.post_init()
self.has_pre_transformation = True
if self.has_pre_transformation:
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def encode(self,c):
device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
return features['projection_state']
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) :
r"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
# # last module outputs
# sequence_output = outputs[0]
# # project every module
# sequence_output_ln = self.pre_LN(sequence_output)
# # pooler
# pooler_output = self.pooler(sequence_output_ln)
# pooler_output = self.transformation(pooler_output)
# projection_state = self.transformation(outputs.last_hidden_state)
if self.has_pre_transformation:
sequence_output2 = outputs["hidden_states"][-2]
sequence_output2 = self.pre_LN(sequence_output2)
projection_state2 = self.transformation_pre(sequence_output2)
return {
"projection_state": projection_state2,
"last_hidden_state": outputs.last_hidden_state,
"hidden_states": outputs.hidden_states,
"attentions": outputs.attentions,
}
else:
projection_state = self.transformation(outputs.last_hidden_state)
return {
"projection_state": projection_state,
"last_hidden_state": outputs.last_hidden_state,
"hidden_states": outputs.hidden_states,
"attentions": outputs.attentions,
}
# return {
# 'pooler_output':pooler_output,
# 'last_hidden_state':outputs.last_hidden_state,
# 'hidden_states':outputs.hidden_states,
# 'attentions':outputs.attentions,
# 'projection_state':projection_state,
# 'sequence_out': sequence_output
# }
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig
# from transformers import BertPreTrainedModel,BertConfig
# import torch.nn as nn
# import torch
# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
# from transformers import XLMRobertaModel,XLMRobertaTokenizer
# from typing import Optional
# from modules import torch_utils
#
#
# class BertSeriesConfig(BertConfig):
# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
#
# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
# self.project_dim = project_dim
# self.pooler_fn = pooler_fn
# self.learn_encoder = learn_encoder
#
# class RobertaSeriesConfig(XLMRobertaConfig):
# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
# self.project_dim = project_dim
# self.pooler_fn = pooler_fn
# self.learn_encoder = learn_encoder
#
#
# class BertSeriesModelWithTransformation(BertPreTrainedModel):
#
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
# config_class = BertSeriesConfig
#
# def __init__(self, config=None, **kargs):
# # modify initialization for autoloading
# if config is None:
# config = XLMRobertaConfig()
# config.attention_probs_dropout_prob= 0.1
# config.bos_token_id=0
# config.eos_token_id=2
# config.hidden_act='gelu'
# config.hidden_dropout_prob=0.1
# config.hidden_size=1024
# config.initializer_range=0.02
# config.intermediate_size=4096
# config.layer_norm_eps=1e-05
# config.max_position_embeddings=514
#
# config.num_attention_heads=16
# config.num_hidden_layers=24
# config.output_past=True
# config.pad_token_id=1
# config.position_embedding_type= "absolute"
#
# config.type_vocab_size= 1
# config.use_cache=True
# config.vocab_size= 250002
# config.project_dim = 1024
# config.learn_encoder = False
# super().__init__(config)
# self.roberta = XLMRobertaModel(config)
# self.transformation = nn.Linear(config.hidden_size,config.project_dim)
# # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
# # self.pooler = lambda x: x[:,0]
# # self.post_init()
#
# self.has_pre_transformation = True
# if self.has_pre_transformation:
# self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
# self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# self.post_init()
#
# def encode(self,c):
# device = torch_utils.get_param(self).device
# text = self.tokenizer(c,
# truncation=True,
# max_length=77,
# return_length=False,
# return_overflowing_tokens=False,
# padding="max_length",
# return_tensors="pt")
# text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
# text["attention_mask"] = torch.tensor(
# text['attention_mask']).to(device)
# features = self(**text)
# return features['projection_state']
#
# def forward(
# self,
# input_ids: Optional[torch.Tensor] = None,
# attention_mask: Optional[torch.Tensor] = None,
# token_type_ids: Optional[torch.Tensor] = None,
# position_ids: Optional[torch.Tensor] = None,
# head_mask: Optional[torch.Tensor] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# encoder_hidden_states: Optional[torch.Tensor] = None,
# encoder_attention_mask: Optional[torch.Tensor] = None,
# output_attentions: Optional[bool] = None,
# return_dict: Optional[bool] = None,
# output_hidden_states: Optional[bool] = None,
# ) :
# r"""
# """
#
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
#
#
# outputs = self.roberta(
# input_ids=input_ids,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids,
# position_ids=position_ids,
# head_mask=head_mask,
# inputs_embeds=inputs_embeds,
# encoder_hidden_states=encoder_hidden_states,
# encoder_attention_mask=encoder_attention_mask,
# output_attentions=output_attentions,
# output_hidden_states=True,
# return_dict=return_dict,
# )
#
# # # last module outputs
# # sequence_output = outputs[0]
#
#
# # # project every module
# # sequence_output_ln = self.pre_LN(sequence_output)
#
# # # pooler
# # pooler_output = self.pooler(sequence_output_ln)
# # pooler_output = self.transformation(pooler_output)
# # projection_state = self.transformation(outputs.last_hidden_state)
#
# if self.has_pre_transformation:
# sequence_output2 = outputs["hidden_states"][-2]
# sequence_output2 = self.pre_LN(sequence_output2)
# projection_state2 = self.transformation_pre(sequence_output2)
#
# return {
# "projection_state": projection_state2,
# "last_hidden_state": outputs.last_hidden_state,
# "hidden_states": outputs.hidden_states,
# "attentions": outputs.attentions,
# }
# else:
# projection_state = self.transformation(outputs.last_hidden_state)
# return {
# "projection_state": projection_state,
# "last_hidden_state": outputs.last_hidden_state,
# "hidden_states": outputs.hidden_states,
# "attentions": outputs.attentions,
# }
#
#
# # return {
# # 'pooler_output':pooler_output,
# # 'last_hidden_state':outputs.last_hidden_state,
# # 'hidden_states':outputs.hidden_states,
# # 'attentions':outputs.attentions,
# # 'projection_state':projection_state,
# # 'sequence_out': sequence_output
# # }
#
#
# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
# base_model_prefix = 'roberta'
# config_class= RobertaSeriesConfig

View File

@@ -1,138 +1,138 @@
from modules import shared
from modules.sd_hijack_utils import CondFunc
has_ipex = False
try:
import torch
import intel_extension_for_pytorch as ipex # noqa: F401
has_ipex = True
except Exception:
pass
def check_for_xpu():
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
def get_xpu_device_string():
if shared.cmd_opts.device_id is not None:
return f"xpu:{shared.cmd_opts.device_id}"
return "xpu"
def torch_xpu_gc():
with torch.xpu.device(get_xpu_device_string()):
torch.xpu.empty_cache()
has_xpu = check_for_xpu()
# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
# Here we implement a slicing algorithm to split large batch size into smaller chunks,
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
# which is the best trade-off between VRAM usage and performance.
ARC_SINGLE_ALLOCATION_LIMIT = {}
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
def torch_xpu_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
):
# cast to same dtype first
key = key.to(query.dtype)
value = value.to(query.dtype)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(query.dtype)
N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length
E = query.size(-1) # Embedding dimension of the query and key
S = key.size(-2) # Source sequence length
Ev = value.size(-1) # Embedding dimension of the value
total_batch_size = torch.numel(torch.empty(N))
device_id = query.device.index
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
if total_batch_size <= batch_size_limit:
return orig_sdp_attn_func(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
*args, **kwargs
)
query = torch.reshape(query, (-1, L, E))
key = torch.reshape(key, (-1, S, E))
value = torch.reshape(value, (-1, S, Ev))
if attn_mask is not None:
attn_mask = attn_mask.view(-1, L, S)
chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
outputs = []
for i in range(chunk_count):
attn_mask_chunk = (
None
if attn_mask is None
else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
)
chunk_output = orig_sdp_attn_func(
query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
attn_mask_chunk,
dropout_p,
is_causal,
*args, **kwargs
)
outputs.append(chunk_output)
result = torch.cat(outputs, dim=0)
return torch.reshape(result, (*N, L, Ev))
def is_xpu_device(device: str | torch.device = None):
if device is None:
return False
if isinstance(device, str):
return device.startswith("xpu")
return device.type == "xpu"
if has_xpu:
try:
# torch.Generator supports "xpu" device since 2.1
torch.Generator("xpu")
except RuntimeError:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: is_xpu_device(device))
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.bmm',
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
CondFunc('torch.cat',
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
CondFunc('torch.nn.functional.scaled_dot_product_attention',
lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
lambda orig_func, query, *args, **kwargs: query.is_xpu)
# from modules import shared
# from modules.sd_hijack_utils import CondFunc
#
# has_ipex = False
# try:
# import torch
# import intel_extension_for_pytorch as ipex # noqa: F401
# has_ipex = True
# except Exception:
# pass
#
#
# def check_for_xpu():
# return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
#
#
# def get_xpu_device_string():
# if shared.cmd_opts.device_id is not None:
# return f"xpu:{shared.cmd_opts.device_id}"
# return "xpu"
#
#
# def torch_xpu_gc():
# with torch.xpu.device(get_xpu_device_string()):
# torch.xpu.empty_cache()
#
#
# has_xpu = check_for_xpu()
#
#
# # Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
# # Here we implement a slicing algorithm to split large batch size into smaller chunks,
# # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
# # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
# # which is the best trade-off between VRAM usage and performance.
# ARC_SINGLE_ALLOCATION_LIMIT = {}
# orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
# def torch_xpu_scaled_dot_product_attention(
# query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
# ):
# # cast to same dtype first
# key = key.to(query.dtype)
# value = value.to(query.dtype)
# if attn_mask is not None and attn_mask.dtype != torch.bool:
# attn_mask = attn_mask.to(query.dtype)
#
# N = query.shape[:-2] # Batch size
# L = query.size(-2) # Target sequence length
# E = query.size(-1) # Embedding dimension of the query and key
# S = key.size(-2) # Source sequence length
# Ev = value.size(-1) # Embedding dimension of the value
#
# total_batch_size = torch.numel(torch.empty(N))
# device_id = query.device.index
# if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
# ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
# batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
#
# if total_batch_size <= batch_size_limit:
# return orig_sdp_attn_func(
# query,
# key,
# value,
# attn_mask,
# dropout_p,
# is_causal,
# *args, **kwargs
# )
#
# query = torch.reshape(query, (-1, L, E))
# key = torch.reshape(key, (-1, S, E))
# value = torch.reshape(value, (-1, S, Ev))
# if attn_mask is not None:
# attn_mask = attn_mask.view(-1, L, S)
# chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
# outputs = []
# for i in range(chunk_count):
# attn_mask_chunk = (
# None
# if attn_mask is None
# else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
# )
# chunk_output = orig_sdp_attn_func(
# query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
# key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
# value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
# attn_mask_chunk,
# dropout_p,
# is_causal,
# *args, **kwargs
# )
# outputs.append(chunk_output)
# result = torch.cat(outputs, dim=0)
# return torch.reshape(result, (*N, L, Ev))
#
#
# def is_xpu_device(device: str | torch.device = None):
# if device is None:
# return False
# if isinstance(device, str):
# return device.startswith("xpu")
# return device.type == "xpu"
#
#
# if has_xpu:
# try:
# # torch.Generator supports "xpu" device since 2.1
# torch.Generator("xpu")
# except RuntimeError:
# # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
# CondFunc('torch.Generator',
# lambda orig_func, device=None: torch.xpu.Generator(device),
# lambda orig_func, device=None: is_xpu_device(device))
#
# # W/A for some OPs that could not handle different input dtypes
# CondFunc('torch.nn.functional.layer_norm',
# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
# orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
# weight is not None and input.dtype != weight.data.dtype)
# CondFunc('torch.nn.modules.GroupNorm.forward',
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# CondFunc('torch.nn.modules.linear.Linear.forward',
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# CondFunc('torch.nn.modules.conv.Conv2d.forward',
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# CondFunc('torch.bmm',
# lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
# lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
# CondFunc('torch.cat',
# lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
# lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
# CondFunc('torch.nn.functional.scaled_dot_product_attention',
# lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
# lambda orig_func, query, *args, **kwargs: query.is_xpu)