diff --git a/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer.json b/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer.json index 21ed409a..b11c92d7 100644 --- a/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer.json +++ b/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer.json @@ -955,8 +955,8 @@ "pre_tokenizer": { "type": "Metaspace", "replacement": "▁", - "prepend_scheme": "always", - "split": true + "add_prefix_space": true, + "prepend_scheme": "first" }, "post_processor": { "type": "TemplateProcessing", @@ -1015,8 +1015,8 @@ "decoder": { "type": "Metaspace", "replacement": "▁", - "prepend_scheme": "always", - "split": true + "add_prefix_space": true, + "prepend_scheme": "always" }, "model": { "type": "Unigram", diff --git a/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer_config.json b/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer_config.json index b336fa23..02020eb6 100644 --- a/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer_config.json +++ b/backend/huggingface/black-forest-labs/FLUX.1-dev/tokenizer_2/tokenizer_config.json @@ -1,5 +1,4 @@ { - "add_prefix_space": true, "added_tokens_decoder": { "0": { "content": "", @@ -931,7 +930,7 @@ "clean_up_tokenization_spaces": true, "eos_token": "", "extra_ids": 100, - "legacy": true, + "legacy": false, "model_max_length": 512, "pad_token": "", "sp_model_kwargs": {}, diff --git a/backend/loader.py b/backend/loader.py index b479a007..f5d251b4 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -12,7 +12,7 @@ from backend import memory_management from backend.state_dict import try_filter_state_dict, load_state_dict from backend.operations import using_forge_operations from backend.nn.vae import IntegratedAutoencoderKL -from backend.nn.clip import IntegratedCLIP, CLIPTextConfig +from backend.nn.clip import IntegratedCLIP from backend.nn.unet import IntegratedUNet2DConditionModel from backend.diffusion_engine.sd15 import StableDiffusion @@ -40,17 +40,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p if cls_name in ['AutoencoderKL']: config = IntegratedAutoencoderKL.load_config(config_path) - with using_forge_operations(): + with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()): model = IntegratedAutoencoderKL.from_config(config) load_state_dict(model, state_dict) return model if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: + from transformers import CLIPTextConfig, CLIPTextModel config = CLIPTextConfig.from_pretrained(config_path) with modeling_utils.no_init_weights(): - with using_forge_operations(): - model = IntegratedCLIP(config) + with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype()): + model = IntegratedCLIP(CLIPTextModel, config, add_text_projection=True) load_state_dict(model, state_dict, ignore_errors=[ 'transformer.text_projection.weight', @@ -58,13 +59,30 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p 'logit_scale' ], log_name=cls_name) + return model + if component_name.startswith('text_encoder') and cls_name in ['T5EncoderModel']: + from transformers import T5EncoderModel, T5Config + config = T5Config.from_pretrained(config_path) + + dtype = memory_management.text_encoder_dtype() + sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype + + if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + dtype = sd_dtype + + with modeling_utils.no_init_weights(): + with using_forge_operations(device=memory_management.cpu, dtype=dtype): + model = IntegratedCLIP(T5EncoderModel, config) + + load_state_dict(model, state_dict, log_name=cls_name) + return model if cls_name == 'UNet2DConditionModel': unet_config = guess.unet_config.copy() state_dict_size = memory_management.state_dict_size(state_dict) unet_config['dtype'] = memory_management.unet_dtype(model_params=state_dict_size) - with using_forge_operations(): + with using_forge_operations(device=memory_management.cpu, dtype=unet_config['dtype']): model = IntegratedUNet2DConditionModel.from_config(unet_config) model._internal_dict = unet_config @@ -77,14 +95,14 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p def split_state_dict(sd): guess = huggingface_guess.guess(sd) + guess.clip_target = guess.clip_target(sd) state_dict = { - 'unet': try_filter_state_dict(sd, ['model.diffusion_model.']), - 'vae': try_filter_state_dict(sd, guess.vae_key_prefix) + guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix), + guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) } sd = guess.process_clip_state_dict(sd) - guess.clip_target = guess.clip_target(sd) for k, v in guess.clip_target.items(): state_dict[v] = try_filter_state_dict(sd, [k + '.']) diff --git a/backend/memory_management.py b/backend/memory_management.py index 95a2f8e5..ea64d1b7 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -11,6 +11,9 @@ from backend import stream from backend.args import args, dynamic_args +cpu = torch.device('cpu') + + class VRAMState(Enum): DISABLED = 0 # No vram present: no need to move models to vram NO_VRAM = 1 # Very low vram: enable all the options to save vram diff --git a/backend/nn/clip.py b/backend/nn/clip.py index 373fb73d..42005d85 100644 --- a/backend/nn/clip.py +++ b/backend/nn/clip.py @@ -1,12 +1,13 @@ import torch -from transformers import CLIPTextModel, CLIPTextConfig - class IntegratedCLIP(torch.nn.Module): - def __init__(self, config: CLIPTextConfig): + def __init__(self, cls, config, add_text_projection=False): super().__init__() - self.transformer = CLIPTextModel(config) - embed_dim = config.hidden_size - self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False) - self.transformer.text_projection.weight.copy_(torch.eye(embed_dim)) + self.transformer = cls(config) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + + if add_text_projection: + embed_dim = config.hidden_size + self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False) + self.transformer.text_projection.weight.copy_(torch.eye(embed_dim)) diff --git a/backend/nn/vae.py b/backend/nn/vae.py index d98644c1..09723d6c 100644 --- a/backend/nn/vae.py +++ b/backend/nn/vae.py @@ -397,8 +397,8 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin): self.decoder = Decoder(double_z=True, z_channels=latent_channels, resolution=256, in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult, num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0) - self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None self.embed_dim = latent_channels self.scaling_factor = scaling_factor self.shift_factor = shift_factor @@ -408,7 +408,10 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin): def encode(self, x, regulation=None): z = self.encoder(x) - z = self.quant_conv(z) + + if self.quant_conv is not None: + z = self.quant_conv(z) + posterior = DiagonalGaussianDistribution(z) if regulation is not None: return regulation(posterior) @@ -416,7 +419,9 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin): return posterior.sample() def decode(self, z): - z = self.post_quant_conv(z) + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + x = self.decoder(z) return x diff --git a/backend/operations.py b/backend/operations.py index a066223d..060b42a0 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -7,23 +7,29 @@ from backend import stream stash = {} -def weights_manual_cast(layer, x): +def weights_manual_cast(layer, x, skip_dtype=False): weight, bias, signal = None, None, None non_blocking = True if getattr(x.device, 'type', None) == 'mps': non_blocking = False + target_dtype = x.dtype + target_device = x.device + + if skip_dtype: + target_dtype = None + if stream.using_stream: with stream.stream_context()(stream.mover_stream): if layer.bias is not None: - bias = layer.bias.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking) - weight = layer.weight.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking) + bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) + weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) signal = stream.mover_stream.record_event() else: if layer.bias is not None: - bias = layer.bias.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking) - weight = layer.weight.to(device=x.device, dtype=x.dtype, non_blocking=non_blocking) + bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) + weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) return weight, bias, signal @@ -60,9 +66,19 @@ def cleanup_cache(): return +current_device = None +current_dtype = None +current_manual_cast_enabled = False + + class ForgeOperations: class Linear(torch.nn.Linear): - parameters_manual_cast = False + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None @@ -76,7 +92,12 @@ class ForgeOperations: return super().forward(x) class Conv2d(torch.nn.Conv2d): - parameters_manual_cast = False + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None @@ -90,7 +111,12 @@ class ForgeOperations: return super().forward(x) class Conv3d(torch.nn.Conv3d): - parameters_manual_cast = False + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None @@ -103,8 +129,98 @@ class ForgeOperations: else: return super().forward(x) + class Conv1d(torch.nn.Conv1d): + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled + + def reset_parameters(self): + return None + + def forward(self, x): + if self.parameters_manual_cast: + weight, bias, signal = weights_manual_cast(self, x) + with main_stream_worker(weight, bias, signal): + return self._conv_forward(x, weight, bias) + else: + return super().forward(x) + + class ConvTranspose2d(torch.nn.ConvTranspose2d): + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled + + def reset_parameters(self): + return None + + def forward(self, x, output_size=None): + if self.parameters_manual_cast: + num_spatial_dims = 2 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + + weight, bias, signal = weights_manual_cast(self, x) + with main_stream_worker(weight, bias, signal): + return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + else: + return super().forward(x, output_size) + + class ConvTranspose1d(torch.nn.ConvTranspose1d): + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled + + def reset_parameters(self): + return None + + def forward(self, x, output_size=None): + if self.parameters_manual_cast: + num_spatial_dims = 1 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + + weight, bias, signal = weights_manual_cast(self, x) + with main_stream_worker(weight, bias, signal): + return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + else: + return super().forward(x, output_size) + + class ConvTranspose3d(torch.nn.ConvTranspose3d): + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled + + def reset_parameters(self): + return None + + def forward(self, x, output_size=None): + if self.parameters_manual_cast: + num_spatial_dims = 3 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + + weight, bias, signal = weights_manual_cast(self, x) + with main_stream_worker(weight, bias, signal): + return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + else: + return super().forward(x, output_size) + class GroupNorm(torch.nn.GroupNorm): - parameters_manual_cast = False + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None @@ -118,7 +234,12 @@ class ForgeOperations: return super().forward(x) class LayerNorm(torch.nn.LayerNorm): - parameters_manual_cast = False + + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + kwargs['dtype'] = current_dtype + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled def reset_parameters(self): return None @@ -131,34 +252,37 @@ class ForgeOperations: else: return super().forward(x) + class Embedding(torch.nn.Embedding): -class ForgeOperationsWithManualCast(ForgeOperations): - class Linear(ForgeOperations.Linear): - parameters_manual_cast = True + def __init__(self, *args, **kwargs): + kwargs['device'] = current_device + super().__init__(*args, **kwargs) + self.parameters_manual_cast = current_manual_cast_enabled + self.bias = None - class Conv2d(ForgeOperations.Conv2d): - parameters_manual_cast = True + def reset_parameters(self): + self.bias = None + return None - class Conv3d(ForgeOperations.Conv3d): - parameters_manual_cast = True - - class GroupNorm(ForgeOperations.GroupNorm): - parameters_manual_cast = True - - class LayerNorm(ForgeOperations.LayerNorm): - parameters_manual_cast = True + def forward(self, x): + if self.parameters_manual_cast: + weight, bias, signal = weights_manual_cast(self, x, skip_dtype=True) + with main_stream_worker(weight, bias, signal): + return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + else: + return super().forward(x) @contextlib.contextmanager -def using_forge_operations(parameters_manual_cast=False, operations=None): +def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False): + global current_device, current_dtype, current_manual_cast_enabled + + current_device, current_dtype, current_manual_cast_enabled = device, dtype, manual_cast_enabled if operations is None: operations = ForgeOperations - if parameters_manual_cast: - operations = ForgeOperationsWithManualCast - - op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] + op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding'] backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} try: diff --git a/backend/patcher/controlnet.py b/backend/patcher/controlnet.py index 15ecf962..7ea5721b 100644 --- a/backend/patcher/controlnet.py +++ b/backend/patcher/controlnet.py @@ -5,7 +5,7 @@ from backend.misc import image_resize from backend import memory_management, state_dict, utils from backend.nn.cnets import cldm, t2i_adapter from backend.patcher.base import ModelPatcher -from backend.operations import using_forge_operations, ForgeOperationsWithManualCast, main_stream_worker, weights_manual_cast +from backend.operations import using_forge_operations, ForgeOperations, main_stream_worker, weights_manual_cast def compute_controlnet_weighting(control, cnet): @@ -282,7 +282,7 @@ class ControlNet(ControlBase): super().cleanup() -class ControlLoraOps(ForgeOperationsWithManualCast): +class ControlLoraOps(ForgeOperations): class Linear(torch.nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: super().__init__() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c11ed6a1..0751994a 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -401,7 +401,7 @@ def prepare_environment(): # stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") # stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "2cecc9aec5b9476ad16d0b0c4a3c779f048e7cdd") + huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "60a0f76d537df765570f8d497eb33ef5dfc6aa60") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 48e333e8..72dd9dfa 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -62,8 +62,8 @@ def initialize_forge(): from modules_forge.shared import diffusers_dir - if 'TRANSFORMERS_CACHE' not in os.environ: - os.environ['TRANSFORMERS_CACHE'] = diffusers_dir + # if 'TRANSFORMERS_CACHE' not in os.environ: + # os.environ['TRANSFORMERS_CACHE'] = diffusers_dir if 'HF_HOME' not in os.environ: os.environ['HF_HOME'] = diffusers_dir diff --git a/requirements-test.txt b/requirements-test.txt deleted file mode 100644 index 37838ca2..00000000 --- a/requirements-test.txt +++ /dev/null @@ -1,3 +0,0 @@ -pytest-base-url~=2.0 -pytest-cov~=4.0 -pytest~=7.3 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 98d9430a..00000000 --- a/requirements.txt +++ /dev/null @@ -1,34 +0,0 @@ -GitPython -Pillow -accelerate - -blendmodes -clean-fid -diskcache -einops -facexlib -fastapi>=0.90.1 -gradio -inflection -jsonmerge -kornia -lark -numpy -omegaconf -open-clip-torch - -piexif -protobuf==3.20.0 -psutil -pytorch_lightning -requests -resize-right - -safetensors -scikit-image>=0.19 -tomesd -torch -torchdiffeq -torchsde -transformers==4.30.2 -pillow-avif-plugin==1.4.3 \ No newline at end of file diff --git a/requirements_npu.txt b/requirements_npu.txt deleted file mode 100644 index 5e6a4364..00000000 --- a/requirements_npu.txt +++ /dev/null @@ -1,4 +0,0 @@ -cloudpickle -decorator -synr==0.5.0 -tornado diff --git a/requirements_versions.txt b/requirements_versions.txt index 4f1672d1..dbab9461 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -30,7 +30,7 @@ tomesd==0.1.3 torch torchdiffeq==0.2.3 torchsde==0.2.6 -transformers==4.30.2 +transformers==4.44.0 httpx==0.24.1 pillow-avif-plugin==1.4.3 basicsr==1.4.2 diff --git a/screenshot.png b/screenshot.png deleted file mode 100644 index 47a1be4e..00000000 Binary files a/screenshot.png and /dev/null differ