Implement many kernels from scratch

This commit is contained in:
layerdiffusion
2024-08-06 18:20:34 -07:00
parent 4c8331b806
commit b57573c8da
15 changed files with 209 additions and 100 deletions

View File

@@ -955,8 +955,8 @@
"pre_tokenizer": {
"type": "Metaspace",
"replacement": "▁",
"prepend_scheme": "always",
"split": true
"add_prefix_space": true,
"prepend_scheme": "first"
},
"post_processor": {
"type": "TemplateProcessing",
@@ -1015,8 +1015,8 @@
"decoder": {
"type": "Metaspace",
"replacement": "▁",
"prepend_scheme": "always",
"split": true
"add_prefix_space": true,
"prepend_scheme": "always"
},
"model": {
"type": "Unigram",

View File

@@ -1,5 +1,4 @@
{
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
@@ -931,7 +930,7 @@
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": true,
"legacy": false,
"model_max_length": 512,
"pad_token": "<pad>",
"sp_model_kwargs": {},

View File

@@ -12,7 +12,7 @@ from backend import memory_management
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
from backend.nn.clip import IntegratedCLIP, CLIPTextConfig
from backend.nn.clip import IntegratedCLIP
from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
@@ -40,17 +40,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
if cls_name in ['AutoencoderKL']:
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations():
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
model = IntegratedAutoencoderKL.from_config(config)
load_state_dict(model, state_dict)
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
from transformers import CLIPTextConfig, CLIPTextModel
config = CLIPTextConfig.from_pretrained(config_path)
with modeling_utils.no_init_weights():
with using_forge_operations():
model = IntegratedCLIP(config)
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype()):
model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True)
load_state_dict(model, state_dict, ignore_errors=[
'transformer.text_projection.weight',
@@ -58,13 +59,30 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
'logit_scale'
], log_name=cls_name)
return model
if component_name.startswith('text_encoder') and cls_name in ['T5EncoderModel']:
from transformers import T5EncoderModel, T5Config
config = T5Config.from_pretrained(config_path)
dtype = memory_management.text_encoder_dtype()
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
dtype = sd_dtype
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype):
model = IntegratedCLIP(T5EncoderModel, config)
load_state_dict(model, state_dict, log_name=cls_name)
return model
if cls_name == 'UNet2DConditionModel':
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size)
with using_forge_operations():
with using_forge_operations(device=memory_management.cpu, dtype=unet_config['dtype']):
model = IntegratedUNet2DConditionModel.from_config(unet_config)
model._internal_dict = unet_config
@@ -77,14 +95,14 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
def split_state_dict(sd):
guess = huggingface_guess.guess(sd)
guess.clip_target = guess.clip_target(sd)
state_dict = {
'unet': try_filter_state_dict(sd, ['model.diffusion_model.']),
'vae': try_filter_state_dict(sd, guess.vae_key_prefix)
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
}
sd = guess.process_clip_state_dict(sd)
guess.clip_target = guess.clip_target(sd)
for k, v in guess.clip_target.items():
state_dict[v] = try_filter_state_dict(sd, [k + '.'])

View File

@@ -11,6 +11,9 @@ from backend import stream
from backend.args import args, dynamic_args
cpu = torch.device('cpu')
class VRAMState(Enum):
DISABLED = 0 # No vram present: no need to move models to vram
NO_VRAM = 1 # Very low vram: enable all the options to save vram

View File

@@ -1,12 +1,13 @@
import torch
from transformers import CLIPTextModel, CLIPTextConfig
class IntegratedCLIP(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
def __init__(self, cls, config, add_text_projection=False):
super().__init__()
self.transformer = CLIPTextModel(config)
embed_dim = config.hidden_size
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
self.transformer = cls(config)
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
if add_text_projection:
embed_dim = config.hidden_size
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))

View File

