mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-26 01:09:10 +00:00
forge 2.0.0
see also discussions
This commit is contained in:
@@ -56,8 +56,6 @@ parser.add_argument("--cuda-malloc", action="store_true")
|
||||
parser.add_argument("--cuda-stream", action="store_true")
|
||||
parser.add_argument("--pin-shared-memory", action="store_true")
|
||||
|
||||
parser.add_argument("--i-am-lllyasviel", action="store_true")
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
# Some dynamic args that may be changed by webui rather than cmd flags.
|
||||
|
||||
@@ -7,7 +7,7 @@ from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args, args
|
||||
from backend.args import dynamic_args
|
||||
from backend.modules.k_prediction import PredictionFlux
|
||||
from backend import memory_management
|
||||
|
||||
@@ -16,9 +16,6 @@ class Flux(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.Flux]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
if not args.i_am_lllyasviel:
|
||||
raise NotImplementedError('Flux is not implemented yet!')
|
||||
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
self.is_inpaint = False
|
||||
|
||||
@@ -68,9 +65,6 @@ class Flux(ForgeDiffusionEngine):
|
||||
|
||||
self.use_distilled_cfg_scale = True
|
||||
|
||||
# WebUI Legacy
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
|
||||
@@ -93,7 +87,7 @@ class Flux(ForgeDiffusionEngine):
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine_t5.process_texts([prompt])
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -53,7 +53,6 @@ class StableDiffusion(ForgeDiffusionEngine):
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd1 = True
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine.clip_skip = clip_skip
|
||||
|
||||
@@ -53,7 +53,6 @@ class StableDiffusion2(ForgeDiffusionEngine):
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd2 = True
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine.clip_skip = clip_skip
|
||||
|
||||
@@ -72,7 +72,6 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sdxl = True
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
import backend.args
|
||||
import huggingface_guess
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -69,9 +70,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
config = read_arbitrary_config(config_path)
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
sd_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
print(f'Using Detected T5 Data Type: {sd_dtype}')
|
||||
dtype = sd_dtype
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
@@ -81,32 +83,60 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
model_loader = None
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
|
||||
with using_forge_operations(**to_args):
|
||||
model = IntegratedUNet2DConditionModel.from_config(unet_config).to(**to_args)
|
||||
model._internal_dict = unet_config
|
||||
storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes)
|
||||
|
||||
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
|
||||
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
storage_dtype = unet_storage_dtype_overwrite
|
||||
else:
|
||||
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4']:
|
||||
print(f'Using Detected UNet Type: {state_dict_dtype}')
|
||||
storage_dtype = state_dict_dtype
|
||||
if state_dict_dtype in ['nf4', 'fp4']:
|
||||
print(f'Using pre-quant state dict!')
|
||||
|
||||
load_device = memory_management.get_torch_device()
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes)
|
||||
offload_device = memory_management.unet_offload_device()
|
||||
|
||||
if storage_dtype in ['nf4', 'fp4']:
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype)
|
||||
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
||||
model = model_loader(unet_config)
|
||||
else:
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype)
|
||||
need_manual_cast = storage_dtype != computation_dtype
|
||||
to_args = dict(device=initial_device, dtype=storage_dtype)
|
||||
|
||||
with using_forge_operations(**to_args, manual_cast_enabled=need_manual_cast):
|
||||
model = model_loader(unet_config).to(**to_args)
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
|
||||
with using_forge_operations(**to_args):
|
||||
model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
|
||||
if hasattr(model, '_internal_dict'):
|
||||
model._internal_dict = unet_config
|
||||
else:
|
||||
model.config = unet_config
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
model.storage_dtype = storage_dtype
|
||||
model.computation_dtype = computation_dtype
|
||||
model.load_device = load_device
|
||||
model.initial_device = initial_device
|
||||
model.offload_device = offload_device
|
||||
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
|
||||
@@ -8,7 +8,7 @@ import platform
|
||||
|
||||
from enum import Enum
|
||||
from backend import stream
|
||||
from backend.args import args, dynamic_args
|
||||
from backend.args import args
|
||||
|
||||
|
||||
cpu = torch.device('cpu')
|
||||
@@ -281,12 +281,8 @@ except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
if 'rtx' in torch_device_name.lower():
|
||||
if not args.pin_shared_memory:
|
||||
print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
|
||||
if not args.cuda_malloc:
|
||||
print('Hint: your device supports --cuda-malloc for potential speed improvements.')
|
||||
if not args.cuda_stream:
|
||||
print('Hint: your device supports --cuda-stream for potential speed improvements.')
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
@@ -305,8 +301,54 @@ def state_dict_size(sd, exclude_device=None):
|
||||
return module_mem
|
||||
|
||||
|
||||
def state_dict_dtype(state_dict):
|
||||
for k in state_dict.keys():
|
||||
if 'bitsandbytes__nf4' in k:
|
||||
return 'nf4'
|
||||
if 'bitsandbytes__fp4' in k:
|
||||
return 'fp4'
|
||||
|
||||
dtype_counts = {}
|
||||
|
||||
for tensor in state_dict.values():
|
||||
dtype = tensor.dtype
|
||||
if dtype in dtype_counts:
|
||||
dtype_counts[dtype] += 1
|
||||
else:
|
||||
dtype_counts[dtype] = 1
|
||||
|
||||
major_dtype = None
|
||||
max_count = 0
|
||||
|
||||
for dtype, count in dtype_counts.items():
|
||||
if count > max_count:
|
||||
max_count = count
|
||||
major_dtype = dtype
|
||||
|
||||
return major_dtype
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None):
|
||||
return state_dict_size(module.state_dict(), exclude_device=exclude_device)
|
||||
module_mem = 0
|
||||
for p in module.parameters():
|
||||
t = p.data
|
||||
|
||||
if exclude_device is not None:
|
||||
if t.device == exclude_device:
|
||||
continue
|
||||
|
||||
element_size = t.element_size()
|
||||
|
||||
if getattr(p, 'quant_type', None) in ['fp4', 'nf4']:
|
||||
if element_size > 1:
|
||||
# not quanted yet
|
||||
element_size = 0.55 # a bit more than 0.5 because of quant state parameters
|
||||
else:
|
||||
# quanted
|
||||
element_size = 1.1 # a bit more than 0.5 because of quant state parameters
|
||||
|
||||
module_mem += t.nelement() * element_size
|
||||
return module_mem
|
||||
|
||||
|
||||
class LoadedModel:
|
||||
@@ -587,11 +629,6 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype')
|
||||
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
return unet_storage_dtype_overwrite
|
||||
|
||||
if args.unet_in_bf16:
|
||||
return torch.bfloat16
|
||||
|
||||
@@ -1040,6 +1077,18 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
|
||||
def can_install_bnb():
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
cuda_version = tuple(int(x) for x in torch.version.cuda.split('.'))
|
||||
|
||||
if cuda_version >= (11, 7):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
||||
@@ -5,16 +5,13 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
|
||||
|
||||
|
||||
class KModel(torch.nn.Module):
|
||||
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype, k_predictor=None):
|
||||
def __init__(self, model, diffusers_scheduler, k_predictor=None):
|
||||
super().__init__()
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
self.computation_dtype = computation_dtype
|
||||
self.storage_dtype = model.storage_dtype
|
||||
self.computation_dtype = model.computation_dtype
|
||||
|
||||
need_manual_cast = self.storage_dtype != self.computation_dtype
|
||||
operations.shift_manual_cast(model, enabled=need_manual_cast)
|
||||
|
||||
print(f'K-Model Created: {dict(storage_dtype=storage_dtype, computation_dtype=computation_dtype, manual_cast=need_manual_cast)}')
|
||||
print(f'K-Model Created: {dict(storage_dtype=self.storage_dtype, computation_dtype=self.computation_dtype)}')
|
||||
|
||||
self.diffusion_model = model
|
||||
|
||||
|
||||
@@ -10,4 +10,3 @@ class IntegratedCLIP(torch.nn.Module):
|
||||
if add_text_projection:
|
||||
embed_dim = config.hidden_size
|
||||
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
|
||||
@@ -433,9 +433,9 @@ class ResBlock(TimestepBlock):
|
||||
def _forward(self, x, emb, transformer_options={}):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
if "group_norm_wrapper" in transformer_options:
|
||||
in_norm, in_rest = in_rest[0], in_rest[1:]
|
||||
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
|
||||
h = transformer_options["group_norm_wrapper"](in_norm, x, transformer_options)
|
||||
h = in_rest(h)
|
||||
else:
|
||||
h = in_rest(x)
|
||||
@@ -443,9 +443,9 @@ class ResBlock(TimestepBlock):
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
if "group_norm_wrapper" in transformer_options:
|
||||
in_norm = self.in_layers[0]
|
||||
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
|
||||
h = transformer_options["group_norm_wrapper"](in_norm, x, transformer_options)
|
||||
h = self.in_layers[1:](h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
@@ -456,8 +456,8 @@ class ResBlock(TimestepBlock):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
|
||||
if "group_norm_wrapper" in transformer_options:
|
||||
h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options)
|
||||
else:
|
||||
h = out_norm(h)
|
||||
if emb_out is not None:
|
||||
@@ -470,8 +470,8 @@ class ResBlock(TimestepBlock):
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
h = h + emb_out
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
h = transformer_options["groupnorm_wrapper"](self.out_layers[0], h, transformer_options)
|
||||
if "group_norm_wrapper" in transformer_options:
|
||||
h = transformer_options["group_norm_wrapper"](self.out_layers[0], h, transformer_options)
|
||||
h = self.out_layers[1:](h)
|
||||
else:
|
||||
h = self.out_layers(h)
|
||||
@@ -752,9 +752,9 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
|
||||
transformer_options["block"] = ("last", 0)
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
if "groupnorm_wrapper" in transformer_options:
|
||||
if "group_norm_wrapper" in transformer_options:
|
||||
out_norm, out_rest = self.out[0], self.out[1:]
|
||||
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
|
||||
h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options)
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = self.out(h)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright Forge 2024
|
||||
|
||||
import time
|
||||
import torch
|
||||
import contextlib
|
||||
@@ -8,7 +10,7 @@ from backend import stream, memory_management
|
||||
stash = {}
|
||||
|
||||
|
||||
def weights_manual_cast(layer, x, skip_dtype=False):
|
||||
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
|
||||
weight, bias, signal = None, None, None
|
||||
non_blocking = True
|
||||
|
||||
@@ -18,21 +20,28 @@ def weights_manual_cast(layer, x, skip_dtype=False):
|
||||
target_dtype = x.dtype
|
||||
target_device = x.device
|
||||
|
||||
if skip_dtype:
|
||||
target_dtype = None
|
||||
if skip_weight_dtype:
|
||||
weight_args = dict(device=target_device, non_blocking=non_blocking)
|
||||
else:
|
||||
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
|
||||
if skip_bias_dtype:
|
||||
bias_args = dict(device=target_device, non_blocking=non_blocking)
|
||||
else:
|
||||
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
|
||||
if stream.should_use_stream():
|
||||
with stream.stream_context()(stream.mover_stream):
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
bias = layer.bias.to(**bias_args)
|
||||
signal = stream.mover_stream.record_event()
|
||||
else:
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
bias = layer.bias.to(**bias_args)
|
||||
|
||||
return weight, bias, signal
|
||||
|
||||
@@ -72,19 +81,27 @@ def cleanup_cache():
|
||||
current_device = None
|
||||
current_dtype = None
|
||||
current_manual_cast_enabled = False
|
||||
current_bnb_dtype = None
|
||||
|
||||
|
||||
class ForgeOperations:
|
||||
class Linear(torch.nn.Linear):
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['device'] = current_device
|
||||
kwargs['dtype'] = current_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
if hasattr(self, 'dummy'):
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
del self.dummy
|
||||
else:
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
@@ -92,7 +109,7 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
|
||||
@@ -269,21 +286,61 @@ class ForgeOperations:
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_dtype=True)
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False):
|
||||
global current_device, current_dtype, current_manual_cast_enabled
|
||||
try:
|
||||
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits
|
||||
|
||||
current_device, current_dtype, current_manual_cast_enabled = device, dtype, manual_cast_enabled
|
||||
class ForgeOperationsBNB4bits(ForgeOperations):
|
||||
class Linear(ForgeLoader4Bit):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype)
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
def forward(self, x):
|
||||
self.weight.quant_state = self.quant_state
|
||||
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
# Maybe this can also be set to all non-bnb ops since the cost is very low.
|
||||
# And it only invokes one time, and most linear does not have bias
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if not self.parameters_manual_cast:
|
||||
return functional_linear_4bits(x, self.weight, self.bias)
|
||||
elif not self.weight.bnb_quantized:
|
||||
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
|
||||
layer_original_device = self.weight.device
|
||||
self.weight = self.weight._quantize(x.device)
|
||||
bias = self.bias.to(x.device) if self.bias is not None else None
|
||||
out = functional_linear_4bits(x, self.weight, bias)
|
||||
self.weight = self.weight.to(layer_original_device)
|
||||
return out
|
||||
else:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return functional_linear_4bits(x, weight, bias)
|
||||
|
||||
bnb_avaliable = True
|
||||
except:
|
||||
bnb_avaliable = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
|
||||
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
|
||||
|
||||
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
|
||||
|
||||
if operations is None:
|
||||
operations = ForgeOperations
|
||||
if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
|
||||
operations = ForgeOperationsBNB4bits
|
||||
else:
|
||||
operations = ForgeOperations
|
||||
|
||||
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
|
||||
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -428,11 +428,17 @@ class ControlLora(ControlNet):
|
||||
controlnet_config = model.diffusion_model.config.copy()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||
controlnet_config["dtype"] = dtype = model.storage_dtype
|
||||
|
||||
dtype = model.storage_dtype
|
||||
|
||||
if dtype in ['nf4', 'fp4']:
|
||||
dtype = torch.float16
|
||||
|
||||
controlnet_config["dtype"] = dtype
|
||||
|
||||
self.manual_cast_dtype = model.computation_dtype
|
||||
|
||||
with using_forge_operations(operations=ControlLoraOps):
|
||||
with using_forge_operations(operations=ControlLoraOps, dtype=dtype):
|
||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
|
||||
|
||||
@@ -3,20 +3,18 @@ import torch
|
||||
|
||||
from backend.modules.k_model import KModel
|
||||
from backend.patcher.base import ModelPatcher
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
@classmethod
|
||||
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
|
||||
parameters = memory_management.module_size(model)
|
||||
unet_dtype = memory_management.unet_dtype(model_params=parameters)
|
||||
load_device = memory_management.get_torch_device()
|
||||
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=config.supported_inference_dtypes)
|
||||
model.to(device=initial_load_device, dtype=unet_dtype)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=computation_dtype)
|
||||
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor)
|
||||
return UnetPatcher(
|
||||
model,
|
||||
load_device=model.diffusion_model.load_device,
|
||||
offload_device=model.diffusion_model.offload_device,
|
||||
current_device=model.diffusion_model.initial_device
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -169,8 +167,8 @@ class UnetPatcher(ModelPatcher):
|
||||
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_groupnorm_wrapper(self, wrapper):
|
||||
self.set_transformer_option('groupnorm_wrapper', wrapper)
|
||||
def set_group_norm_wrapper(self, wrapper):
|
||||
self.set_transformer_option('group_norm_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_controlnet_model_function_wrapper(self, wrapper):
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
from collections import namedtuple
|
||||
from backend.text_processing import parsing, emphasis
|
||||
from backend.text_processing.textual_inversion import EmbeddingDatabase
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
@@ -50,9 +48,6 @@ class T5TextProcessingEngine:
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
return token_count
|
||||
|
||||
def tokenize(self, texts):
|
||||
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
return tokenized
|
||||
@@ -112,45 +107,33 @@ class T5TextProcessingEngine:
|
||||
|
||||
return chunks, token_count
|
||||
|
||||
def process_texts(self, texts):
|
||||
token_count = 0
|
||||
|
||||
def __call__(self, texts):
|
||||
zs = []
|
||||
cache = {}
|
||||
batch_chunks = []
|
||||
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
chunks = cache[line]
|
||||
line_z_values = cache[line]
|
||||
else:
|
||||
chunks, current_token_count = self.tokenize_line(line)
|
||||
token_count = max(current_token_count, token_count)
|
||||
chunks, token_count = self.tokenize_line(line)
|
||||
line_z_values = []
|
||||
for chunk in chunks:
|
||||
tokens = chunk.tokens
|
||||
multipliers = chunk.multipliers
|
||||
z = self.process_tokens([tokens], [multipliers])[0]
|
||||
line_z_values.append(z)
|
||||
cache[line] = line_z_values
|
||||
|
||||
cache[line] = chunks
|
||||
zs.extend(line_z_values)
|
||||
|
||||
batch_chunks.append(chunks)
|
||||
return torch.stack(zs)
|
||||
|
||||
return batch_chunks, token_count
|
||||
|
||||
def __call__(self, texts):
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
chunk_count = max([len(x) for x in batch_chunks])
|
||||
|
||||
zs = []
|
||||
|
||||
for i in range(chunk_count):
|
||||
batch_chunk = [chunks[i] for chunks in batch_chunks]
|
||||
tokens = [x.tokens for x in batch_chunk]
|
||||
multipliers = [x.multipliers for x in batch_chunk]
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
zs.append(z)
|
||||
|
||||
return torch.hstack(zs)
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
tokens = torch.asarray(remade_batch_tokens)
|
||||
def process_tokens(self, batch_tokens, batch_multipliers):
|
||||
tokens = torch.asarray(batch_tokens)
|
||||
|
||||
z = self.encode_with_transformers(tokens)
|
||||
|
||||
self.emphasis.tokens = remade_batch_tokens
|
||||
self.emphasis.tokens = batch_tokens
|
||||
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
|
||||
self.emphasis.z = z
|
||||
self.emphasis.after_transformers()
|
||||
|
||||
Reference in New Issue
Block a user