mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 10:59:47 +00:00
remove files
This commit is contained in:
@@ -1,72 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: modules.xlmr.BertSeriesModelWithTransformation
|
||||
params:
|
||||
name: "XLMR-Large"
|
||||
@@ -1,73 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: modules.xlmr_m18.BertSeriesModelWithTransformation
|
||||
params:
|
||||
name: "XLMR-Large"
|
||||
@@ -1,98 +0,0 @@
|
||||
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
||||
# See more details in LICENSE.
|
||||
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: edited
|
||||
cond_stage_key: edit
|
||||
# image_size: 64
|
||||
# image_size: 32
|
||||
image_size: 16
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: false
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 0 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 128
|
||||
num_workers: 1
|
||||
wrap: false
|
||||
validation:
|
||||
target: edit_dataset.EditDataset
|
||||
params:
|
||||
path: data/clip-filtered-dataset
|
||||
cache_dir: data/
|
||||
cache_name: data_10k
|
||||
split: val
|
||||
min_text_sim: 0.2
|
||||
min_image_sim: 0.75
|
||||
min_direction_sim: 0.2
|
||||
max_samples_per_prompt: 1
|
||||
min_resize_res: 512
|
||||
max_resize_res: 512
|
||||
crop_res: 512
|
||||
output_as_edit: False
|
||||
real_input: True
|
||||
@@ -1,5 +0,0 @@
|
||||
model:
|
||||
target: modules.models.sd3.sd3_model.SD3Inferencer
|
||||
params:
|
||||
shift: 3
|
||||
state_dict: null
|
||||
@@ -1,98 +0,0 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.13025
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
adm_in_channels: 2816
|
||||
num_classes: sequential
|
||||
use_checkpoint: False
|
||||
in_channels: 9
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
||||
context_dim: 2048
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
layer: hidden
|
||||
layer_idx: 11
|
||||
# crossattn and vector cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
params:
|
||||
arch: ViT-bigG-14
|
||||
version: laion2b_s39b_b160k
|
||||
freeze: True
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
legacy: False
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: target_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
@@ -1,70 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
@@ -1,70 +0,0 @@
|
||||
model:
|
||||
base_learning_rate: 7.5e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid # important
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
@@ -1,98 +1,98 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import platform
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
from modules import shared
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||
# use check `getattr` and try it for compatibility.
|
||||
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
||||
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||
def check_for_mps() -> bool:
|
||||
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
|
||||
|
||||
has_mps = check_for_mps()
|
||||
|
||||
|
||||
def torch_mps_gc() -> None:
|
||||
try:
|
||||
if shared.state.current_latent is not None:
|
||||
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||
return
|
||||
from torch.mps import empty_cache
|
||||
empty_cache()
|
||||
except Exception:
|
||||
log.warning("MPS garbage collection failed", exc_info=True)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if output_dtype == torch.int64:
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
||||
try:
|
||||
return orig_func(*args, **kwargs)
|
||||
except RuntimeError as e:
|
||||
if "not implemented for" in str(e) and "Half" in str(e):
|
||||
input_tensor = args[0]
|
||||
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
||||
else:
|
||||
print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||
|
||||
if has_mps:
|
||||
if platform.mac_ver()[0].startswith("13.2."):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||
|
||||
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||
if platform.processor() == 'i386':
|
||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||
# import logging
|
||||
#
|
||||
# import torch
|
||||
# from torch import Tensor
|
||||
# import platform
|
||||
# from modules.sd_hijack_utils import CondFunc
|
||||
# from packaging import version
|
||||
# from modules import shared
|
||||
#
|
||||
# log = logging.getLogger(__name__)
|
||||
#
|
||||
#
|
||||
# # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||
# # use check `getattr` and try it for compatibility.
|
||||
# # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
||||
# # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||
# def check_for_mps() -> bool:
|
||||
# if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||
# if not getattr(torch, 'has_mps', False):
|
||||
# return False
|
||||
# try:
|
||||
# torch.zeros(1).to(torch.device("mps"))
|
||||
# return True
|
||||
# except Exception:
|
||||
# return False
|
||||
# else:
|
||||
# return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
#
|
||||
#
|
||||
# has_mps = check_for_mps()
|
||||
#
|
||||
#
|
||||
# def torch_mps_gc() -> None:
|
||||
# try:
|
||||
# if shared.state.current_latent is not None:
|
||||
# log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||
# return
|
||||
# from torch.mps import empty_cache
|
||||
# empty_cache()
|
||||
# except Exception:
|
||||
# log.warning("MPS garbage collection failed", exc_info=True)
|
||||
#
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
# def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
# if input.device.type == 'mps':
|
||||
# output_dtype = kwargs.get('dtype', input.dtype)
|
||||
# if output_dtype == torch.int64:
|
||||
# return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
# elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
# return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
# return cumsum_func(input, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
# def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
||||
# try:
|
||||
# return orig_func(*args, **kwargs)
|
||||
# except RuntimeError as e:
|
||||
# if "not implemented for" in str(e) and "Half" in str(e):
|
||||
# input_tensor = args[0]
|
||||
# return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
||||
# else:
|
||||
# print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||
#
|
||||
# if has_mps:
|
||||
# if platform.mac_ver()[0].startswith("13.2."):
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
# CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
#
|
||||
# if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
# CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
# lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
# lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
# CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||
# elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
# cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
# cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||
# CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
# CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
# CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||
#
|
||||
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
# CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||
# if platform.processor() == 'i386':
|
||||
# for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||
# CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
import importlib
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
def check_for_npu():
|
||||
if importlib.util.find_spec("torch_npu") is None:
|
||||
return False
|
||||
import torch_npu
|
||||
|
||||
try:
|
||||
# Will raise a RuntimeError if no NPU is found
|
||||
_ = torch_npu.npu.device_count()
|
||||
return torch.npu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
def get_npu_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"npu:{shared.cmd_opts.device_id}"
|
||||
return "npu:0"
|
||||
|
||||
|
||||
def torch_npu_gc():
|
||||
with torch.npu.device(get_npu_device_string()):
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
has_npu = check_for_npu()
|
||||
# import importlib
|
||||
# import torch
|
||||
#
|
||||
# from modules import shared
|
||||
#
|
||||
#
|
||||
# def check_for_npu():
|
||||
# if importlib.util.find_spec("torch_npu") is None:
|
||||
# return False
|
||||
# import torch_npu
|
||||
#
|
||||
# try:
|
||||
# # Will raise a RuntimeError if no NPU is found
|
||||
# _ = torch_npu.npu.device_count()
|
||||
# return torch.npu.is_available()
|
||||
# except RuntimeError:
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# def get_npu_device_string():
|
||||
# if shared.cmd_opts.device_id is not None:
|
||||
# return f"npu:{shared.cmd_opts.device_id}"
|
||||
# return "npu:0"
|
||||
#
|
||||
#
|
||||
# def torch_npu_gc():
|
||||
# with torch.npu.device(get_npu_device_string()):
|
||||
# torch.npu.empty_cache()
|
||||
#
|
||||
#
|
||||
# has_npu = check_for_npu()
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
# original source:
|
||||
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
||||
# license:
|
||||
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
|
||||
# credit:
|
||||
# Amin Rezaei (original author)
|
||||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
|
||||
# implementation of:
|
||||
# Self-attention Does Not Need O(n2) Memory":
|
||||
# https://arxiv.org/abs/2112.05682v2
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
|
||||
def narrow_trunc(
|
||||
input: Tensor,
|
||||
dim: int,
|
||||
start: int,
|
||||
length: int
|
||||
) -> Tensor:
|
||||
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
||||
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
|
||||
class SummarizeChunk:
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
|
||||
class ComputeQueryChunkAttn:
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
) -> AttnChunk:
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
exp_weights = torch.exp(attn_weights - max_score)
|
||||
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
summarize_chunk: SummarizeChunk,
|
||||
kv_chunk_size: int,
|
||||
) -> Tensor:
|
||||
batch_x_heads, k_tokens, k_channels_per_head = key.shape
|
||||
_, _, v_channels_per_head = value.shape
|
||||
|
||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||
key_chunk = narrow_trunc(
|
||||
key,
|
||||
1,
|
||||
chunk_idx,
|
||||
kv_chunk_size
|
||||
)
|
||||
value_chunk = narrow_trunc(
|
||||
value,
|
||||
1,
|
||||
chunk_idx,
|
||||
kv_chunk_size
|
||||
)
|
||||
return summarize_chunk(query, key_chunk, value_chunk)
|
||||
|
||||
chunks: list[AttnChunk] = [
|
||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||
|
||||
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = chunk_values.sum(dim=0)
|
||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
) -> Tensor:
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
|
||||
return hidden_states_slice
|
||||
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
query_chunk_size=1024,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size_min: Optional[int] = None,
|
||||
use_checkpoint=True,
|
||||
):
|
||||
"""Computes efficient dot-product attention given query, key, and value.
|
||||
This is efficient version of attention presented in
|
||||
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
||||
Args:
|
||||
query: queries for calculating attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
key: keys for calculating attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
value: values to be used in attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
query_chunk_size: int: query chunks size
|
||||
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
||||
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
||||
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
||||
Returns:
|
||||
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
||||
"""
|
||||
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
||||
_, k_tokens, _ = key.shape
|
||||
scale = q_channels_per_head ** -0.5
|
||||
|
||||
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
||||
if kv_chunk_size_min is not None:
|
||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||
|
||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||
return narrow_trunc(
|
||||
query,
|
||||
1,
|
||||
chunk_idx,
|
||||
min(query_chunk_size, q_tokens)
|
||||
)
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
_get_attention_scores_no_kv_chunking,
|
||||
scale=scale
|
||||
) if k_tokens <= kv_chunk_size else (
|
||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||
partial(
|
||||
_query_chunk_attention,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
summarize_chunk=summarize_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
if q_tokens <= query_chunk_size:
|
||||
# fast-path for when there's just 1 query chunk
|
||||
return compute_query_chunk_attn(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
res = torch.zeros_like(query)
|
||||
for i in range(math.ceil(q_tokens / query_chunk_size)):
|
||||
attn_scores = compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
|
||||
|
||||
return res
|
||||
280
modules/xlmr.py
280
modules/xlmr.py
@@ -1,140 +1,140 @@
|
||||
from transformers import BertPreTrainedModel, BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
|
||||
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
config.bos_token_id=0
|
||||
config.eos_token_id=2
|
||||
config.hidden_act='gelu'
|
||||
config.hidden_dropout_prob=0.1
|
||||
config.hidden_size=1024
|
||||
config.initializer_range=0.02
|
||||
config.intermediate_size=4096
|
||||
config.layer_norm_eps=1e-05
|
||||
config.max_position_embeddings=514
|
||||
|
||||
config.num_attention_heads=16
|
||||
config.num_hidden_layers=24
|
||||
config.output_past=True
|
||||
config.pad_token_id=1
|
||||
config.position_embedding_type= "absolute"
|
||||
|
||||
config.type_vocab_size= 1
|
||||
config.use_cache=True
|
||||
config.vocab_size= 250002
|
||||
config.project_dim = 768
|
||||
config.learn_encoder = False
|
||||
super().__init__(config)
|
||||
self.roberta = XLMRobertaModel(config)
|
||||
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||
self.pooler = lambda x: x[:,0]
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = torch_utils.get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) :
|
||||
r"""
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# last module outputs
|
||||
sequence_output = outputs[0]
|
||||
|
||||
|
||||
# project every module
|
||||
sequence_output_ln = self.pre_LN(sequence_output)
|
||||
|
||||
# pooler
|
||||
pooler_output = self.pooler(sequence_output_ln)
|
||||
pooler_output = self.transformation(pooler_output)
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
|
||||
return {
|
||||
'pooler_output':pooler_output,
|
||||
'last_hidden_state':outputs.last_hidden_state,
|
||||
'hidden_states':outputs.hidden_states,
|
||||
'attentions':outputs.attentions,
|
||||
'projection_state':projection_state,
|
||||
'sequence_out': sequence_output
|
||||
}
|
||||
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
||||
# from transformers import BertPreTrainedModel, BertConfig
|
||||
# import torch.nn as nn
|
||||
# import torch
|
||||
# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
# from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
# from typing import Optional
|
||||
#
|
||||
# from modules import torch_utils
|
||||
#
|
||||
#
|
||||
# class BertSeriesConfig(BertConfig):
|
||||
# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
#
|
||||
# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||
# self.project_dim = project_dim
|
||||
# self.pooler_fn = pooler_fn
|
||||
# self.learn_encoder = learn_encoder
|
||||
#
|
||||
# class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||
# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
# self.project_dim = project_dim
|
||||
# self.pooler_fn = pooler_fn
|
||||
# self.learn_encoder = learn_encoder
|
||||
#
|
||||
#
|
||||
# class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
#
|
||||
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
# config_class = BertSeriesConfig
|
||||
#
|
||||
# def __init__(self, config=None, **kargs):
|
||||
# # modify initialization for autoloading
|
||||
# if config is None:
|
||||
# config = XLMRobertaConfig()
|
||||
# config.attention_probs_dropout_prob= 0.1
|
||||
# config.bos_token_id=0
|
||||
# config.eos_token_id=2
|
||||
# config.hidden_act='gelu'
|
||||
# config.hidden_dropout_prob=0.1
|
||||
# config.hidden_size=1024
|
||||
# config.initializer_range=0.02
|
||||
# config.intermediate_size=4096
|
||||
# config.layer_norm_eps=1e-05
|
||||
# config.max_position_embeddings=514
|
||||
#
|
||||
# config.num_attention_heads=16
|
||||
# config.num_hidden_layers=24
|
||||
# config.output_past=True
|
||||
# config.pad_token_id=1
|
||||
# config.position_embedding_type= "absolute"
|
||||
#
|
||||
# config.type_vocab_size= 1
|
||||
# config.use_cache=True
|
||||
# config.vocab_size= 250002
|
||||
# config.project_dim = 768
|
||||
# config.learn_encoder = False
|
||||
# super().__init__(config)
|
||||
# self.roberta = XLMRobertaModel(config)
|
||||
# self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||
# self.pooler = lambda x: x[:,0]
|
||||
# self.post_init()
|
||||
#
|
||||
# def encode(self,c):
|
||||
# device = torch_utils.get_param(self).device
|
||||
# text = self.tokenizer(c,
|
||||
# truncation=True,
|
||||
# max_length=77,
|
||||
# return_length=False,
|
||||
# return_overflowing_tokens=False,
|
||||
# padding="max_length",
|
||||
# return_tensors="pt")
|
||||
# text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||
# text["attention_mask"] = torch.tensor(
|
||||
# text['attention_mask']).to(device)
|
||||
# features = self(**text)
|
||||
# return features['projection_state']
|
||||
#
|
||||
# def forward(
|
||||
# self,
|
||||
# input_ids: Optional[torch.Tensor] = None,
|
||||
# attention_mask: Optional[torch.Tensor] = None,
|
||||
# token_type_ids: Optional[torch.Tensor] = None,
|
||||
# position_ids: Optional[torch.Tensor] = None,
|
||||
# head_mask: Optional[torch.Tensor] = None,
|
||||
# inputs_embeds: Optional[torch.Tensor] = None,
|
||||
# encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
# encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
# output_attentions: Optional[bool] = None,
|
||||
# return_dict: Optional[bool] = None,
|
||||
# output_hidden_states: Optional[bool] = None,
|
||||
# ) :
|
||||
# r"""
|
||||
# """
|
||||
#
|
||||
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
#
|
||||
#
|
||||
# outputs = self.roberta(
|
||||
# input_ids=input_ids,
|
||||
# attention_mask=attention_mask,
|
||||
# token_type_ids=token_type_ids,
|
||||
# position_ids=position_ids,
|
||||
# head_mask=head_mask,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# encoder_hidden_states=encoder_hidden_states,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
# output_attentions=output_attentions,
|
||||
# output_hidden_states=True,
|
||||
# return_dict=return_dict,
|
||||
# )
|
||||
#
|
||||
# # last module outputs
|
||||
# sequence_output = outputs[0]
|
||||
#
|
||||
#
|
||||
# # project every module
|
||||
# sequence_output_ln = self.pre_LN(sequence_output)
|
||||
#
|
||||
# # pooler
|
||||
# pooler_output = self.pooler(sequence_output_ln)
|
||||
# pooler_output = self.transformation(pooler_output)
|
||||
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||
#
|
||||
# return {
|
||||
# 'pooler_output':pooler_output,
|
||||
# 'last_hidden_state':outputs.last_hidden_state,
|
||||
# 'hidden_states':outputs.hidden_states,
|
||||
# 'attentions':outputs.attentions,
|
||||
# 'projection_state':projection_state,
|
||||
# 'sequence_out': sequence_output
|
||||
# }
|
||||
#
|
||||
#
|
||||
# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
# base_model_prefix = 'roberta'
|
||||
# config_class= RobertaSeriesConfig
|
||||
|
||||
@@ -1,166 +1,166 @@
|
||||
from transformers import BertPreTrainedModel,BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
|
||||
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
config.bos_token_id=0
|
||||
config.eos_token_id=2
|
||||
config.hidden_act='gelu'
|
||||
config.hidden_dropout_prob=0.1
|
||||
config.hidden_size=1024
|
||||
config.initializer_range=0.02
|
||||
config.intermediate_size=4096
|
||||
config.layer_norm_eps=1e-05
|
||||
config.max_position_embeddings=514
|
||||
|
||||
config.num_attention_heads=16
|
||||
config.num_hidden_layers=24
|
||||
config.output_past=True
|
||||
config.pad_token_id=1
|
||||
config.position_embedding_type= "absolute"
|
||||
|
||||
config.type_vocab_size= 1
|
||||
config.use_cache=True
|
||||
config.vocab_size= 250002
|
||||
config.project_dim = 1024
|
||||
config.learn_encoder = False
|
||||
super().__init__(config)
|
||||
self.roberta = XLMRobertaModel(config)
|
||||
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||
# self.pooler = lambda x: x[:,0]
|
||||
# self.post_init()
|
||||
|
||||
self.has_pre_transformation = True
|
||||
if self.has_pre_transformation:
|
||||
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = torch_utils.get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) :
|
||||
r"""
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# # last module outputs
|
||||
# sequence_output = outputs[0]
|
||||
|
||||
|
||||
# # project every module
|
||||
# sequence_output_ln = self.pre_LN(sequence_output)
|
||||
|
||||
# # pooler
|
||||
# pooler_output = self.pooler(sequence_output_ln)
|
||||
# pooler_output = self.transformation(pooler_output)
|
||||
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||
|
||||
if self.has_pre_transformation:
|
||||
sequence_output2 = outputs["hidden_states"][-2]
|
||||
sequence_output2 = self.pre_LN(sequence_output2)
|
||||
projection_state2 = self.transformation_pre(sequence_output2)
|
||||
|
||||
return {
|
||||
"projection_state": projection_state2,
|
||||
"last_hidden_state": outputs.last_hidden_state,
|
||||
"hidden_states": outputs.hidden_states,
|
||||
"attentions": outputs.attentions,
|
||||
}
|
||||
else:
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
return {
|
||||
"projection_state": projection_state,
|
||||
"last_hidden_state": outputs.last_hidden_state,
|
||||
"hidden_states": outputs.hidden_states,
|
||||
"attentions": outputs.attentions,
|
||||
}
|
||||
|
||||
|
||||
# return {
|
||||
# 'pooler_output':pooler_output,
|
||||
# 'last_hidden_state':outputs.last_hidden_state,
|
||||
# 'hidden_states':outputs.hidden_states,
|
||||
# 'attentions':outputs.attentions,
|
||||
# 'projection_state':projection_state,
|
||||
# 'sequence_out': sequence_output
|
||||
# }
|
||||
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
||||
# from transformers import BertPreTrainedModel,BertConfig
|
||||
# import torch.nn as nn
|
||||
# import torch
|
||||
# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
# from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
# from typing import Optional
|
||||
# from modules import torch_utils
|
||||
#
|
||||
#
|
||||
# class BertSeriesConfig(BertConfig):
|
||||
# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
#
|
||||
# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||
# self.project_dim = project_dim
|
||||
# self.pooler_fn = pooler_fn
|
||||
# self.learn_encoder = learn_encoder
|
||||
#
|
||||
# class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||
# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
# self.project_dim = project_dim
|
||||
# self.pooler_fn = pooler_fn
|
||||
# self.learn_encoder = learn_encoder
|
||||
#
|
||||
#
|
||||
# class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
#
|
||||
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
# config_class = BertSeriesConfig
|
||||
#
|
||||
# def __init__(self, config=None, **kargs):
|
||||
# # modify initialization for autoloading
|
||||
# if config is None:
|
||||
# config = XLMRobertaConfig()
|
||||
# config.attention_probs_dropout_prob= 0.1
|
||||
# config.bos_token_id=0
|
||||
# config.eos_token_id=2
|
||||
# config.hidden_act='gelu'
|
||||
# config.hidden_dropout_prob=0.1
|
||||
# config.hidden_size=1024
|
||||
# config.initializer_range=0.02
|
||||
# config.intermediate_size=4096
|
||||
# config.layer_norm_eps=1e-05
|
||||
# config.max_position_embeddings=514
|
||||
#
|
||||
# config.num_attention_heads=16
|
||||
# config.num_hidden_layers=24
|
||||
# config.output_past=True
|
||||
# config.pad_token_id=1
|
||||
# config.position_embedding_type= "absolute"
|
||||
#
|
||||
# config.type_vocab_size= 1
|
||||
# config.use_cache=True
|
||||
# config.vocab_size= 250002
|
||||
# config.project_dim = 1024
|
||||
# config.learn_encoder = False
|
||||
# super().__init__(config)
|
||||
# self.roberta = XLMRobertaModel(config)
|
||||
# self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||
# # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||
# # self.pooler = lambda x: x[:,0]
|
||||
# # self.post_init()
|
||||
#
|
||||
# self.has_pre_transformation = True
|
||||
# if self.has_pre_transformation:
|
||||
# self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||
# self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# self.post_init()
|
||||
#
|
||||
# def encode(self,c):
|
||||
# device = torch_utils.get_param(self).device
|
||||
# text = self.tokenizer(c,
|
||||
# truncation=True,
|
||||
# max_length=77,
|
||||
# return_length=False,
|
||||
# return_overflowing_tokens=False,
|
||||
# padding="max_length",
|
||||
# return_tensors="pt")
|
||||
# text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||
# text["attention_mask"] = torch.tensor(
|
||||
# text['attention_mask']).to(device)
|
||||
# features = self(**text)
|
||||
# return features['projection_state']
|
||||
#
|
||||
# def forward(
|
||||
# self,
|
||||
# input_ids: Optional[torch.Tensor] = None,
|
||||
# attention_mask: Optional[torch.Tensor] = None,
|
||||
# token_type_ids: Optional[torch.Tensor] = None,
|
||||
# position_ids: Optional[torch.Tensor] = None,
|
||||
# head_mask: Optional[torch.Tensor] = None,
|
||||
# inputs_embeds: Optional[torch.Tensor] = None,
|
||||
# encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
# encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
# output_attentions: Optional[bool] = None,
|
||||
# return_dict: Optional[bool] = None,
|
||||
# output_hidden_states: Optional[bool] = None,
|
||||
# ) :
|
||||
# r"""
|
||||
# """
|
||||
#
|
||||
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
#
|
||||
#
|
||||
# outputs = self.roberta(
|
||||
# input_ids=input_ids,
|
||||
# attention_mask=attention_mask,
|
||||
# token_type_ids=token_type_ids,
|
||||
# position_ids=position_ids,
|
||||
# head_mask=head_mask,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# encoder_hidden_states=encoder_hidden_states,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
# output_attentions=output_attentions,
|
||||
# output_hidden_states=True,
|
||||
# return_dict=return_dict,
|
||||
# )
|
||||
#
|
||||
# # # last module outputs
|
||||
# # sequence_output = outputs[0]
|
||||
#
|
||||
#
|
||||
# # # project every module
|
||||
# # sequence_output_ln = self.pre_LN(sequence_output)
|
||||
#
|
||||
# # # pooler
|
||||
# # pooler_output = self.pooler(sequence_output_ln)
|
||||
# # pooler_output = self.transformation(pooler_output)
|
||||
# # projection_state = self.transformation(outputs.last_hidden_state)
|
||||
#
|
||||
# if self.has_pre_transformation:
|
||||
# sequence_output2 = outputs["hidden_states"][-2]
|
||||
# sequence_output2 = self.pre_LN(sequence_output2)
|
||||
# projection_state2 = self.transformation_pre(sequence_output2)
|
||||
#
|
||||
# return {
|
||||
# "projection_state": projection_state2,
|
||||
# "last_hidden_state": outputs.last_hidden_state,
|
||||
# "hidden_states": outputs.hidden_states,
|
||||
# "attentions": outputs.attentions,
|
||||
# }
|
||||
# else:
|
||||
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||
# return {
|
||||
# "projection_state": projection_state,
|
||||
# "last_hidden_state": outputs.last_hidden_state,
|
||||
# "hidden_states": outputs.hidden_states,
|
||||
# "attentions": outputs.attentions,
|
||||
# }
|
||||
#
|
||||
#
|
||||
# # return {
|
||||
# # 'pooler_output':pooler_output,
|
||||
# # 'last_hidden_state':outputs.last_hidden_state,
|
||||
# # 'hidden_states':outputs.hidden_states,
|
||||
# # 'attentions':outputs.attentions,
|
||||
# # 'projection_state':projection_state,
|
||||
# # 'sequence_out': sequence_output
|
||||
# # }
|
||||
#
|
||||
#
|
||||
# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
# base_model_prefix = 'roberta'
|
||||
# config_class= RobertaSeriesConfig
|
||||
|
||||
@@ -1,138 +1,138 @@
|
||||
from modules import shared
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
||||
has_ipex = False
|
||||
try:
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
has_ipex = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def check_for_xpu():
|
||||
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||||
|
||||
|
||||
def get_xpu_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"xpu:{shared.cmd_opts.device_id}"
|
||||
return "xpu"
|
||||
|
||||
|
||||
def torch_xpu_gc():
|
||||
with torch.xpu.device(get_xpu_device_string()):
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
|
||||
has_xpu = check_for_xpu()
|
||||
|
||||
|
||||
# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
|
||||
# Here we implement a slicing algorithm to split large batch size into smaller chunks,
|
||||
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
|
||||
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
|
||||
# which is the best trade-off between VRAM usage and performance.
|
||||
ARC_SINGLE_ALLOCATION_LIMIT = {}
|
||||
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
|
||||
def torch_xpu_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
|
||||
):
|
||||
# cast to same dtype first
|
||||
key = key.to(query.dtype)
|
||||
value = value.to(query.dtype)
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(query.dtype)
|
||||
|
||||
N = query.shape[:-2] # Batch size
|
||||
L = query.size(-2) # Target sequence length
|
||||
E = query.size(-1) # Embedding dimension of the query and key
|
||||
S = key.size(-2) # Source sequence length
|
||||
Ev = value.size(-1) # Embedding dimension of the value
|
||||
|
||||
total_batch_size = torch.numel(torch.empty(N))
|
||||
device_id = query.device.index
|
||||
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
|
||||
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
|
||||
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
|
||||
|
||||
if total_batch_size <= batch_size_limit:
|
||||
return orig_sdp_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
query = torch.reshape(query, (-1, L, E))
|
||||
key = torch.reshape(key, (-1, S, E))
|
||||
value = torch.reshape(value, (-1, S, Ev))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.view(-1, L, S)
|
||||
chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
|
||||
outputs = []
|
||||
for i in range(chunk_count):
|
||||
attn_mask_chunk = (
|
||||
None
|
||||
if attn_mask is None
|
||||
else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
|
||||
)
|
||||
chunk_output = orig_sdp_attn_func(
|
||||
query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
attn_mask_chunk,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
*args, **kwargs
|
||||
)
|
||||
outputs.append(chunk_output)
|
||||
result = torch.cat(outputs, dim=0)
|
||||
return torch.reshape(result, (*N, L, Ev))
|
||||
|
||||
|
||||
def is_xpu_device(device: str | torch.device = None):
|
||||
if device is None:
|
||||
return False
|
||||
if isinstance(device, str):
|
||||
return device.startswith("xpu")
|
||||
return device.type == "xpu"
|
||||
|
||||
|
||||
if has_xpu:
|
||||
try:
|
||||
# torch.Generator supports "xpu" device since 2.1
|
||||
torch.Generator("xpu")
|
||||
except RuntimeError:
|
||||
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||
lambda orig_func, device=None: is_xpu_device(device))
|
||||
|
||||
# W/A for some OPs that could not handle different input dtypes
|
||||
CondFunc('torch.nn.functional.layer_norm',
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
weight is not None and input.dtype != weight.data.dtype)
|
||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.bmm',
|
||||
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
||||
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
||||
CondFunc('torch.cat',
|
||||
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||
CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||
lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
|
||||
lambda orig_func, query, *args, **kwargs: query.is_xpu)
|
||||
# from modules import shared
|
||||
# from modules.sd_hijack_utils import CondFunc
|
||||
#
|
||||
# has_ipex = False
|
||||
# try:
|
||||
# import torch
|
||||
# import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
# has_ipex = True
|
||||
# except Exception:
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def check_for_xpu():
|
||||
# return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||||
#
|
||||
#
|
||||
# def get_xpu_device_string():
|
||||
# if shared.cmd_opts.device_id is not None:
|
||||
# return f"xpu:{shared.cmd_opts.device_id}"
|
||||
# return "xpu"
|
||||
#
|
||||
#
|
||||
# def torch_xpu_gc():
|
||||
# with torch.xpu.device(get_xpu_device_string()):
|
||||
# torch.xpu.empty_cache()
|
||||
#
|
||||
#
|
||||
# has_xpu = check_for_xpu()
|
||||
#
|
||||
#
|
||||
# # Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
|
||||
# # Here we implement a slicing algorithm to split large batch size into smaller chunks,
|
||||
# # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
|
||||
# # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
|
||||
# # which is the best trade-off between VRAM usage and performance.
|
||||
# ARC_SINGLE_ALLOCATION_LIMIT = {}
|
||||
# orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
|
||||
# def torch_xpu_scaled_dot_product_attention(
|
||||
# query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
|
||||
# ):
|
||||
# # cast to same dtype first
|
||||
# key = key.to(query.dtype)
|
||||
# value = value.to(query.dtype)
|
||||
# if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
# attn_mask = attn_mask.to(query.dtype)
|
||||
#
|
||||
# N = query.shape[:-2] # Batch size
|
||||
# L = query.size(-2) # Target sequence length
|
||||
# E = query.size(-1) # Embedding dimension of the query and key
|
||||
# S = key.size(-2) # Source sequence length
|
||||
# Ev = value.size(-1) # Embedding dimension of the value
|
||||
#
|
||||
# total_batch_size = torch.numel(torch.empty(N))
|
||||
# device_id = query.device.index
|
||||
# if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
|
||||
# ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
|
||||
# batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
|
||||
#
|
||||
# if total_batch_size <= batch_size_limit:
|
||||
# return orig_sdp_attn_func(
|
||||
# query,
|
||||
# key,
|
||||
# value,
|
||||
# attn_mask,
|
||||
# dropout_p,
|
||||
# is_causal,
|
||||
# *args, **kwargs
|
||||
# )
|
||||
#
|
||||
# query = torch.reshape(query, (-1, L, E))
|
||||
# key = torch.reshape(key, (-1, S, E))
|
||||
# value = torch.reshape(value, (-1, S, Ev))
|
||||
# if attn_mask is not None:
|
||||
# attn_mask = attn_mask.view(-1, L, S)
|
||||
# chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
|
||||
# outputs = []
|
||||
# for i in range(chunk_count):
|
||||
# attn_mask_chunk = (
|
||||
# None
|
||||
# if attn_mask is None
|
||||
# else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
|
||||
# )
|
||||
# chunk_output = orig_sdp_attn_func(
|
||||
# query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
# key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
# value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
# attn_mask_chunk,
|
||||
# dropout_p,
|
||||
# is_causal,
|
||||
# *args, **kwargs
|
||||
# )
|
||||
# outputs.append(chunk_output)
|
||||
# result = torch.cat(outputs, dim=0)
|
||||
# return torch.reshape(result, (*N, L, Ev))
|
||||
#
|
||||
#
|
||||
# def is_xpu_device(device: str | torch.device = None):
|
||||
# if device is None:
|
||||
# return False
|
||||
# if isinstance(device, str):
|
||||
# return device.startswith("xpu")
|
||||
# return device.type == "xpu"
|
||||
#
|
||||
#
|
||||
# if has_xpu:
|
||||
# try:
|
||||
# # torch.Generator supports "xpu" device since 2.1
|
||||
# torch.Generator("xpu")
|
||||
# except RuntimeError:
|
||||
# # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
|
||||
# CondFunc('torch.Generator',
|
||||
# lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||
# lambda orig_func, device=None: is_xpu_device(device))
|
||||
#
|
||||
# # W/A for some OPs that could not handle different input dtypes
|
||||
# CondFunc('torch.nn.functional.layer_norm',
|
||||
# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
# orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||
# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
# weight is not None and input.dtype != weight.data.dtype)
|
||||
# CondFunc('torch.nn.modules.GroupNorm.forward',
|
||||
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# CondFunc('torch.bmm',
|
||||
# lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
||||
# lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
||||
# CondFunc('torch.cat',
|
||||
# lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||
# lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||
# CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||
# lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
|
||||
# lambda orig_func, query, *args, **kwargs: query.is_xpu)
|
||||
|
||||
Reference in New Issue
Block a user