@@ -397,8 +397,8 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
self.decoder = Decoder(double_z=True, z_channels=latent_channels, resolution=256,
in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult,
num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.embed_dim = latent_channels
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
@@ -408,7 +408,10 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
def encode(self, x, regulation=None):
z = self.encoder(x)
z = self.quant_conv(z)
if self.quant_conv is not None:
z = self.quant_conv(z)
posterior = DiagonalGaussianDistribution(z)
if regulation is not None:
return regulation(posterior)
@@ -416,7 +419,9 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
return posterior.sample()
def decode(self, z):
z = self.post_quant_conv(z)
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
x = self.decoder(z)
return x

View File

@@ -7,23 +7,29 @@ from backend import stream
stash = {}
def weights_manual_cast(layer, x):
def weights_manual_cast(layer, x, skip_dtype=False):
weight, bias, signal = None, None, None
non_blocking = True
if getattr(x.device, 'type', None) == 'mps':
non_blocking = False
target_dtype = x.dtype
target_device = x.device
if skip_dtype:
target_dtype = None
if stream.using_stream:
with stream.stream_context()(stream.mover_stream):
if layer.bias is not None:
bias = layer.bias.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking)
weight = layer.weight.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking)
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
signal = stream.mover_stream.record_event()
else:
if layer.bias is not None:
bias = layer.bias.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking)
weight = layer.weight.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking)
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
return weight, bias, signal
@@ -60,9 +66,19 @@ def cleanup_cache():
return
current_device = None
current_dtype = None
current_manual_cast_enabled = False
class ForgeOperations:
class Linear(torch.nn.Linear):
parameters_manual_cast = False
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
@@ -76,7 +92,12 @@ class ForgeOperations:
return super().forward(x)
class Conv2d(torch.nn.Conv2d):
parameters_manual_cast = False
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
@@ -90,7 +111,12 @@ class ForgeOperations:
return super().forward(x)
class Conv3d(torch.nn.Conv3d):
parameters_manual_cast = False
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
@@ -103,8 +129,98 @@ class ForgeOperations:
else:
return super().forward(x)
class Conv1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
return super().forward(x)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 2
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
class ConvTranspose1d(torch.nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 1
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
class ConvTranspose3d(torch.nn.ConvTranspose3d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 3
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
class GroupNorm(torch.nn.GroupNorm):
parameters_manual_cast = False
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
@@ -118,7 +234,12 @@ class ForgeOperations:
return super().forward(x)
class LayerNorm(torch.nn.LayerNorm):
parameters_manual_cast = False
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
@@ -131,34 +252,37 @@ class ForgeOperations:
else:
return super().forward(x)
class Embedding(torch.nn.Embedding):
class ForgeOperationsWithManualCast(ForgeOperations):
class Linear(ForgeOperations.Linear):
parameters_manual_cast = True
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
self.bias = None
class Conv2d(ForgeOperations.Conv2d):
parameters_manual_cast = True
def reset_parameters(self):
self.bias = None
return None
class Conv3d(ForgeOperations.Conv3d):
parameters_manual_cast = True
class GroupNorm(ForgeOperations.GroupNorm):
parameters_manual_cast = True
class LayerNorm(ForgeOperations.LayerNorm):
parameters_manual_cast = True
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
else:
return super().forward(x)
@contextlib.contextmanager
def using_forge_operations(parameters_manual_cast=False, operations=None):
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False):
global current_device, current_dtype, current_manual_cast_enabled
current_device, current_dtype, current_manual_cast_enabled = device, dtype, manual_cast_enabled
if operations is None:
operations = ForgeOperations
if parameters_manual_cast:
operations = ForgeOperationsWithManualCast
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:

View File

@@ -5,7 +5,7 @@ from backend.misc import image_resize
from backend import memory_management, state_dict, utils
from backend.nn.cnets import cldm, t2i_adapter
from backend.patcher.base import ModelPatcher
from backend.operations import using_forge_operations, ForgeOperationsWithManualCast, main_stream_worker, weights_manual_cast
from backend.operations import using_forge_operations, ForgeOperations, main_stream_worker, weights_manual_cast
def compute_controlnet_weighting(control, cnet):
@@ -282,7 +282,7 @@ class ControlNet(ControlBase):
super().cleanup()
class ControlLoraOps(ForgeOperationsWithManualCast):
class ControlLoraOps(ForgeOperations):
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
super().__init__()

View File

@@ -401,7 +401,7 @@ def prepare_environment():
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "2cecc9aec5b9476ad16d0b0c4a3c779f048e7cdd")
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "60a0f76d537df765570f8d497eb33ef5dfc6aa60")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:

View File

@@ -62,8 +62,8 @@ def initialize_forge():
from modules_forge.shared import diffusers_dir
if 'TRANSFORMERS_CACHE' not in os.environ:
os.environ['TRANSFORMERS_CACHE'] = diffusers_dir
# if 'TRANSFORMERS_CACHE' not in os.environ:
# os.environ['TRANSFORMERS_CACHE'] = diffusers_dir
if 'HF_HOME' not in os.environ:
os.environ['HF_HOME'] = diffusers_dir

View File

@@ -1,3 +0,0 @@
pytest-base-url~=2.0
pytest-cov~=4.0
pytest~=7.3

View File

@@ -1,34 +0,0 @@
GitPython
Pillow
accelerate
blendmodes
clean-fid
diskcache
einops
facexlib
fastapi>=0.90.1
gradio
inflection
jsonmerge
kornia
lark
numpy
omegaconf
open-clip-torch
piexif
protobuf==3.20.0
psutil
pytorch_lightning
requests
resize-right
safetensors
scikit-image>=0.19
tomesd
torch
torchdiffeq
torchsde
transformers==4.30.2
pillow-avif-plugin==1.4.3

View File

@@ -1,4 +0,0 @@
cloudpickle
decorator
synr==0.5.0
tornado

View File

@@ -30,7 +30,7 @@ tomesd==0.1.3
torch
torchdiffeq==0.2.3
torchsde==0.2.6
transformers==4.30.2
transformers==4.44.0
httpx==0.24.1
pillow-avif-plugin==1.4.3
basicsr==1.4.2

Binary file not shown.

Before

Width:  |  Height:  |  Size: 411 KiB