diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml deleted file mode 100644 index 4944ab5c..00000000 --- a/configs/alt-diffusion-inference.yaml +++ /dev/null @@ -1,72 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: False - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: modules.xlmr.BertSeriesModelWithTransformation - params: - name: "XLMR-Large" \ No newline at end of file diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml deleted file mode 100644 index c60dca8c..00000000 --- a/configs/alt-diffusion-m18-inference.yaml +++ /dev/null @@ -1,73 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - use_checkpoint: False - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: modules.xlmr_m18.BertSeriesModelWithTransformation - params: - name: "XLMR-Large" diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml deleted file mode 100644 index 564e50ae..00000000 --- a/configs/instruct-pix2pix.yaml +++ /dev/null @@ -1,98 +0,0 @@ -# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). -# See more details in LICENSE. - -model: - base_learning_rate: 1.0e-04 - target: modules.models.diffusion.ddpm_edit.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: edited - cond_stage_key: edit - # image_size: 64 - # image_size: 32 - image_size: 16 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: hybrid - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: false - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 0 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: False - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - -data: - target: main.DataModuleFromConfig - params: - batch_size: 128 - num_workers: 1 - wrap: false - validation: - target: edit_dataset.EditDataset - params: - path: data/clip-filtered-dataset - cache_dir: data/ - cache_name: data_10k - split: val - min_text_sim: 0.2 - min_image_sim: 0.75 - min_direction_sim: 0.2 - max_samples_per_prompt: 1 - min_resize_res: 512 - max_resize_res: 512 - crop_res: 512 - output_as_edit: False - real_input: True diff --git a/configs/sd3-inference.yaml b/configs/sd3-inference.yaml deleted file mode 100644 index bccb69d2..00000000 --- a/configs/sd3-inference.yaml +++ /dev/null @@ -1,5 +0,0 @@ -model: - target: modules.models.sd3.sd3_model.SD3Inferencer - params: - shift: 3 - state_dict: null diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml deleted file mode 100644 index f40f45e3..00000000 --- a/configs/sd_xl_inpaint.yaml +++ /dev/null @@ -1,98 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.13025 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - weighting_config: - target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - adm_in_channels: 2816 - num_classes: sequential - use_checkpoint: False - in_channels: 9 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2] - num_res_blocks: 2 - channel_mult: [1, 2, 4] - num_head_channels: 64 - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 - context_dim: 2048 - spatial_transformer_attn_type: softmax-xformers - legacy: False - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - # crossattn cond - - is_trainable: False - input_key: txt - target: sgm.modules.encoders.modules.FrozenCLIPEmbedder - params: - layer: hidden - layer_idx: 11 - # crossattn and vector cond - - is_trainable: False - input_key: txt - target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 - params: - arch: ViT-bigG-14 - version: laion2b_s39b_b160k - freeze: True - layer: penultimate - always_return_pooled: True - legacy: False - # vector cond - - is_trainable: False - input_key: original_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 # multiplied by two - # vector cond - - is_trainable: False - input_key: crop_coords_top_left - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 # multiplied by two - # vector cond - - is_trainable: False - input_key: target_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 # multiplied by two - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml deleted file mode 100644 index 25c4d9ed..00000000 --- a/configs/v1-inference.yaml +++ /dev/null @@ -1,70 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: False - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml deleted file mode 100644 index 68c199f9..00000000 --- a/configs/v1-inpainting-inference.yaml +++ /dev/null @@ -1,70 +0,0 @@ -model: - base_learning_rate: 7.5e-05 - target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: hybrid # important - monitor: val/loss_simple_ema - scale_factor: 0.18215 - finetune_keys: null - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 9 # 4 data + 4 downscaled image + 1 mask - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: False - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 039689f3..a42741f6 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,98 +1,98 @@ -import logging - -import torch -from torch import Tensor -import platform -from modules.sd_hijack_utils import CondFunc -from packaging import version -from modules import shared - -log = logging.getLogger(__name__) - - -# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, -# use check `getattr` and try it for compatibility. -# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability, -# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 -def check_for_mps() -> bool: - if version.parse(torch.__version__) <= version.parse("2.0.1"): - if not getattr(torch, 'has_mps', False): - return False - try: - torch.zeros(1).to(torch.device("mps")) - return True - except Exception: - return False - else: - return torch.backends.mps.is_available() and torch.backends.mps.is_built() - - -has_mps = check_for_mps() - - -def torch_mps_gc() -> None: - try: - if shared.state.current_latent is not None: - log.debug("`current_latent` is set, skipping MPS garbage collection") - return - from torch.mps import empty_cache - empty_cache() - except Exception: - log.warning("MPS garbage collection failed", exc_info=True) - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 -def cumsum_fix(input, cumsum_func, *args, **kwargs): - if input.device.type == 'mps': - output_dtype = kwargs.get('dtype', input.dtype) - if output_dtype == torch.int64: - return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) - elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): - return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) - return cumsum_func(input, *args, **kwargs) - - -# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 -def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor: - try: - return orig_func(*args, **kwargs) - except RuntimeError as e: - if "not implemented for" in str(e) and "Half" in str(e): - input_tensor = args[0] - return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype) - else: - print(f"An unexpected RuntimeError occurred: {str(e)}") - -if has_mps: - if platform.mac_ver()[0].startswith("13.2."): - # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) - CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) - - if version.parse(torch.__version__) < version.parse("1.13"): - # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working - - # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 - CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), - lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) - # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 - CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), - lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') - # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 - CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) - elif version.parse(torch.__version__) > version.parse("1.13.1"): - cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) - cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) - CondFunc('torch.cumsum', cumsum_fix_func, None) - CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) - CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) - - # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 - CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') - - # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 - CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None) - - # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 - if platform.processor() == 'i386': - for funcName in ['torch.argmax', 'torch.Tensor.argmax']: - CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') +# import logging +# +# import torch +# from torch import Tensor +# import platform +# from modules.sd_hijack_utils import CondFunc +# from packaging import version +# from modules import shared +# +# log = logging.getLogger(__name__) +# +# +# # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# # use check `getattr` and try it for compatibility. +# # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability, +# # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 +# def check_for_mps() -> bool: +# if version.parse(torch.__version__) <= version.parse("2.0.1"): +# if not getattr(torch, 'has_mps', False): +# return False +# try: +# torch.zeros(1).to(torch.device("mps")) +# return True +# except Exception: +# return False +# else: +# return torch.backends.mps.is_available() and torch.backends.mps.is_built() +# +# +# has_mps = check_for_mps() +# +# +# def torch_mps_gc() -> None: +# try: +# if shared.state.current_latent is not None: +# log.debug("`current_latent` is set, skipping MPS garbage collection") +# return +# from torch.mps import empty_cache +# empty_cache() +# except Exception: +# log.warning("MPS garbage collection failed", exc_info=True) +# +# +# # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +# def cumsum_fix(input, cumsum_func, *args, **kwargs): +# if input.device.type == 'mps': +# output_dtype = kwargs.get('dtype', input.dtype) +# if output_dtype == torch.int64: +# return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) +# elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): +# return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) +# return cumsum_func(input, *args, **kwargs) +# +# +# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 +# def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor: +# try: +# return orig_func(*args, **kwargs) +# except RuntimeError as e: +# if "not implemented for" in str(e) and "Half" in str(e): +# input_tensor = args[0] +# return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype) +# else: +# print(f"An unexpected RuntimeError occurred: {str(e)}") +# +# if has_mps: +# if platform.mac_ver()[0].startswith("13.2."): +# # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) +# CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) +# +# if version.parse(torch.__version__) < version.parse("1.13"): +# # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +# +# # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 +# CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), +# lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) +# # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), +# lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') +# # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 +# CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) +# elif version.parse(torch.__version__) > version.parse("1.13.1"): +# cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) +# cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) +# CondFunc('torch.cumsum', cumsum_fix_func, None) +# CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) +# CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) +# +# # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 +# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') +# +# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 +# CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None) +# +# # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 +# if platform.processor() == 'i386': +# for funcName in ['torch.argmax', 'torch.Tensor.argmax']: +# CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') diff --git a/modules/npu_specific.py b/modules/npu_specific.py index 94100691..66ba3102 100644 --- a/modules/npu_specific.py +++ b/modules/npu_specific.py @@ -1,31 +1,31 @@ -import importlib -import torch - -from modules import shared - - -def check_for_npu(): - if importlib.util.find_spec("torch_npu") is None: - return False - import torch_npu - - try: - # Will raise a RuntimeError if no NPU is found - _ = torch_npu.npu.device_count() - return torch.npu.is_available() - except RuntimeError: - return False - - -def get_npu_device_string(): - if shared.cmd_opts.device_id is not None: - return f"npu:{shared.cmd_opts.device_id}" - return "npu:0" - - -def torch_npu_gc(): - with torch.npu.device(get_npu_device_string()): - torch.npu.empty_cache() - - -has_npu = check_for_npu() +# import importlib +# import torch +# +# from modules import shared +# +# +# def check_for_npu(): +# if importlib.util.find_spec("torch_npu") is None: +# return False +# import torch_npu +# +# try: +# # Will raise a RuntimeError if no NPU is found +# _ = torch_npu.npu.device_count() +# return torch.npu.is_available() +# except RuntimeError: +# return False +# +# +# def get_npu_device_string(): +# if shared.cmd_opts.device_id is not None: +# return f"npu:{shared.cmd_opts.device_id}" +# return "npu:0" +# +# +# def torch_npu_gc(): +# with torch.npu.device(get_npu_device_string()): +# torch.npu.empty_cache() +# +# +# has_npu = check_for_npu() diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py deleted file mode 100644 index 4cb561ef..00000000 --- a/modules/sub_quadratic_attention.py +++ /dev/null @@ -1,215 +0,0 @@ -# original source: -# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py -# license: -# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) -# credit: -# Amin Rezaei (original author) -# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) -# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) -# implementation of: -# Self-attention Does Not Need O(n2) Memory": -# https://arxiv.org/abs/2112.05682v2 - -from functools import partial -import torch -from torch import Tensor -from torch.utils.checkpoint import checkpoint -import math -from typing import Optional, NamedTuple - - -def narrow_trunc( - input: Tensor, - dim: int, - start: int, - length: int -) -> Tensor: - return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) - - -class AttnChunk(NamedTuple): - exp_values: Tensor - exp_weights_sum: Tensor - max_score: Tensor - - -class SummarizeChunk: - @staticmethod - def __call__( - query: Tensor, - key: Tensor, - value: Tensor, - ) -> AttnChunk: ... - - -class ComputeQueryChunkAttn: - @staticmethod - def __call__( - query: Tensor, - key: Tensor, - value: Tensor, - ) -> Tensor: ... - - -def _summarize_chunk( - query: Tensor, - key: Tensor, - value: Tensor, - scale: float, -) -> AttnChunk: - attn_weights = torch.baddbmm( - torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype), - query, - key.transpose(1,2), - alpha=scale, - beta=0, - ) - max_score, _ = torch.max(attn_weights, -1, keepdim=True) - max_score = max_score.detach() - exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) - max_score = max_score.squeeze(-1) - return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) - - -def _query_chunk_attention( - query: Tensor, - key: Tensor, - value: Tensor, - summarize_chunk: SummarizeChunk, - kv_chunk_size: int, -) -> Tensor: - batch_x_heads, k_tokens, k_channels_per_head = key.shape - _, _, v_channels_per_head = value.shape - - def chunk_scanner(chunk_idx: int) -> AttnChunk: - key_chunk = narrow_trunc( - key, - 1, - chunk_idx, - kv_chunk_size - ) - value_chunk = narrow_trunc( - value, - 1, - chunk_idx, - kv_chunk_size - ) - return summarize_chunk(query, key_chunk, value_chunk) - - chunks: list[AttnChunk] = [ - chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) - ] - acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) - chunk_values, chunk_weights, chunk_max = acc_chunk - - global_max, _ = torch.max(chunk_max, 0, keepdim=True) - max_diffs = torch.exp(chunk_max - global_max) - chunk_values *= torch.unsqueeze(max_diffs, -1) - chunk_weights *= max_diffs - - all_values = chunk_values.sum(dim=0) - all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) - return all_values / all_weights - - -# TODO: refactor CrossAttention#get_attention_scores to share code with this -def _get_attention_scores_no_kv_chunking( - query: Tensor, - key: Tensor, - value: Tensor, - scale: float, -) -> Tensor: - attn_scores = torch.baddbmm( - torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype), - query, - key.transpose(1,2), - alpha=scale, - beta=0, - ) - attn_probs = attn_scores.softmax(dim=-1) - del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) - return hidden_states_slice - - -class ScannedChunk(NamedTuple): - chunk_idx: int - attn_chunk: AttnChunk - - -def efficient_dot_product_attention( - query: Tensor, - key: Tensor, - value: Tensor, - query_chunk_size=1024, - kv_chunk_size: Optional[int] = None, - kv_chunk_size_min: Optional[int] = None, - use_checkpoint=True, -): - """Computes efficient dot-product attention given query, key, and value. - This is efficient version of attention presented in - https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. - Args: - query: queries for calculating attention with shape of - `[batch * num_heads, tokens, channels_per_head]`. - key: keys for calculating attention with shape of - `[batch * num_heads, tokens, channels_per_head]`. - value: values to be used in attention with shape of - `[batch * num_heads, tokens, channels_per_head]`. - query_chunk_size: int: query chunks size - kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) - kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). - use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) - Returns: - Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. - """ - batch_x_heads, q_tokens, q_channels_per_head = query.shape - _, k_tokens, _ = key.shape - scale = q_channels_per_head ** -0.5 - - kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) - if kv_chunk_size_min is not None: - kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) - - def get_query_chunk(chunk_idx: int) -> Tensor: - return narrow_trunc( - query, - 1, - chunk_idx, - min(query_chunk_size, q_tokens) - ) - - summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) - summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk - compute_query_chunk_attn: ComputeQueryChunkAttn = partial( - _get_attention_scores_no_kv_chunking, - scale=scale - ) if k_tokens <= kv_chunk_size else ( - # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) - partial( - _query_chunk_attention, - kv_chunk_size=kv_chunk_size, - summarize_chunk=summarize_chunk, - ) - ) - - if q_tokens <= query_chunk_size: - # fast-path for when there's just 1 query chunk - return compute_query_chunk_attn( - query=query, - key=key, - value=value, - ) - - res = torch.zeros_like(query) - for i in range(math.ceil(q_tokens / query_chunk_size)): - attn_scores = compute_query_chunk_attn( - query=get_query_chunk(i * query_chunk_size), - key=key, - value=value, - ) - - res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores - - return res diff --git a/modules/xlmr.py b/modules/xlmr.py index 319771b7..029821be 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -1,140 +1,140 @@ -from transformers import BertPreTrainedModel, BertConfig -import torch.nn as nn -import torch -from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig -from transformers import XLMRobertaModel,XLMRobertaTokenizer -from typing import Optional - -from modules import torch_utils - - -class BertSeriesConfig(BertConfig): - def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): - - super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) - self.project_dim = project_dim - self.pooler_fn = pooler_fn - self.learn_encoder = learn_encoder - -class RobertaSeriesConfig(XLMRobertaConfig): - def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) - self.project_dim = project_dim - self.pooler_fn = pooler_fn - self.learn_encoder = learn_encoder - - -class BertSeriesModelWithTransformation(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - config_class = BertSeriesConfig - - def __init__(self, config=None, **kargs): - # modify initialization for autoloading - if config is None: - config = XLMRobertaConfig() - config.attention_probs_dropout_prob= 0.1 - config.bos_token_id=0 - config.eos_token_id=2 - config.hidden_act='gelu' - config.hidden_dropout_prob=0.1 - config.hidden_size=1024 - config.initializer_range=0.02 - config.intermediate_size=4096 - config.layer_norm_eps=1e-05 - config.max_position_embeddings=514 - - config.num_attention_heads=16 - config.num_hidden_layers=24 - config.output_past=True - config.pad_token_id=1 - config.position_embedding_type= "absolute" - - config.type_vocab_size= 1 - config.use_cache=True - config.vocab_size= 250002 - config.project_dim = 768 - config.learn_encoder = False - super().__init__(config) - self.roberta = XLMRobertaModel(config) - self.transformation = nn.Linear(config.hidden_size,config.project_dim) - self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') - self.pooler = lambda x: x[:,0] - self.post_init() - - def encode(self,c): - device = torch_utils.get_param(self).device - text = self.tokenizer(c, - truncation=True, - max_length=77, - return_length=False, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt") - text["input_ids"] = torch.tensor(text["input_ids"]).to(device) - text["attention_mask"] = torch.tensor( - text['attention_mask']).to(device) - features = self(**text) - return features['projection_state'] - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) : - r""" - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=return_dict, - ) - - # last module outputs - sequence_output = outputs[0] - - - # project every module - sequence_output_ln = self.pre_LN(sequence_output) - - # pooler - pooler_output = self.pooler(sequence_output_ln) - pooler_output = self.transformation(pooler_output) - projection_state = self.transformation(outputs.last_hidden_state) - - return { - 'pooler_output':pooler_output, - 'last_hidden_state':outputs.last_hidden_state, - 'hidden_states':outputs.hidden_states, - 'attentions':outputs.attentions, - 'projection_state':projection_state, - 'sequence_out': sequence_output - } - - -class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): - base_model_prefix = 'roberta' - config_class= RobertaSeriesConfig +# from transformers import BertPreTrainedModel, BertConfig +# import torch.nn as nn +# import torch +# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +# from transformers import XLMRobertaModel,XLMRobertaTokenizer +# from typing import Optional +# +# from modules import torch_utils +# +# +# class BertSeriesConfig(BertConfig): +# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): +# +# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) +# self.project_dim = project_dim +# self.pooler_fn = pooler_fn +# self.learn_encoder = learn_encoder +# +# class RobertaSeriesConfig(XLMRobertaConfig): +# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): +# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) +# self.project_dim = project_dim +# self.pooler_fn = pooler_fn +# self.learn_encoder = learn_encoder +# +# +# class BertSeriesModelWithTransformation(BertPreTrainedModel): +# +# _keys_to_ignore_on_load_unexpected = [r"pooler"] +# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] +# config_class = BertSeriesConfig +# +# def __init__(self, config=None, **kargs): +# # modify initialization for autoloading +# if config is None: +# config = XLMRobertaConfig() +# config.attention_probs_dropout_prob= 0.1 +# config.bos_token_id=0 +# config.eos_token_id=2 +# config.hidden_act='gelu' +# config.hidden_dropout_prob=0.1 +# config.hidden_size=1024 +# config.initializer_range=0.02 +# config.intermediate_size=4096 +# config.layer_norm_eps=1e-05 +# config.max_position_embeddings=514 +# +# config.num_attention_heads=16 +# config.num_hidden_layers=24 +# config.output_past=True +# config.pad_token_id=1 +# config.position_embedding_type= "absolute" +# +# config.type_vocab_size= 1 +# config.use_cache=True +# config.vocab_size= 250002 +# config.project_dim = 768 +# config.learn_encoder = False +# super().__init__(config) +# self.roberta = XLMRobertaModel(config) +# self.transformation = nn.Linear(config.hidden_size,config.project_dim) +# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') +# self.pooler = lambda x: x[:,0] +# self.post_init() +# +# def encode(self,c): +# device = torch_utils.get_param(self).device +# text = self.tokenizer(c, +# truncation=True, +# max_length=77, +# return_length=False, +# return_overflowing_tokens=False, +# padding="max_length", +# return_tensors="pt") +# text["input_ids"] = torch.tensor(text["input_ids"]).to(device) +# text["attention_mask"] = torch.tensor( +# text['attention_mask']).to(device) +# features = self(**text) +# return features['projection_state'] +# +# def forward( +# self, +# input_ids: Optional[torch.Tensor] = None, +# attention_mask: Optional[torch.Tensor] = None, +# token_type_ids: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.Tensor] = None, +# head_mask: Optional[torch.Tensor] = None, +# inputs_embeds: Optional[torch.Tensor] = None, +# encoder_hidden_states: Optional[torch.Tensor] = None, +# encoder_attention_mask: Optional[torch.Tensor] = None, +# output_attentions: Optional[bool] = None, +# return_dict: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# ) : +# r""" +# """ +# +# return_dict = return_dict if return_dict is not None else self.config.use_return_dict +# +# +# outputs = self.roberta( +# input_ids=input_ids, +# attention_mask=attention_mask, +# token_type_ids=token_type_ids, +# position_ids=position_ids, +# head_mask=head_mask, +# inputs_embeds=inputs_embeds, +# encoder_hidden_states=encoder_hidden_states, +# encoder_attention_mask=encoder_attention_mask, +# output_attentions=output_attentions, +# output_hidden_states=True, +# return_dict=return_dict, +# ) +# +# # last module outputs +# sequence_output = outputs[0] +# +# +# # project every module +# sequence_output_ln = self.pre_LN(sequence_output) +# +# # pooler +# pooler_output = self.pooler(sequence_output_ln) +# pooler_output = self.transformation(pooler_output) +# projection_state = self.transformation(outputs.last_hidden_state) +# +# return { +# 'pooler_output':pooler_output, +# 'last_hidden_state':outputs.last_hidden_state, +# 'hidden_states':outputs.hidden_states, +# 'attentions':outputs.attentions, +# 'projection_state':projection_state, +# 'sequence_out': sequence_output +# } +# +# +# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): +# base_model_prefix = 'roberta' +# config_class= RobertaSeriesConfig diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index f6055504..f92afd16 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -1,166 +1,166 @@ -from transformers import BertPreTrainedModel,BertConfig -import torch.nn as nn -import torch -from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig -from transformers import XLMRobertaModel,XLMRobertaTokenizer -from typing import Optional -from modules import torch_utils - - -class BertSeriesConfig(BertConfig): - def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): - - super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) - self.project_dim = project_dim - self.pooler_fn = pooler_fn - self.learn_encoder = learn_encoder - -class RobertaSeriesConfig(XLMRobertaConfig): - def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) - self.project_dim = project_dim - self.pooler_fn = pooler_fn - self.learn_encoder = learn_encoder - - -class BertSeriesModelWithTransformation(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - config_class = BertSeriesConfig - - def __init__(self, config=None, **kargs): - # modify initialization for autoloading - if config is None: - config = XLMRobertaConfig() - config.attention_probs_dropout_prob= 0.1 - config.bos_token_id=0 - config.eos_token_id=2 - config.hidden_act='gelu' - config.hidden_dropout_prob=0.1 - config.hidden_size=1024 - config.initializer_range=0.02 - config.intermediate_size=4096 - config.layer_norm_eps=1e-05 - config.max_position_embeddings=514 - - config.num_attention_heads=16 - config.num_hidden_layers=24 - config.output_past=True - config.pad_token_id=1 - config.position_embedding_type= "absolute" - - config.type_vocab_size= 1 - config.use_cache=True - config.vocab_size= 250002 - config.project_dim = 1024 - config.learn_encoder = False - super().__init__(config) - self.roberta = XLMRobertaModel(config) - self.transformation = nn.Linear(config.hidden_size,config.project_dim) - # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') - # self.pooler = lambda x: x[:,0] - # self.post_init() - - self.has_pre_transformation = True - if self.has_pre_transformation: - self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) - self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.post_init() - - def encode(self,c): - device = torch_utils.get_param(self).device - text = self.tokenizer(c, - truncation=True, - max_length=77, - return_length=False, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt") - text["input_ids"] = torch.tensor(text["input_ids"]).to(device) - text["attention_mask"] = torch.tensor( - text['attention_mask']).to(device) - features = self(**text) - return features['projection_state'] - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) : - r""" - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=return_dict, - ) - - # # last module outputs - # sequence_output = outputs[0] - - - # # project every module - # sequence_output_ln = self.pre_LN(sequence_output) - - # # pooler - # pooler_output = self.pooler(sequence_output_ln) - # pooler_output = self.transformation(pooler_output) - # projection_state = self.transformation(outputs.last_hidden_state) - - if self.has_pre_transformation: - sequence_output2 = outputs["hidden_states"][-2] - sequence_output2 = self.pre_LN(sequence_output2) - projection_state2 = self.transformation_pre(sequence_output2) - - return { - "projection_state": projection_state2, - "last_hidden_state": outputs.last_hidden_state, - "hidden_states": outputs.hidden_states, - "attentions": outputs.attentions, - } - else: - projection_state = self.transformation(outputs.last_hidden_state) - return { - "projection_state": projection_state, - "last_hidden_state": outputs.last_hidden_state, - "hidden_states": outputs.hidden_states, - "attentions": outputs.attentions, - } - - - # return { - # 'pooler_output':pooler_output, - # 'last_hidden_state':outputs.last_hidden_state, - # 'hidden_states':outputs.hidden_states, - # 'attentions':outputs.attentions, - # 'projection_state':projection_state, - # 'sequence_out': sequence_output - # } - - -class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): - base_model_prefix = 'roberta' - config_class= RobertaSeriesConfig +# from transformers import BertPreTrainedModel,BertConfig +# import torch.nn as nn +# import torch +# from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +# from transformers import XLMRobertaModel,XLMRobertaTokenizer +# from typing import Optional +# from modules import torch_utils +# +# +# class BertSeriesConfig(BertConfig): +# def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): +# +# super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) +# self.project_dim = project_dim +# self.pooler_fn = pooler_fn +# self.learn_encoder = learn_encoder +# +# class RobertaSeriesConfig(XLMRobertaConfig): +# def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): +# super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) +# self.project_dim = project_dim +# self.pooler_fn = pooler_fn +# self.learn_encoder = learn_encoder +# +# +# class BertSeriesModelWithTransformation(BertPreTrainedModel): +# +# _keys_to_ignore_on_load_unexpected = [r"pooler"] +# _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] +# config_class = BertSeriesConfig +# +# def __init__(self, config=None, **kargs): +# # modify initialization for autoloading +# if config is None: +# config = XLMRobertaConfig() +# config.attention_probs_dropout_prob= 0.1 +# config.bos_token_id=0 +# config.eos_token_id=2 +# config.hidden_act='gelu' +# config.hidden_dropout_prob=0.1 +# config.hidden_size=1024 +# config.initializer_range=0.02 +# config.intermediate_size=4096 +# config.layer_norm_eps=1e-05 +# config.max_position_embeddings=514 +# +# config.num_attention_heads=16 +# config.num_hidden_layers=24 +# config.output_past=True +# config.pad_token_id=1 +# config.position_embedding_type= "absolute" +# +# config.type_vocab_size= 1 +# config.use_cache=True +# config.vocab_size= 250002 +# config.project_dim = 1024 +# config.learn_encoder = False +# super().__init__(config) +# self.roberta = XLMRobertaModel(config) +# self.transformation = nn.Linear(config.hidden_size,config.project_dim) +# # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +# self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') +# # self.pooler = lambda x: x[:,0] +# # self.post_init() +# +# self.has_pre_transformation = True +# if self.has_pre_transformation: +# self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) +# self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) +# self.post_init() +# +# def encode(self,c): +# device = torch_utils.get_param(self).device +# text = self.tokenizer(c, +# truncation=True, +# max_length=77, +# return_length=False, +# return_overflowing_tokens=False, +# padding="max_length", +# return_tensors="pt") +# text["input_ids"] = torch.tensor(text["input_ids"]).to(device) +# text["attention_mask"] = torch.tensor( +# text['attention_mask']).to(device) +# features = self(**text) +# return features['projection_state'] +# +# def forward( +# self, +# input_ids: Optional[torch.Tensor] = None, +# attention_mask: Optional[torch.Tensor] = None, +# token_type_ids: Optional[torch.Tensor] = None, +# position_ids: Optional[torch.Tensor] = None, +# head_mask: Optional[torch.Tensor] = None, +# inputs_embeds: Optional[torch.Tensor] = None, +# encoder_hidden_states: Optional[torch.Tensor] = None, +# encoder_attention_mask: Optional[torch.Tensor] = None, +# output_attentions: Optional[bool] = None, +# return_dict: Optional[bool] = None, +# output_hidden_states: Optional[bool] = None, +# ) : +# r""" +# """ +# +# return_dict = return_dict if return_dict is not None else self.config.use_return_dict +# +# +# outputs = self.roberta( +# input_ids=input_ids, +# attention_mask=attention_mask, +# token_type_ids=token_type_ids, +# position_ids=position_ids, +# head_mask=head_mask, +# inputs_embeds=inputs_embeds, +# encoder_hidden_states=encoder_hidden_states, +# encoder_attention_mask=encoder_attention_mask, +# output_attentions=output_attentions, +# output_hidden_states=True, +# return_dict=return_dict, +# ) +# +# # # last module outputs +# # sequence_output = outputs[0] +# +# +# # # project every module +# # sequence_output_ln = self.pre_LN(sequence_output) +# +# # # pooler +# # pooler_output = self.pooler(sequence_output_ln) +# # pooler_output = self.transformation(pooler_output) +# # projection_state = self.transformation(outputs.last_hidden_state) +# +# if self.has_pre_transformation: +# sequence_output2 = outputs["hidden_states"][-2] +# sequence_output2 = self.pre_LN(sequence_output2) +# projection_state2 = self.transformation_pre(sequence_output2) +# +# return { +# "projection_state": projection_state2, +# "last_hidden_state": outputs.last_hidden_state, +# "hidden_states": outputs.hidden_states, +# "attentions": outputs.attentions, +# } +# else: +# projection_state = self.transformation(outputs.last_hidden_state) +# return { +# "projection_state": projection_state, +# "last_hidden_state": outputs.last_hidden_state, +# "hidden_states": outputs.hidden_states, +# "attentions": outputs.attentions, +# } +# +# +# # return { +# # 'pooler_output':pooler_output, +# # 'last_hidden_state':outputs.last_hidden_state, +# # 'hidden_states':outputs.hidden_states, +# # 'attentions':outputs.attentions, +# # 'projection_state':projection_state, +# # 'sequence_out': sequence_output +# # } +# +# +# class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): +# base_model_prefix = 'roberta' +# config_class= RobertaSeriesConfig diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 2971dbc3..c5b445f8 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -1,138 +1,138 @@ -from modules import shared -from modules.sd_hijack_utils import CondFunc - -has_ipex = False -try: - import torch - import intel_extension_for_pytorch as ipex # noqa: F401 - has_ipex = True -except Exception: - pass - - -def check_for_xpu(): - return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() - - -def get_xpu_device_string(): - if shared.cmd_opts.device_id is not None: - return f"xpu:{shared.cmd_opts.device_id}" - return "xpu" - - -def torch_xpu_gc(): - with torch.xpu.device(get_xpu_device_string()): - torch.xpu.empty_cache() - - -has_xpu = check_for_xpu() - - -# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 -# Here we implement a slicing algorithm to split large batch size into smaller chunks, -# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. -# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, -# which is the best trade-off between VRAM usage and performance. -ARC_SINGLE_ALLOCATION_LIMIT = {} -orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention -def torch_xpu_scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs -): - # cast to same dtype first - key = key.to(query.dtype) - value = value.to(query.dtype) - if attn_mask is not None and attn_mask.dtype != torch.bool: - attn_mask = attn_mask.to(query.dtype) - - N = query.shape[:-2] # Batch size - L = query.size(-2) # Target sequence length - E = query.size(-1) # Embedding dimension of the query and key - S = key.size(-2) # Source sequence length - Ev = value.size(-1) # Embedding dimension of the value - - total_batch_size = torch.numel(torch.empty(N)) - device_id = query.device.index - if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: - ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) - batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) - - if total_batch_size <= batch_size_limit: - return orig_sdp_attn_func( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - *args, **kwargs - ) - - query = torch.reshape(query, (-1, L, E)) - key = torch.reshape(key, (-1, S, E)) - value = torch.reshape(value, (-1, S, Ev)) - if attn_mask is not None: - attn_mask = attn_mask.view(-1, L, S) - chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit - outputs = [] - for i in range(chunk_count): - attn_mask_chunk = ( - None - if attn_mask is None - else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] - ) - chunk_output = orig_sdp_attn_func( - query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], - key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], - value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], - attn_mask_chunk, - dropout_p, - is_causal, - *args, **kwargs - ) - outputs.append(chunk_output) - result = torch.cat(outputs, dim=0) - return torch.reshape(result, (*N, L, Ev)) - - -def is_xpu_device(device: str | torch.device = None): - if device is None: - return False - if isinstance(device, str): - return device.startswith("xpu") - return device.type == "xpu" - - -if has_xpu: - try: - # torch.Generator supports "xpu" device since 2.1 - torch.Generator("xpu") - except RuntimeError: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: is_xpu_device(device)) - - # W/A for some OPs that could not handle different input dtypes - CondFunc('torch.nn.functional.layer_norm', - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - weight is not None and input.dtype != weight.data.dtype) - CondFunc('torch.nn.modules.GroupNorm.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.linear.Linear.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.conv.Conv2d.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.bmm', - lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), - lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) - CondFunc('torch.cat', - lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), - lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) - CondFunc('torch.nn.functional.scaled_dot_product_attention', - lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), - lambda orig_func, query, *args, **kwargs: query.is_xpu) +# from modules import shared +# from modules.sd_hijack_utils import CondFunc +# +# has_ipex = False +# try: +# import torch +# import intel_extension_for_pytorch as ipex # noqa: F401 +# has_ipex = True +# except Exception: +# pass +# +# +# def check_for_xpu(): +# return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() +# +# +# def get_xpu_device_string(): +# if shared.cmd_opts.device_id is not None: +# return f"xpu:{shared.cmd_opts.device_id}" +# return "xpu" +# +# +# def torch_xpu_gc(): +# with torch.xpu.device(get_xpu_device_string()): +# torch.xpu.empty_cache() +# +# +# has_xpu = check_for_xpu() +# +# +# # Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 +# # Here we implement a slicing algorithm to split large batch size into smaller chunks, +# # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. +# # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, +# # which is the best trade-off between VRAM usage and performance. +# ARC_SINGLE_ALLOCATION_LIMIT = {} +# orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention +# def torch_xpu_scaled_dot_product_attention( +# query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs +# ): +# # cast to same dtype first +# key = key.to(query.dtype) +# value = value.to(query.dtype) +# if attn_mask is not None and attn_mask.dtype != torch.bool: +# attn_mask = attn_mask.to(query.dtype) +# +# N = query.shape[:-2] # Batch size +# L = query.size(-2) # Target sequence length +# E = query.size(-1) # Embedding dimension of the query and key +# S = key.size(-2) # Source sequence length +# Ev = value.size(-1) # Embedding dimension of the value +# +# total_batch_size = torch.numel(torch.empty(N)) +# device_id = query.device.index +# if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: +# ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +# batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) +# +# if total_batch_size <= batch_size_limit: +# return orig_sdp_attn_func( +# query, +# key, +# value, +# attn_mask, +# dropout_p, +# is_causal, +# *args, **kwargs +# ) +# +# query = torch.reshape(query, (-1, L, E)) +# key = torch.reshape(key, (-1, S, E)) +# value = torch.reshape(value, (-1, S, Ev)) +# if attn_mask is not None: +# attn_mask = attn_mask.view(-1, L, S) +# chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit +# outputs = [] +# for i in range(chunk_count): +# attn_mask_chunk = ( +# None +# if attn_mask is None +# else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] +# ) +# chunk_output = orig_sdp_attn_func( +# query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], +# key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], +# value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], +# attn_mask_chunk, +# dropout_p, +# is_causal, +# *args, **kwargs +# ) +# outputs.append(chunk_output) +# result = torch.cat(outputs, dim=0) +# return torch.reshape(result, (*N, L, Ev)) +# +# +# def is_xpu_device(device: str | torch.device = None): +# if device is None: +# return False +# if isinstance(device, str): +# return device.startswith("xpu") +# return device.type == "xpu" +# +# +# if has_xpu: +# try: +# # torch.Generator supports "xpu" device since 2.1 +# torch.Generator("xpu") +# except RuntimeError: +# # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) +# CondFunc('torch.Generator', +# lambda orig_func, device=None: torch.xpu.Generator(device), +# lambda orig_func, device=None: is_xpu_device(device)) +# +# # W/A for some OPs that could not handle different input dtypes +# CondFunc('torch.nn.functional.layer_norm', +# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: +# orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), +# lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: +# weight is not None and input.dtype != weight.data.dtype) +# CondFunc('torch.nn.modules.GroupNorm.forward', +# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), +# lambda orig_func, self, input: input.dtype != self.weight.data.dtype) +# CondFunc('torch.nn.modules.linear.Linear.forward', +# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), +# lambda orig_func, self, input: input.dtype != self.weight.data.dtype) +# CondFunc('torch.nn.modules.conv.Conv2d.forward', +# lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), +# lambda orig_func, self, input: input.dtype != self.weight.data.dtype) +# CondFunc('torch.bmm', +# lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), +# lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) +# CondFunc('torch.cat', +# lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), +# lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) +# CondFunc('torch.nn.functional.scaled_dot_product_attention', +# lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), +# lambda orig_func, query, *args, **kwargs: query.is_xpu)