mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
remove files
This commit is contained in:
@@ -1,72 +0,0 @@
|
|||||||
model:
|
|
||||||
base_learning_rate: 1.0e-04
|
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
|
||||||
params:
|
|
||||||
linear_start: 0.00085
|
|
||||||
linear_end: 0.0120
|
|
||||||
num_timesteps_cond: 1
|
|
||||||
log_every_t: 200
|
|
||||||
timesteps: 1000
|
|
||||||
first_stage_key: "jpg"
|
|
||||||
cond_stage_key: "txt"
|
|
||||||
image_size: 64
|
|
||||||
channels: 4
|
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
|
||||||
conditioning_key: crossattn
|
|
||||||
monitor: val/loss_simple_ema
|
|
||||||
scale_factor: 0.18215
|
|
||||||
use_ema: False
|
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
|
||||||
params:
|
|
||||||
warm_up_steps: [ 10000 ]
|
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
|
||||||
f_start: [ 1.e-6 ]
|
|
||||||
f_max: [ 1. ]
|
|
||||||
f_min: [ 1. ]
|
|
||||||
|
|
||||||
unet_config:
|
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
image_size: 32 # unused
|
|
||||||
in_channels: 4
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
|
||||||
num_heads: 8
|
|
||||||
use_spatial_transformer: True
|
|
||||||
transformer_depth: 1
|
|
||||||
context_dim: 768
|
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult:
|
|
||||||
- 1
|
|
||||||
- 2
|
|
||||||
- 4
|
|
||||||
- 4
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
|
|
||||||
cond_stage_config:
|
|
||||||
target: modules.xlmr.BertSeriesModelWithTransformation
|
|
||||||
params:
|
|
||||||
name: "XLMR-Large"
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
model:
|
|
||||||
base_learning_rate: 1.0e-04
|
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
|
||||||
params:
|
|
||||||
linear_start: 0.00085
|
|
||||||
linear_end: 0.0120
|
|
||||||
num_timesteps_cond: 1
|
|
||||||
log_every_t: 200
|
|
||||||
timesteps: 1000
|
|
||||||
first_stage_key: "jpg"
|
|
||||||
cond_stage_key: "txt"
|
|
||||||
image_size: 64
|
|
||||||
channels: 4
|
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
|
||||||
conditioning_key: crossattn
|
|
||||||
monitor: val/loss_simple_ema
|
|
||||||
scale_factor: 0.18215
|
|
||||||
use_ema: False
|
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
|
||||||
params:
|
|
||||||
warm_up_steps: [ 10000 ]
|
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
|
||||||
f_start: [ 1.e-6 ]
|
|
||||||
f_max: [ 1. ]
|
|
||||||
f_min: [ 1. ]
|
|
||||||
|
|
||||||
unet_config:
|
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
image_size: 32 # unused
|
|
||||||
in_channels: 4
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
|
||||||
num_head_channels: 64
|
|
||||||
use_spatial_transformer: True
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
transformer_depth: 1
|
|
||||||
context_dim: 1024
|
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult:
|
|
||||||
- 1
|
|
||||||
- 2
|
|
||||||
- 4
|
|
||||||
- 4
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
|
|
||||||
cond_stage_config:
|
|
||||||
target: modules.xlmr_m18.BertSeriesModelWithTransformation
|
|
||||||
params:
|
|
||||||
name: "XLMR-Large"
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
|
||||||
# See more details in LICENSE.
|
|
||||||
|
|
||||||
model:
|
|
||||||
base_learning_rate: 1.0e-04
|
|
||||||
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
|
||||||
params:
|
|
||||||
linear_start: 0.00085
|
|
||||||
linear_end: 0.0120
|
|
||||||
num_timesteps_cond: 1
|
|
||||||
log_every_t: 200
|
|
||||||
timesteps: 1000
|
|
||||||
first_stage_key: edited
|
|
||||||
cond_stage_key: edit
|
|
||||||
# image_size: 64
|
|
||||||
# image_size: 32
|
|
||||||
image_size: 16
|
|
||||||
channels: 4
|
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
|
||||||
conditioning_key: hybrid
|
|
||||||
monitor: val/loss_simple_ema
|
|
||||||
scale_factor: 0.18215
|
|
||||||
use_ema: false
|
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
|
||||||
params:
|
|
||||||
warm_up_steps: [ 0 ]
|
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
|
||||||
f_start: [ 1.e-6 ]
|
|
||||||
f_max: [ 1. ]
|
|
||||||
f_min: [ 1. ]
|
|
||||||
|
|
||||||
unet_config:
|
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
image_size: 32 # unused
|
|
||||||
in_channels: 8
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
|
||||||
num_heads: 8
|
|
||||||
use_spatial_transformer: True
|
|
||||||
transformer_depth: 1
|
|
||||||
context_dim: 768
|
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult:
|
|
||||||
- 1
|
|
||||||
- 2
|
|
||||||
- 4
|
|
||||||
- 4
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
|
|
||||||
cond_stage_config:
|
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
||||||
|
|
||||||
data:
|
|
||||||
target: main.DataModuleFromConfig
|
|
||||||
params:
|
|
||||||
batch_size: 128
|
|
||||||
num_workers: 1
|
|
||||||
wrap: false
|
|
||||||
validation:
|
|
||||||
target: edit_dataset.EditDataset
|
|
||||||
params:
|
|
||||||
path: data/clip-filtered-dataset
|
|
||||||
cache_dir: data/
|
|
||||||
cache_name: data_10k
|
|
||||||
split: val
|
|
||||||
min_text_sim: 0.2
|
|
||||||
min_image_sim: 0.75
|
|
||||||
min_direction_sim: 0.2
|
|
||||||
max_samples_per_prompt: 1
|
|
||||||
min_resize_res: 512
|
|
||||||
max_resize_res: 512
|
|
||||||
crop_res: 512
|
|
||||||
output_as_edit: False
|
|
||||||
real_input: True
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
model:
|
|
||||||
target: modules.models.sd3.sd3_model.SD3Inferencer
|
|
||||||
params:
|
|
||||||
shift: 3
|
|
||||||
state_dict: null
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
model:
|
|
||||||
target: sgm.models.diffusion.DiffusionEngine
|
|
||||||
params:
|
|
||||||
scale_factor: 0.13025
|
|
||||||
disable_first_stage_autocast: True
|
|
||||||
|
|
||||||
denoiser_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
|
||||||
params:
|
|
||||||
num_idx: 1000
|
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
|
||||||
scaling_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
|
||||||
discretization_config:
|
|
||||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
|
||||||
|
|
||||||
network_config:
|
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
adm_in_channels: 2816
|
|
||||||
num_classes: sequential
|
|
||||||
use_checkpoint: False
|
|
||||||
in_channels: 9
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [4, 2]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [1, 2, 4]
|
|
||||||
num_head_channels: 64
|
|
||||||
use_spatial_transformer: True
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
|
||||||
context_dim: 2048
|
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
conditioner_config:
|
|
||||||
target: sgm.modules.GeneralConditioner
|
|
||||||
params:
|
|
||||||
emb_models:
|
|
||||||
# crossattn cond
|
|
||||||
- is_trainable: False
|
|
||||||
input_key: txt
|
|
||||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
||||||
params:
|
|
||||||
layer: hidden
|
|
||||||
layer_idx: 11
|
|
||||||
# crossattn and vector cond
|
|
||||||
- is_trainable: False
|
|
||||||
input_key: txt
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
|
||||||
params:
|
|
||||||
arch: ViT-bigG-14
|
|
||||||
version: laion2b_s39b_b160k
|
|
||||||
freeze: True
|
|
||||||
layer: penultimate
|
|
||||||
always_return_pooled: True
|
|
||||||
legacy: False
|
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
|
||||||
input_key: original_size_as_tuple
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 256 # multiplied by two
|
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
|
||||||
input_key: crop_coords_top_left
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 256 # multiplied by two
|
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
|
||||||
input_key: target_size_as_tuple
|
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
params:
|
|
||||||
outdim: 256 # multiplied by two
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
attn_type: vanilla-xformers
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
model:
|
|
||||||
base_learning_rate: 1.0e-04
|
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
|
||||||
params:
|
|
||||||
linear_start: 0.00085
|
|
||||||
linear_end: 0.0120
|
|
||||||
num_timesteps_cond: 1
|
|
||||||
log_every_t: 200
|
|
||||||
timesteps: 1000
|
|
||||||
first_stage_key: "jpg"
|
|
||||||
cond_stage_key: "txt"
|
|
||||||
image_size: 64
|
|
||||||
channels: 4
|
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
|
||||||
conditioning_key: crossattn
|
|
||||||
monitor: val/loss_simple_ema
|
|
||||||
scale_factor: 0.18215
|
|
||||||
use_ema: False
|
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
|
||||||
params:
|
|
||||||
warm_up_steps: [ 10000 ]
|
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
|
||||||
f_start: [ 1.e-6 ]
|
|
||||||
f_max: [ 1. ]
|
|
||||||
f_min: [ 1. ]
|
|
||||||
|
|
||||||
unet_config:
|
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
image_size: 32 # unused
|
|
||||||
in_channels: 4
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
|
||||||
num_heads: 8
|
|
||||||
use_spatial_transformer: True
|
|
||||||
transformer_depth: 1
|
|
||||||
context_dim: 768
|
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult:
|
|
||||||
- 1
|
|
||||||
- 2
|
|
||||||
- 4
|
|
||||||
- 4
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
|
|
||||||
cond_stage_config:
|
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
model:
|
|
||||||
base_learning_rate: 7.5e-05
|
|
||||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
|
||||||
params:
|
|
||||||
linear_start: 0.00085
|
|
||||||
linear_end: 0.0120
|
|
||||||
num_timesteps_cond: 1
|
|
||||||
log_every_t: 200
|
|
||||||
timesteps: 1000
|
|
||||||
first_stage_key: "jpg"
|
|
||||||
cond_stage_key: "txt"
|
|
||||||
image_size: 64
|
|
||||||
channels: 4
|
|
||||||
cond_stage_trainable: false # Note: different from the one we trained before
|
|
||||||
conditioning_key: hybrid # important
|
|
||||||
monitor: val/loss_simple_ema
|
|
||||||
scale_factor: 0.18215
|
|
||||||
finetune_keys: null
|
|
||||||
|
|
||||||
scheduler_config: # 10000 warmup steps
|
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
|
||||||
params:
|
|
||||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
|
||||||
f_start: [ 1.e-6 ]
|
|
||||||
f_max: [ 1. ]
|
|
||||||
f_min: [ 1. ]
|
|
||||||
|
|
||||||
unet_config:
|
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
image_size: 32 # unused
|
|
||||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
|
||||||
num_heads: 8
|
|
||||||
use_spatial_transformer: True
|
|
||||||
transformer_depth: 1
|
|
||||||
context_dim: 768
|
|
||||||
use_checkpoint: False
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult:
|
|
||||||
- 1
|
|
||||||
- 2
|
|
||||||
- 4
|
|
||||||
- 4
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
|
|
||||||
cond_stage_config:
|
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
||||||
@@ -1,98 +1,98 @@
|
|||||||
import logging
|
# import logging
|
||||||
|
#
|
||||||
import torch
|
# import torch
|
||||||
from torch import Tensor
|
# from torch import Tensor
|
||||||
import platform
|
# import platform
|
||||||
from modules.sd_hijack_utils import CondFunc
|
# from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
# from packaging import version
|
||||||
from modules import shared
|
# from modules import shared
|
||||||
|
#
|
||||||
log = logging.getLogger(__name__)
|
# log = logging.getLogger(__name__)
|
||||||
|
#
|
||||||
|
#
|
||||||
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
# # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||||
# use check `getattr` and try it for compatibility.
|
# # use check `getattr` and try it for compatibility.
|
||||||
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
# # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
||||||
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
# # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||||
def check_for_mps() -> bool:
|
# def check_for_mps() -> bool:
|
||||||
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
# if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||||
if not getattr(torch, 'has_mps', False):
|
# if not getattr(torch, 'has_mps', False):
|
||||||
return False
|
# return False
|
||||||
try:
|
# try:
|
||||||
torch.zeros(1).to(torch.device("mps"))
|
# torch.zeros(1).to(torch.device("mps"))
|
||||||
return True
|
# return True
|
||||||
except Exception:
|
# except Exception:
|
||||||
return False
|
# return False
|
||||||
else:
|
# else:
|
||||||
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
# return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||||
|
#
|
||||||
|
#
|
||||||
has_mps = check_for_mps()
|
# has_mps = check_for_mps()
|
||||||
|
#
|
||||||
|
#
|
||||||
def torch_mps_gc() -> None:
|
# def torch_mps_gc() -> None:
|
||||||
try:
|
# try:
|
||||||
if shared.state.current_latent is not None:
|
# if shared.state.current_latent is not None:
|
||||||
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
# log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||||
return
|
# return
|
||||||
from torch.mps import empty_cache
|
# from torch.mps import empty_cache
|
||||||
empty_cache()
|
# empty_cache()
|
||||||
except Exception:
|
# except Exception:
|
||||||
log.warning("MPS garbage collection failed", exc_info=True)
|
# log.warning("MPS garbage collection failed", exc_info=True)
|
||||||
|
#
|
||||||
|
#
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
# # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
# def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||||
if input.device.type == 'mps':
|
# if input.device.type == 'mps':
|
||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
# output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
if output_dtype == torch.int64:
|
# if output_dtype == torch.int64:
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
# return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
# elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
# return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||||
return cumsum_func(input, *args, **kwargs)
|
# return cumsum_func(input, *args, **kwargs)
|
||||||
|
#
|
||||||
|
#
|
||||||
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||||
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
# def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
||||||
try:
|
# try:
|
||||||
return orig_func(*args, **kwargs)
|
# return orig_func(*args, **kwargs)
|
||||||
except RuntimeError as e:
|
# except RuntimeError as e:
|
||||||
if "not implemented for" in str(e) and "Half" in str(e):
|
# if "not implemented for" in str(e) and "Half" in str(e):
|
||||||
input_tensor = args[0]
|
# input_tensor = args[0]
|
||||||
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
# return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
||||||
else:
|
# else:
|
||||||
print(f"An unexpected RuntimeError occurred: {str(e)}")
|
# print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||||
|
#
|
||||||
if has_mps:
|
# if has_mps:
|
||||||
if platform.mac_ver()[0].startswith("13.2."):
|
# if platform.mac_ver()[0].startswith("13.2."):
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
# # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
# CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||||
|
#
|
||||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
# if version.parse(torch.__version__) < version.parse("1.13"):
|
||||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
# # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||||
|
#
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
# # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
# CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
# lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
# # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
# lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
# # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
# CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
# elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
# cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
# cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
# CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
# CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
# CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||||
|
#
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
# # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||||
|
#
|
||||||
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||||
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
# CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
||||||
|
#
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
# # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
if platform.processor() == 'i386':
|
# if platform.processor() == 'i386':
|
||||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
# for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
# CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
import importlib
|
# import importlib
|
||||||
import torch
|
# import torch
|
||||||
|
#
|
||||||
from modules import shared
|
# from modules import shared
|
||||||
|
#
|
||||||
|
#
|
||||||
def check_for_npu():
|
# def check_for_npu():
|
||||||
if importlib.util.find_spec("torch_npu") is None:
|
# if importlib.util.find_spec("torch_npu") is None:
|
||||||
return False
|
# return False
|
||||||
import torch_npu
|
# import torch_npu
|
||||||
|
#
|
||||||
try:
|
# try:
|
||||||
# Will raise a RuntimeError if no NPU is found
|
# # Will raise a RuntimeError if no NPU is found
|
||||||
_ = torch_npu.npu.device_count()
|
# _ = torch_npu.npu.device_count()
|
||||||
return torch.npu.is_available()
|
# return torch.npu.is_available()
|
||||||
except RuntimeError:
|
# except RuntimeError:
|
||||||
return False
|
# return False
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_npu_device_string():
|
# def get_npu_device_string():
|
||||||
if shared.cmd_opts.device_id is not None:
|
# if shared.cmd_opts.device_id is not None:
|
||||||
return f"npu:{shared.cmd_opts.device_id}"
|
# return f"npu:{shared.cmd_opts.device_id}"
|
||||||
return "npu:0"
|
# return "npu:0"
|
||||||
|
#
|
||||||
|
#
|
||||||
def torch_npu_gc():
|
# def torch_npu_gc():
|
||||||
with torch.npu.device(get_npu_device_string()):
|
# with torch.npu.device(get_npu_device_string()):
|
||||||
torch.npu.empty_cache()
|
# torch.npu.empty_cache()
|
||||||
|
#
|
||||||
|
#
|
||||||
has_npu = check_for_npu()
|
# has_npu = check_for_npu()
|
||||||
|
|||||||
@@ -1,215 +0,0 @@
|
|||||||
# original source:
|
|
||||||
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
|
||||||
# license:
|
|
||||||
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
|
|
||||||
# credit:
|
|
||||||
# Amin Rezaei (original author)
|
|
||||||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
|
||||||
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
|
|
||||||
# implementation of:
|
|
||||||
# Self-attention Does Not Need O(n2) Memory":
|
|
||||||
# https://arxiv.org/abs/2112.05682v2
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
import math
|
|
||||||
from typing import Optional, NamedTuple
|
|
||||||
|
|
||||||
|
|
||||||
def narrow_trunc(
|
|
||||||
input: Tensor,
|
|
||||||
dim: int,
|
|
||||||
start: int,
|
|
||||||
length: int
|
|
||||||
) -> Tensor:
|
|
||||||
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
|
||||||
|
|
||||||
|
|
||||||
class AttnChunk(NamedTuple):
|
|
||||||
exp_values: Tensor
|
|
||||||
exp_weights_sum: Tensor
|
|
||||||
max_score: Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizeChunk:
|
|
||||||
@staticmethod
|
|
||||||
def __call__(
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
) -> AttnChunk: ...
|
|
||||||
|
|
||||||
|
|
||||||
class ComputeQueryChunkAttn:
|
|
||||||
@staticmethod
|
|
||||||
def __call__(
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
) -> Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
def _summarize_chunk(
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
scale: float,
|
|
||||||
) -> AttnChunk:
|
|
||||||
attn_weights = torch.baddbmm(
|
|
||||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
|
||||||
query,
|
|
||||||
key.transpose(1,2),
|
|
||||||
alpha=scale,
|
|
||||||
beta=0,
|
|
||||||
)
|
|
||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
|
||||||
max_score = max_score.detach()
|
|
||||||
exp_weights = torch.exp(attn_weights - max_score)
|
|
||||||
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
|
|
||||||
max_score = max_score.squeeze(-1)
|
|
||||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
|
||||||
|
|
||||||
|
|
||||||
def _query_chunk_attention(
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
summarize_chunk: SummarizeChunk,
|
|
||||||
kv_chunk_size: int,
|
|
||||||
) -> Tensor:
|
|
||||||
batch_x_heads, k_tokens, k_channels_per_head = key.shape
|
|
||||||
_, _, v_channels_per_head = value.shape
|
|
||||||
|
|
||||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
|
||||||
key_chunk = narrow_trunc(
|
|
||||||
key,
|
|
||||||
1,
|
|
||||||
chunk_idx,
|
|
||||||
kv_chunk_size
|
|
||||||
)
|
|
||||||
value_chunk = narrow_trunc(
|
|
||||||
value,
|
|
||||||
1,
|
|
||||||
chunk_idx,
|
|
||||||
kv_chunk_size
|
|
||||||
)
|
|
||||||
return summarize_chunk(query, key_chunk, value_chunk)
|
|
||||||
|
|
||||||
chunks: list[AttnChunk] = [
|
|
||||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
|
||||||
]
|
|
||||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
|
||||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
|
||||||
|
|
||||||
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
|
||||||
max_diffs = torch.exp(chunk_max - global_max)
|
|
||||||
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
|
||||||
chunk_weights *= max_diffs
|
|
||||||
|
|
||||||
all_values = chunk_values.sum(dim=0)
|
|
||||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
|
||||||
return all_values / all_weights
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
|
||||||
def _get_attention_scores_no_kv_chunking(
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
scale: float,
|
|
||||||
) -> Tensor:
|
|
||||||
attn_scores = torch.baddbmm(
|
|
||||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
|
||||||
query,
|
|
||||||
key.transpose(1,2),
|
|
||||||
alpha=scale,
|
|
||||||
beta=0,
|
|
||||||
)
|
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
|
||||||
del attn_scores
|
|
||||||
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
|
|
||||||
return hidden_states_slice
|
|
||||||
|
|
||||||
|
|
||||||
class ScannedChunk(NamedTuple):
|
|
||||||
chunk_idx: int
|
|
||||||
attn_chunk: AttnChunk
|
|
||||||
|
|
||||||
|
|
||||||
def efficient_dot_product_attention(
|
|
||||||
query: Tensor,
|
|
||||||
key: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
query_chunk_size=1024,
|
|
||||||
kv_chunk_size: Optional[int] = None,
|
|
||||||
kv_chunk_size_min: Optional[int] = None,
|
|
||||||
use_checkpoint=True,
|
|
||||||
):
|
|
||||||
"""Computes efficient dot-product attention given query, key, and value.
|
|
||||||
This is efficient version of attention presented in
|
|
||||||
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
|
||||||
Args:
|
|
||||||
query: queries for calculating attention with shape of
|
|
||||||
`[batch * num_heads, tokens, channels_per_head]`.
|
|
||||||
key: keys for calculating attention with shape of
|
|
||||||
`[batch * num_heads, tokens, channels_per_head]`.
|
|
||||||
value: values to be used in attention with shape of
|
|
||||||
`[batch * num_heads, tokens, channels_per_head]`.
|
|
||||||
query_chunk_size: int: query chunks size
|
|
||||||
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
|
||||||
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
|
||||||
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
|
||||||
Returns:
|
|
||||||
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
|
||||||
"""
|
|
||||||
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
|
||||||
_, k_tokens, _ = key.shape
|
|
||||||
scale = q_channels_per_head ** -0.5
|
|
||||||
|
|
||||||
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
|
||||||
if kv_chunk_size_min is not None:
|
|
||||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
|
||||||
|
|
||||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
|
||||||
return narrow_trunc(
|
|
||||||
query,
|
|
||||||
1,
|
|
||||||
chunk_idx,
|
|
||||||
min(query_chunk_size, q_tokens)
|
|
||||||
)
|
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
|
||||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
|
||||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
|
||||||
_get_attention_scores_no_kv_chunking,
|
|
||||||
scale=scale
|
|
||||||
) if k_tokens <= kv_chunk_size else (
|
|
||||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
|
||||||
partial(
|
|
||||||
_query_chunk_attention,
|
|
||||||
kv_chunk_size=kv_chunk_size,
|
|
||||||
summarize_chunk=summarize_chunk,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if q_tokens <= query_chunk_size:
|
|
||||||
# fast-path for when there's just 1 query chunk
|
|
||||||
return compute_query_chunk_attn(
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
res = torch.zeros_like(query)
|
|
||||||
for i in range(math.ceil(q_tokens / query_chunk_size)):
|
|
||||||
attn_scores = compute_query_chunk_attn(
|
|
||||||
query=get_query_chunk(i * query_chunk_size),
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
|
|
||||||
|
|
||||||
return res
|
|
||||||
280
modules/xlmr.py
280
modules/xlmr.py
@@ -1,140 +1,140 @@
|
|||||||
from transformers import BertPreTrainedModel, BertConfig
|
# from transformers import BertPreTrainedModel, BertConfig
|
||||||
import torch.nn as nn
|
# import torch.nn as nn
|
||||||
import torch
|
# import torch
|
||||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
# from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||||
from typing import Optional
|
# from typing import Optional
|
||||||
|
#
|
||||||
from modules import torch_utils
|
# from modules import torch_utils
|
||||||
|
#
|
||||||
|
#
|
||||||
class BertSeriesConfig(BertConfig):
|
# class BertSeriesConfig(BertConfig):
|
||||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||||
|
#
|
||||||
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||||
self.project_dim = project_dim
|
# self.project_dim = project_dim
|
||||||
self.pooler_fn = pooler_fn
|
# self.pooler_fn = pooler_fn
|
||||||
self.learn_encoder = learn_encoder
|
# self.learn_encoder = learn_encoder
|
||||||
|
#
|
||||||
class RobertaSeriesConfig(XLMRobertaConfig):
|
# class RobertaSeriesConfig(XLMRobertaConfig):
|
||||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
self.project_dim = project_dim
|
# self.project_dim = project_dim
|
||||||
self.pooler_fn = pooler_fn
|
# self.pooler_fn = pooler_fn
|
||||||
self.learn_encoder = learn_encoder
|
# self.learn_encoder = learn_encoder
|
||||||
|
#
|
||||||
|
#
|
||||||
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
# class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||||
|
#
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
config_class = BertSeriesConfig
|
# config_class = BertSeriesConfig
|
||||||
|
#
|
||||||
def __init__(self, config=None, **kargs):
|
# def __init__(self, config=None, **kargs):
|
||||||
# modify initialization for autoloading
|
# # modify initialization for autoloading
|
||||||
if config is None:
|
# if config is None:
|
||||||
config = XLMRobertaConfig()
|
# config = XLMRobertaConfig()
|
||||||
config.attention_probs_dropout_prob= 0.1
|
# config.attention_probs_dropout_prob= 0.1
|
||||||
config.bos_token_id=0
|
# config.bos_token_id=0
|
||||||
config.eos_token_id=2
|
# config.eos_token_id=2
|
||||||
config.hidden_act='gelu'
|
# config.hidden_act='gelu'
|
||||||
config.hidden_dropout_prob=0.1
|
# config.hidden_dropout_prob=0.1
|
||||||
config.hidden_size=1024
|
# config.hidden_size=1024
|
||||||
config.initializer_range=0.02
|
# config.initializer_range=0.02
|
||||||
config.intermediate_size=4096
|
# config.intermediate_size=4096
|
||||||
config.layer_norm_eps=1e-05
|
# config.layer_norm_eps=1e-05
|
||||||
config.max_position_embeddings=514
|
# config.max_position_embeddings=514
|
||||||
|
#
|
||||||
config.num_attention_heads=16
|
# config.num_attention_heads=16
|
||||||
config.num_hidden_layers=24
|
# config.num_hidden_layers=24
|
||||||
config.output_past=True
|
# config.output_past=True
|
||||||
config.pad_token_id=1
|
# config.pad_token_id=1
|
||||||
config.position_embedding_type= "absolute"
|
# config.position_embedding_type= "absolute"
|
||||||
|
#
|
||||||
config.type_vocab_size= 1
|
# config.type_vocab_size= 1
|
||||||
config.use_cache=True
|
# config.use_cache=True
|
||||||
config.vocab_size= 250002
|
# config.vocab_size= 250002
|
||||||
config.project_dim = 768
|
# config.project_dim = 768
|
||||||
config.learn_encoder = False
|
# config.learn_encoder = False
|
||||||
super().__init__(config)
|
# super().__init__(config)
|
||||||
self.roberta = XLMRobertaModel(config)
|
# self.roberta = XLMRobertaModel(config)
|
||||||
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
# self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||||
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||||
self.pooler = lambda x: x[:,0]
|
# self.pooler = lambda x: x[:,0]
|
||||||
self.post_init()
|
# self.post_init()
|
||||||
|
#
|
||||||
def encode(self,c):
|
# def encode(self,c):
|
||||||
device = torch_utils.get_param(self).device
|
# device = torch_utils.get_param(self).device
|
||||||
text = self.tokenizer(c,
|
# text = self.tokenizer(c,
|
||||||
truncation=True,
|
# truncation=True,
|
||||||
max_length=77,
|
# max_length=77,
|
||||||
return_length=False,
|
# return_length=False,
|
||||||
return_overflowing_tokens=False,
|
# return_overflowing_tokens=False,
|
||||||
padding="max_length",
|
# padding="max_length",
|
||||||
return_tensors="pt")
|
# return_tensors="pt")
|
||||||
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
# text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||||
text["attention_mask"] = torch.tensor(
|
# text["attention_mask"] = torch.tensor(
|
||||||
text['attention_mask']).to(device)
|
# text['attention_mask']).to(device)
|
||||||
features = self(**text)
|
# features = self(**text)
|
||||||
return features['projection_state']
|
# return features['projection_state']
|
||||||
|
#
|
||||||
def forward(
|
# def forward(
|
||||||
self,
|
# self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
# input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
# attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
# token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
# position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
# head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
# inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
# encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
# encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
# output_attentions: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
# return_dict: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
# output_hidden_states: Optional[bool] = None,
|
||||||
) :
|
# ) :
|
||||||
r"""
|
# r"""
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
#
|
||||||
|
#
|
||||||
outputs = self.roberta(
|
# outputs = self.roberta(
|
||||||
input_ids=input_ids,
|
# input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
# attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
# token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
# position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
# head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
# inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
# encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
# encoder_attention_mask=encoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
# output_attentions=output_attentions,
|
||||||
output_hidden_states=True,
|
# output_hidden_states=True,
|
||||||
return_dict=return_dict,
|
# return_dict=return_dict,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
# last module outputs
|
# # last module outputs
|
||||||
sequence_output = outputs[0]
|
# sequence_output = outputs[0]
|
||||||
|
#
|
||||||
|
#
|
||||||
# project every module
|
# # project every module
|
||||||
sequence_output_ln = self.pre_LN(sequence_output)
|
# sequence_output_ln = self.pre_LN(sequence_output)
|
||||||
|
#
|
||||||
# pooler
|
# # pooler
|
||||||
pooler_output = self.pooler(sequence_output_ln)
|
# pooler_output = self.pooler(sequence_output_ln)
|
||||||
pooler_output = self.transformation(pooler_output)
|
# pooler_output = self.transformation(pooler_output)
|
||||||
projection_state = self.transformation(outputs.last_hidden_state)
|
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
|
#
|
||||||
return {
|
# return {
|
||||||
'pooler_output':pooler_output,
|
# 'pooler_output':pooler_output,
|
||||||
'last_hidden_state':outputs.last_hidden_state,
|
# 'last_hidden_state':outputs.last_hidden_state,
|
||||||
'hidden_states':outputs.hidden_states,
|
# 'hidden_states':outputs.hidden_states,
|
||||||
'attentions':outputs.attentions,
|
# 'attentions':outputs.attentions,
|
||||||
'projection_state':projection_state,
|
# 'projection_state':projection_state,
|
||||||
'sequence_out': sequence_output
|
# 'sequence_out': sequence_output
|
||||||
}
|
# }
|
||||||
|
#
|
||||||
|
#
|
||||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||||
base_model_prefix = 'roberta'
|
# base_model_prefix = 'roberta'
|
||||||
config_class= RobertaSeriesConfig
|
# config_class= RobertaSeriesConfig
|
||||||
|
|||||||
@@ -1,166 +1,166 @@
|
|||||||
from transformers import BertPreTrainedModel,BertConfig
|
# from transformers import BertPreTrainedModel,BertConfig
|
||||||
import torch.nn as nn
|
# import torch.nn as nn
|
||||||
import torch
|
# import torch
|
||||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
# from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||||
from typing import Optional
|
# from typing import Optional
|
||||||
from modules import torch_utils
|
# from modules import torch_utils
|
||||||
|
#
|
||||||
|
#
|
||||||
class BertSeriesConfig(BertConfig):
|
# class BertSeriesConfig(BertConfig):
|
||||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||||
|
#
|
||||||
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||||
self.project_dim = project_dim
|
# self.project_dim = project_dim
|
||||||
self.pooler_fn = pooler_fn
|
# self.pooler_fn = pooler_fn
|
||||||
self.learn_encoder = learn_encoder
|
# self.learn_encoder = learn_encoder
|
||||||
|
#
|
||||||
class RobertaSeriesConfig(XLMRobertaConfig):
|
# class RobertaSeriesConfig(XLMRobertaConfig):
|
||||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
self.project_dim = project_dim
|
# self.project_dim = project_dim
|
||||||
self.pooler_fn = pooler_fn
|
# self.pooler_fn = pooler_fn
|
||||||
self.learn_encoder = learn_encoder
|
# self.learn_encoder = learn_encoder
|
||||||
|
#
|
||||||
|
#
|
||||||
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
# class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||||
|
#
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
# _keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
config_class = BertSeriesConfig
|
# config_class = BertSeriesConfig
|
||||||
|
#
|
||||||
def __init__(self, config=None, **kargs):
|
# def __init__(self, config=None, **kargs):
|
||||||
# modify initialization for autoloading
|
# # modify initialization for autoloading
|
||||||
if config is None:
|
# if config is None:
|
||||||
config = XLMRobertaConfig()
|
# config = XLMRobertaConfig()
|
||||||
config.attention_probs_dropout_prob= 0.1
|
# config.attention_probs_dropout_prob= 0.1
|
||||||
config.bos_token_id=0
|
# config.bos_token_id=0
|
||||||
config.eos_token_id=2
|
# config.eos_token_id=2
|
||||||
config.hidden_act='gelu'
|
# config.hidden_act='gelu'
|
||||||
config.hidden_dropout_prob=0.1
|
# config.hidden_dropout_prob=0.1
|
||||||
config.hidden_size=1024
|
# config.hidden_size=1024
|
||||||
config.initializer_range=0.02
|
# config.initializer_range=0.02
|
||||||
config.intermediate_size=4096
|
# config.intermediate_size=4096
|
||||||
config.layer_norm_eps=1e-05
|
# config.layer_norm_eps=1e-05
|
||||||
config.max_position_embeddings=514
|
# config.max_position_embeddings=514
|
||||||
|
#
|
||||||
config.num_attention_heads=16
|
# config.num_attention_heads=16
|
||||||
config.num_hidden_layers=24
|
# config.num_hidden_layers=24
|
||||||
config.output_past=True
|
# config.output_past=True
|
||||||
config.pad_token_id=1
|
# config.pad_token_id=1
|
||||||
config.position_embedding_type= "absolute"
|
# config.position_embedding_type= "absolute"
|
||||||
|
#
|
||||||
config.type_vocab_size= 1
|
# config.type_vocab_size= 1
|
||||||
config.use_cache=True
|
# config.use_cache=True
|
||||||
config.vocab_size= 250002
|
# config.vocab_size= 250002
|
||||||
config.project_dim = 1024
|
# config.project_dim = 1024
|
||||||
config.learn_encoder = False
|
# config.learn_encoder = False
|
||||||
super().__init__(config)
|
# super().__init__(config)
|
||||||
self.roberta = XLMRobertaModel(config)
|
# self.roberta = XLMRobertaModel(config)
|
||||||
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
# self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||||
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
# # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||||
# self.pooler = lambda x: x[:,0]
|
# # self.pooler = lambda x: x[:,0]
|
||||||
# self.post_init()
|
# # self.post_init()
|
||||||
|
#
|
||||||
self.has_pre_transformation = True
|
# self.has_pre_transformation = True
|
||||||
if self.has_pre_transformation:
|
# if self.has_pre_transformation:
|
||||||
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
# self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||||
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
# self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.post_init()
|
# self.post_init()
|
||||||
|
#
|
||||||
def encode(self,c):
|
# def encode(self,c):
|
||||||
device = torch_utils.get_param(self).device
|
# device = torch_utils.get_param(self).device
|
||||||
text = self.tokenizer(c,
|
# text = self.tokenizer(c,
|
||||||
truncation=True,
|
# truncation=True,
|
||||||
max_length=77,
|
# max_length=77,
|
||||||
return_length=False,
|
# return_length=False,
|
||||||
return_overflowing_tokens=False,
|
# return_overflowing_tokens=False,
|
||||||
padding="max_length",
|
# padding="max_length",
|
||||||
return_tensors="pt")
|
# return_tensors="pt")
|
||||||
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
# text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||||
text["attention_mask"] = torch.tensor(
|
# text["attention_mask"] = torch.tensor(
|
||||||
text['attention_mask']).to(device)
|
# text['attention_mask']).to(device)
|
||||||
features = self(**text)
|
# features = self(**text)
|
||||||
return features['projection_state']
|
# return features['projection_state']
|
||||||
|
#
|
||||||
def forward(
|
# def forward(
|
||||||
self,
|
# self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
# input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
# attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
# token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
# position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
# head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
# inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
# encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
# encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
# output_attentions: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
# return_dict: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
# output_hidden_states: Optional[bool] = None,
|
||||||
) :
|
# ) :
|
||||||
r"""
|
# r"""
|
||||||
"""
|
# """
|
||||||
|
#
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
#
|
||||||
|
#
|
||||||
outputs = self.roberta(
|
# outputs = self.roberta(
|
||||||
input_ids=input_ids,
|
# input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
# attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
# token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
# position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
# head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
# inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
# encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
# encoder_attention_mask=encoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
# output_attentions=output_attentions,
|
||||||
output_hidden_states=True,
|
# output_hidden_states=True,
|
||||||
return_dict=return_dict,
|
# return_dict=return_dict,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
# # last module outputs
|
# # # last module outputs
|
||||||
# sequence_output = outputs[0]
|
# # sequence_output = outputs[0]
|
||||||
|
#
|
||||||
|
#
|
||||||
# # project every module
|
# # # project every module
|
||||||
# sequence_output_ln = self.pre_LN(sequence_output)
|
# # sequence_output_ln = self.pre_LN(sequence_output)
|
||||||
|
#
|
||||||
# # pooler
|
# # # pooler
|
||||||
# pooler_output = self.pooler(sequence_output_ln)
|
# # pooler_output = self.pooler(sequence_output_ln)
|
||||||
# pooler_output = self.transformation(pooler_output)
|
# # pooler_output = self.transformation(pooler_output)
|
||||||
# projection_state = self.transformation(outputs.last_hidden_state)
|
# # projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
|
#
|
||||||
if self.has_pre_transformation:
|
# if self.has_pre_transformation:
|
||||||
sequence_output2 = outputs["hidden_states"][-2]
|
# sequence_output2 = outputs["hidden_states"][-2]
|
||||||
sequence_output2 = self.pre_LN(sequence_output2)
|
# sequence_output2 = self.pre_LN(sequence_output2)
|
||||||
projection_state2 = self.transformation_pre(sequence_output2)
|
# projection_state2 = self.transformation_pre(sequence_output2)
|
||||||
|
#
|
||||||
return {
|
# return {
|
||||||
"projection_state": projection_state2,
|
# "projection_state": projection_state2,
|
||||||
"last_hidden_state": outputs.last_hidden_state,
|
# "last_hidden_state": outputs.last_hidden_state,
|
||||||
"hidden_states": outputs.hidden_states,
|
# "hidden_states": outputs.hidden_states,
|
||||||
"attentions": outputs.attentions,
|
# "attentions": outputs.attentions,
|
||||||
}
|
# }
|
||||||
else:
|
# else:
|
||||||
projection_state = self.transformation(outputs.last_hidden_state)
|
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
return {
|
# return {
|
||||||
"projection_state": projection_state,
|
# "projection_state": projection_state,
|
||||||
"last_hidden_state": outputs.last_hidden_state,
|
# "last_hidden_state": outputs.last_hidden_state,
|
||||||
"hidden_states": outputs.hidden_states,
|
# "hidden_states": outputs.hidden_states,
|
||||||
"attentions": outputs.attentions,
|
# "attentions": outputs.attentions,
|
||||||
}
|
# }
|
||||||
|
#
|
||||||
|
#
|
||||||
# return {
|
# # return {
|
||||||
# 'pooler_output':pooler_output,
|
# # 'pooler_output':pooler_output,
|
||||||
# 'last_hidden_state':outputs.last_hidden_state,
|
# # 'last_hidden_state':outputs.last_hidden_state,
|
||||||
# 'hidden_states':outputs.hidden_states,
|
# # 'hidden_states':outputs.hidden_states,
|
||||||
# 'attentions':outputs.attentions,
|
# # 'attentions':outputs.attentions,
|
||||||
# 'projection_state':projection_state,
|
# # 'projection_state':projection_state,
|
||||||
# 'sequence_out': sequence_output
|
# # 'sequence_out': sequence_output
|
||||||
# }
|
# # }
|
||||||
|
#
|
||||||
|
#
|
||||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||||
base_model_prefix = 'roberta'
|
# base_model_prefix = 'roberta'
|
||||||
config_class= RobertaSeriesConfig
|
# config_class= RobertaSeriesConfig
|
||||||
|
|||||||
@@ -1,138 +1,138 @@
|
|||||||
from modules import shared
|
# from modules import shared
|
||||||
from modules.sd_hijack_utils import CondFunc
|
# from modules.sd_hijack_utils import CondFunc
|
||||||
|
#
|
||||||
has_ipex = False
|
# has_ipex = False
|
||||||
try:
|
# try:
|
||||||
import torch
|
# import torch
|
||||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
# import intel_extension_for_pytorch as ipex # noqa: F401
|
||||||
has_ipex = True
|
# has_ipex = True
|
||||||
except Exception:
|
# except Exception:
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
def check_for_xpu():
|
# def check_for_xpu():
|
||||||
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
|
# return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_xpu_device_string():
|
# def get_xpu_device_string():
|
||||||
if shared.cmd_opts.device_id is not None:
|
# if shared.cmd_opts.device_id is not None:
|
||||||
return f"xpu:{shared.cmd_opts.device_id}"
|
# return f"xpu:{shared.cmd_opts.device_id}"
|
||||||
return "xpu"
|
# return "xpu"
|
||||||
|
#
|
||||||
|
#
|
||||||
def torch_xpu_gc():
|
# def torch_xpu_gc():
|
||||||
with torch.xpu.device(get_xpu_device_string()):
|
# with torch.xpu.device(get_xpu_device_string()):
|
||||||
torch.xpu.empty_cache()
|
# torch.xpu.empty_cache()
|
||||||
|
#
|
||||||
|
#
|
||||||
has_xpu = check_for_xpu()
|
# has_xpu = check_for_xpu()
|
||||||
|
#
|
||||||
|
#
|
||||||
# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
|
# # Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
|
||||||
# Here we implement a slicing algorithm to split large batch size into smaller chunks,
|
# # Here we implement a slicing algorithm to split large batch size into smaller chunks,
|
||||||
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
|
# # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
|
||||||
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
|
# # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
|
||||||
# which is the best trade-off between VRAM usage and performance.
|
# # which is the best trade-off between VRAM usage and performance.
|
||||||
ARC_SINGLE_ALLOCATION_LIMIT = {}
|
# ARC_SINGLE_ALLOCATION_LIMIT = {}
|
||||||
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
|
# orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
|
||||||
def torch_xpu_scaled_dot_product_attention(
|
# def torch_xpu_scaled_dot_product_attention(
|
||||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
|
# query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
|
||||||
):
|
# ):
|
||||||
# cast to same dtype first
|
# # cast to same dtype first
|
||||||
key = key.to(query.dtype)
|
# key = key.to(query.dtype)
|
||||||
value = value.to(query.dtype)
|
# value = value.to(query.dtype)
|
||||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
# if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||||
attn_mask = attn_mask.to(query.dtype)
|
# attn_mask = attn_mask.to(query.dtype)
|
||||||
|
#
|
||||||
N = query.shape[:-2] # Batch size
|
# N = query.shape[:-2] # Batch size
|
||||||
L = query.size(-2) # Target sequence length
|
# L = query.size(-2) # Target sequence length
|
||||||
E = query.size(-1) # Embedding dimension of the query and key
|
# E = query.size(-1) # Embedding dimension of the query and key
|
||||||
S = key.size(-2) # Source sequence length
|
# S = key.size(-2) # Source sequence length
|
||||||
Ev = value.size(-1) # Embedding dimension of the value
|
# Ev = value.size(-1) # Embedding dimension of the value
|
||||||
|
#
|
||||||
total_batch_size = torch.numel(torch.empty(N))
|
# total_batch_size = torch.numel(torch.empty(N))
|
||||||
device_id = query.device.index
|
# device_id = query.device.index
|
||||||
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
|
# if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
|
||||||
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
|
# ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
|
||||||
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
|
# batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
|
||||||
|
#
|
||||||
if total_batch_size <= batch_size_limit:
|
# if total_batch_size <= batch_size_limit:
|
||||||
return orig_sdp_attn_func(
|
# return orig_sdp_attn_func(
|
||||||
query,
|
# query,
|
||||||
key,
|
# key,
|
||||||
value,
|
# value,
|
||||||
attn_mask,
|
# attn_mask,
|
||||||
dropout_p,
|
# dropout_p,
|
||||||
is_causal,
|
# is_causal,
|
||||||
*args, **kwargs
|
# *args, **kwargs
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
query = torch.reshape(query, (-1, L, E))
|
# query = torch.reshape(query, (-1, L, E))
|
||||||
key = torch.reshape(key, (-1, S, E))
|
# key = torch.reshape(key, (-1, S, E))
|
||||||
value = torch.reshape(value, (-1, S, Ev))
|
# value = torch.reshape(value, (-1, S, Ev))
|
||||||
if attn_mask is not None:
|
# if attn_mask is not None:
|
||||||
attn_mask = attn_mask.view(-1, L, S)
|
# attn_mask = attn_mask.view(-1, L, S)
|
||||||
chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
|
# chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
|
||||||
outputs = []
|
# outputs = []
|
||||||
for i in range(chunk_count):
|
# for i in range(chunk_count):
|
||||||
attn_mask_chunk = (
|
# attn_mask_chunk = (
|
||||||
None
|
# None
|
||||||
if attn_mask is None
|
# if attn_mask is None
|
||||||
else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
|
# else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
|
||||||
)
|
# )
|
||||||
chunk_output = orig_sdp_attn_func(
|
# chunk_output = orig_sdp_attn_func(
|
||||||
query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
# query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||||
key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
# key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||||
value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
# value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||||
attn_mask_chunk,
|
# attn_mask_chunk,
|
||||||
dropout_p,
|
# dropout_p,
|
||||||
is_causal,
|
# is_causal,
|
||||||
*args, **kwargs
|
# *args, **kwargs
|
||||||
)
|
# )
|
||||||
outputs.append(chunk_output)
|
# outputs.append(chunk_output)
|
||||||
result = torch.cat(outputs, dim=0)
|
# result = torch.cat(outputs, dim=0)
|
||||||
return torch.reshape(result, (*N, L, Ev))
|
# return torch.reshape(result, (*N, L, Ev))
|
||||||
|
#
|
||||||
|
#
|
||||||
def is_xpu_device(device: str | torch.device = None):
|
# def is_xpu_device(device: str | torch.device = None):
|
||||||
if device is None:
|
# if device is None:
|
||||||
return False
|
# return False
|
||||||
if isinstance(device, str):
|
# if isinstance(device, str):
|
||||||
return device.startswith("xpu")
|
# return device.startswith("xpu")
|
||||||
return device.type == "xpu"
|
# return device.type == "xpu"
|
||||||
|
#
|
||||||
|
#
|
||||||
if has_xpu:
|
# if has_xpu:
|
||||||
try:
|
# try:
|
||||||
# torch.Generator supports "xpu" device since 2.1
|
# # torch.Generator supports "xpu" device since 2.1
|
||||||
torch.Generator("xpu")
|
# torch.Generator("xpu")
|
||||||
except RuntimeError:
|
# except RuntimeError:
|
||||||
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
|
# # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
|
||||||
CondFunc('torch.Generator',
|
# CondFunc('torch.Generator',
|
||||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
# lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||||
lambda orig_func, device=None: is_xpu_device(device))
|
# lambda orig_func, device=None: is_xpu_device(device))
|
||||||
|
#
|
||||||
# W/A for some OPs that could not handle different input dtypes
|
# # W/A for some OPs that could not handle different input dtypes
|
||||||
CondFunc('torch.nn.functional.layer_norm',
|
# CondFunc('torch.nn.functional.layer_norm',
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
# orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||||
weight is not None and input.dtype != weight.data.dtype)
|
# weight is not None and input.dtype != weight.data.dtype)
|
||||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
# CondFunc('torch.nn.modules.GroupNorm.forward',
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
# CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
# CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
# lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
CondFunc('torch.bmm',
|
# CondFunc('torch.bmm',
|
||||||
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
# lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
||||||
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
# lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
||||||
CondFunc('torch.cat',
|
# CondFunc('torch.cat',
|
||||||
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
# lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||||
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
# lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||||
CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
# CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||||
lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
|
# lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
|
||||||
lambda orig_func, query, *args, **kwargs: query.is_xpu)
|
# lambda orig_func, query, *args, **kwargs: query.is_xpu)
|
||||||
|
|||||||
Reference in New Issue
Block a user