mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes and minor features
This commit is contained in:
@@ -1450,8 +1450,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
|
||||
print(f"Compiling Model")
|
||||
torch.compile(self.sd.unet, dynamic=True)
|
||||
# print(f"Compiling Model")
|
||||
# torch.compile(self.sd.unet, dynamic=True)
|
||||
|
||||
###################################################################
|
||||
# TRAIN LOOP
|
||||
|
||||
@@ -23,4 +23,5 @@ prodigyopt
|
||||
controlnet_aux==0.0.7
|
||||
python-dotenv
|
||||
bitsandbytes
|
||||
xformers
|
||||
xformers
|
||||
hf_transfer
|
||||
1
run.py
1
run.py
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
import sys
|
||||
from typing import Union, OrderedDict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
105
testing/merge_in_text_encoder_adapter.py
Normal file
105
testing/merge_in_text_encoder_adapter.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
model_path = "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/objective_reality_v2.safetensors"
|
||||
te_path = "google/flan-t5-xl"
|
||||
te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/t5xl_sd15_v1"
|
||||
|
||||
print("Loading te adapter")
|
||||
te_aug_sd = load_file(te_aug_path)
|
||||
|
||||
print("Loading model")
|
||||
sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
|
||||
print("Loading Text Encoder")
|
||||
# Load the text encoder
|
||||
te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16)
|
||||
|
||||
# patch it
|
||||
sd.text_encoder = te
|
||||
sd.tokenizer = T5Tokenizer.from_pretrained(te_path)
|
||||
|
||||
unet_sd = sd.unet.state_dict()
|
||||
|
||||
weight_idx = 1
|
||||
|
||||
new_cross_attn_dim = None
|
||||
|
||||
print("Patching UNet")
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
if cross_attention_dim is None:
|
||||
pass
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
|
||||
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
|
||||
|
||||
te_aug_name = None
|
||||
while True:
|
||||
te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
|
||||
if f"{te_aug_name}.weight" in te_aug_sd:
|
||||
# increment so we dont redo it next time
|
||||
weight_idx += 1
|
||||
break
|
||||
else:
|
||||
weight_idx += 1
|
||||
|
||||
if weight_idx > 1000:
|
||||
raise ValueError("Could not find the next weight")
|
||||
|
||||
unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"]
|
||||
unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"]
|
||||
|
||||
if new_cross_attn_dim is None:
|
||||
new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1]
|
||||
|
||||
|
||||
print("Saving unmodified model")
|
||||
sd.save_pretrained(
|
||||
output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
# overwrite the unet
|
||||
unet_folder = os.path.join(output_path, "unet")
|
||||
|
||||
# move state_dict to cpu
|
||||
unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()}
|
||||
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
print("Patching new unet")
|
||||
|
||||
save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta)
|
||||
|
||||
# load the json file
|
||||
with open(os.path.join(unet_folder, "config.json"), 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
config['cross_attention_dim'] = new_cross_attn_dim
|
||||
|
||||
# save it
|
||||
with open(os.path.join(unet_folder, "config.json"), 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print("Done")
|
||||
@@ -344,6 +344,7 @@ class ModelConfig:
|
||||
self._original_refiner_name_or_path = self.refiner_name_or_path
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
|
||||
self.lora_path = kwargs.get('lora_path', None)
|
||||
self.latent_space_version = kwargs.get('latent_space_version', None)
|
||||
|
||||
# only for SDXL models for now
|
||||
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
||||
|
||||
@@ -111,6 +111,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.load_state_dict(loaded_state_dict, strict=False)
|
||||
|
||||
def setup_adapter(self):
|
||||
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
|
||||
if self.adapter_type == 'photo_maker':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
@@ -146,14 +147,23 @@ class CustomAdapter(torch.nn.Module):
|
||||
)
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
if self.config.text_encoder_arch == 't5':
|
||||
self.te = T5EncoderModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
||||
dtype=get_torch_dtype(
|
||||
self.sd_ref().dtype))
|
||||
te_kwargs = {}
|
||||
# te_kwargs['load_in_4bit'] = True
|
||||
# te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
self.te = T5EncoderModel.from_pretrained(
|
||||
self.config.text_encoder_path,
|
||||
torch_dtype=torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
|
||||
# self.te.to = lambda *args, **kwargs: None
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
elif self.config.text_encoder_arch == 'clip':
|
||||
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
||||
dtype=get_torch_dtype(
|
||||
self.sd_ref().dtype))
|
||||
dtype=torch_dtype)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
else:
|
||||
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
|
||||
@@ -531,7 +541,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=False,
|
||||
quad_count=4,
|
||||
is_generating_samples=False,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'text_encoder' and is_generating_samples:
|
||||
# replace the prompt embed with ours
|
||||
if is_unconditional:
|
||||
return self.unconditional_embeds.clone()
|
||||
return self.conditional_embeds.clone()
|
||||
|
||||
if self.adapter_type == 'ilora':
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
@@ -585,7 +585,7 @@ def get_dataloader_from_datasets(
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=4
|
||||
num_workers=8
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
|
||||
@@ -361,7 +361,7 @@ class CaptionProcessingDTOMixin:
|
||||
caption = ', '.join(token_list)
|
||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
|
||||
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
|
||||
if self.dataset_config.random_triggers:
|
||||
num_triggers = self.dataset_config.random_triggers_max
|
||||
if num_triggers > 1:
|
||||
num_triggers = random.randint(0, num_triggers)
|
||||
@@ -369,6 +369,9 @@ class CaptionProcessingDTOMixin:
|
||||
if num_triggers > 0:
|
||||
# add random triggers
|
||||
for i in range(num_triggers):
|
||||
|
||||
|
||||
|
||||
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
@@ -1316,7 +1319,9 @@ class LatentCachingMixin:
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.is_xl:
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
|
||||
@@ -487,7 +487,7 @@ class IPAdapter(torch.nn.Module):
|
||||
attn_processor_names = []
|
||||
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
@@ -540,9 +540,6 @@ class IPAdapter(torch.nn.Module):
|
||||
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn1.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
] + [
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
|
||||
@@ -6,8 +6,12 @@ import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
@@ -106,14 +110,14 @@ class TEAdapterAttnProcessor(nn.Module):
|
||||
|
||||
# only use one TE or the other. If our adapter is active only use ours
|
||||
if self.is_active and self.conditional_embeds is not None:
|
||||
adapter_hidden_states = self.conditional_embeds
|
||||
adapter_hidden_states = self.conditional_embeds.text_embeds
|
||||
# check if we are doing unconditional
|
||||
if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != encoder_hidden_states.shape[0]:
|
||||
# concat unconditional to match the hidden state batch size
|
||||
if self.unconditional_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
|
||||
unconditional = torch.cat([self.unconditional_embeds] * adapter_hidden_states.shape[0], dim=0)
|
||||
if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
|
||||
unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
|
||||
else:
|
||||
unconditional = self.unconditional_embeds
|
||||
unconditional = self.unconditional_embeds.text_embeds
|
||||
adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
|
||||
# for ip-adapter
|
||||
key = self.to_k_adapter(adapter_hidden_states)
|
||||
@@ -163,7 +167,7 @@ class TEAdapter(torch.nn.Module):
|
||||
self,
|
||||
adapter: 'CustomAdapter',
|
||||
sd: 'StableDiffusion',
|
||||
te: Union[T5EncoderModel, CLIPTextModel],
|
||||
te: Union[T5EncoderModel],
|
||||
tokenizer: CLIPTokenizer
|
||||
):
|
||||
super(TEAdapter, self).__init__()
|
||||
@@ -178,6 +182,12 @@ class TEAdapter(torch.nn.Module):
|
||||
else:
|
||||
self.token_size = self.te_ref().config.hidden_size
|
||||
|
||||
# add text projection if is sdxl
|
||||
self.text_projection = None
|
||||
if sd.is_xl:
|
||||
clip_with_projection: CLIPTextModelWithProjection = sd.text_encoder[0]
|
||||
self.text_projection = nn.Linear(te.config.hidden_size, clip_with_projection.config.projection_dim, bias=False)
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
@@ -258,16 +268,48 @@ class TEAdapter(torch.nn.Module):
|
||||
te: T5EncoderModel = self.te_ref()
|
||||
tokenizer: T5Tokenizer = self.tokenizer_ref()
|
||||
|
||||
input_ids = tokenizer(
|
||||
# input_ids = tokenizer(
|
||||
# text,
|
||||
# max_length=77,
|
||||
# padding="max_length",
|
||||
# truncation=True,
|
||||
# return_tensors="pt",
|
||||
# ).input_ids.to(te.device)
|
||||
# outputs = te(input_ids=input_ids)
|
||||
# outputs = outputs.last_hidden_state
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
tokenizer,
|
||||
te,
|
||||
text,
|
||||
max_length=77,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids.to(te.device)
|
||||
outputs = te(input_ids=input_ids)
|
||||
outputs = outputs.last_hidden_state
|
||||
return outputs
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||
if self.text_projection is not None:
|
||||
# pool the output of embeds ignoring 0 in the attention mask
|
||||
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
|
||||
|
||||
# reduce along dim 1 while maintaining batch and dim 2
|
||||
pooled_output_sum = pooled_output.sum(dim=1)
|
||||
attn_mask_sum = attn_mask_float.sum(dim=1).unsqueeze(-1)
|
||||
|
||||
pooled_output = pooled_output_sum / attn_mask_sum
|
||||
|
||||
pooled_embeds = self.text_projection(pooled_output)
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
(embeds, pooled_embeds),
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
else:
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
embeds,
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
return t5_embeds
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -672,6 +672,7 @@ class StableDiffusion:
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_generating_samples=True,
|
||||
)
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
@@ -679,6 +680,7 @@ class StableDiffusion:
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=True,
|
||||
is_generating_samples=True,
|
||||
)
|
||||
|
||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||
@@ -1324,6 +1326,20 @@ class StableDiffusion:
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
elif isinstance(self.text_encoder, T5EncoderModel):
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=77, # todo set this higher when not transfer learning
|
||||
dropout_prob=dropout_prob
|
||||
)
|
||||
return PromptEmbeds(
|
||||
embeds,
|
||||
# do we want attn mask here?
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts(
|
||||
|
||||
@@ -665,7 +665,9 @@ def encode_prompts_pixart(
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(text_encoder.device)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), attention_mask=prompt_attention_mask)
|
||||
text_input_ids = text_input_ids.to(text_encoder.device)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids, attention_mask=prompt_attention_mask)
|
||||
|
||||
return prompt_embeds.last_hidden_state, prompt_attention_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user