mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Varous bug fixes
This commit is contained in:
@@ -270,6 +270,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
noise_pred = noise_pred * self.train_config.pred_scaler
|
noise_pred = noise_pred * self.train_config.pred_scaler
|
||||||
|
|
||||||
target = None
|
target = None
|
||||||
|
|
||||||
|
if self.train_config.target_noise_multiplier != 1.0:
|
||||||
|
noise = noise * self.train_config.target_noise_multiplier
|
||||||
|
|
||||||
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
|
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
|
||||||
if self.train_config.correct_pred_norm and not is_reg:
|
if self.train_config.correct_pred_norm and not is_reg:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ import json
|
|||||||
# te_path = "google/flan-t5-xl"
|
# te_path = "google/flan-t5-xl"
|
||||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS"
|
model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
|
||||||
te_path = "google/flan-t5-base"
|
te_path = "google/flan-t5-large"
|
||||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000227500.safetensors"
|
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors"
|
||||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5base_raw"
|
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw"
|
||||||
|
|
||||||
|
|
||||||
print("Loading te adapter")
|
print("Loading te adapter")
|
||||||
|
|||||||
@@ -2,62 +2,83 @@ import torch
|
|||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.orig.safetensors"
|
|
||||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.safetensors"
|
|
||||||
|
|
||||||
state_dict = load_file(model_path)
|
|
||||||
|
|
||||||
meta = OrderedDict()
|
meta = OrderedDict()
|
||||||
meta["format"] = "pt"
|
meta['format'] = "pt"
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_weight(weight, target_size):
|
||||||
|
weight = weight.to(device, torch.float32)
|
||||||
|
original_shape = weight.shape
|
||||||
|
flattened = weight.view(-1, original_shape[-1])
|
||||||
|
|
||||||
|
if flattened.shape[1] <= target_size:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
U, S, V = torch.svd(flattened)
|
||||||
|
reduced = torch.mm(U[:, :target_size], torch.diag(S[:target_size]))
|
||||||
|
|
||||||
|
if reduced.shape[1] < target_size:
|
||||||
|
padding = torch.zeros(reduced.shape[0], target_size - reduced.shape[1], device=device)
|
||||||
|
reduced = torch.cat((reduced, padding), dim=1)
|
||||||
|
|
||||||
|
return reduced.view(original_shape[:-1] + (target_size,))
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_bias(bias, target_size):
|
||||||
|
bias = bias.to(device, torch.float32)
|
||||||
|
original_size = bias.shape[0]
|
||||||
|
|
||||||
|
if original_size <= target_size:
|
||||||
|
return torch.nn.functional.pad(bias, (0, target_size - original_size))
|
||||||
|
else:
|
||||||
|
return bias.view(-1, original_size // target_size).mean(dim=1)[:target_size]
|
||||||
|
|
||||||
|
|
||||||
|
# Load your original state dict
|
||||||
|
state_dict = load_file(
|
||||||
|
"/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors")
|
||||||
|
|
||||||
|
# Create a new state dict for the reduced model
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
|
|
||||||
# Move non-blocks over
|
source_hidden_size = 1152
|
||||||
|
target_hidden_size = 1024
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if not key.startswith("transformer_blocks."):
|
value = value.to(device, torch.float32)
|
||||||
new_state_dict[key] = value
|
if 'weight' in key or 'scale_shift_table' in key:
|
||||||
|
if value.shape[0] == source_hidden_size:
|
||||||
|
value = value[:target_hidden_size]
|
||||||
|
elif value.shape[0] == source_hidden_size * 4:
|
||||||
|
value = value[:target_hidden_size * 4]
|
||||||
|
elif value.shape[0] == source_hidden_size * 6:
|
||||||
|
value = value[:target_hidden_size * 6]
|
||||||
|
|
||||||
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
|
if len(value.shape) > 1 and value.shape[
|
||||||
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
|
1] == source_hidden_size and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key:
|
||||||
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
|
value = value[:, :target_hidden_size]
|
||||||
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
|
elif len(value.shape) > 1 and value.shape[1] == source_hidden_size * 4:
|
||||||
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
|
value = value[:, :target_hidden_size * 4]
|
||||||
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
|
|
||||||
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
|
|
||||||
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
|
|
||||||
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
|
|
||||||
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
|
|
||||||
'transformer_blocks.{idx}.scale_shift_table']
|
|
||||||
|
|
||||||
# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27
|
elif 'bias' in key:
|
||||||
|
if value.shape[0] == source_hidden_size:
|
||||||
|
value = value[:target_hidden_size]
|
||||||
|
elif value.shape[0] == source_hidden_size * 4:
|
||||||
|
value = value[:target_hidden_size * 4]
|
||||||
|
elif value.shape[0] == source_hidden_size * 6:
|
||||||
|
value = value[:target_hidden_size * 6]
|
||||||
|
|
||||||
current_idx = 0
|
new_state_dict[key] = value
|
||||||
for i in range(28):
|
|
||||||
if i not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
|
|
||||||
# todo merge in with previous block
|
|
||||||
for name in block_names:
|
|
||||||
continue
|
|
||||||
# try:
|
|
||||||
# new_state_dict_key = name.format(idx=current_idx - 1)
|
|
||||||
# old_state_dict_key = name.format(idx=i)
|
|
||||||
# new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5)
|
|
||||||
# except KeyError:
|
|
||||||
# raise KeyError(f"KeyError: {name.format(idx=current_idx)}")
|
|
||||||
else:
|
|
||||||
for name in block_names:
|
|
||||||
new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)]
|
|
||||||
current_idx += 1
|
|
||||||
|
|
||||||
|
# Move all to CPU and convert to float16
|
||||||
# make sure they are all fp16 and on cpu
|
|
||||||
for key, value in new_state_dict.items():
|
for key, value in new_state_dict.items():
|
||||||
new_state_dict[key] = value.to(torch.float16).cpu()
|
new_state_dict[key] = value.cpu().to(torch.float16)
|
||||||
|
|
||||||
# save the new state dict
|
# Save the new state dict
|
||||||
save_file(new_state_dict, output_path, metadata=meta)
|
save_file(new_state_dict,
|
||||||
|
"/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors",
|
||||||
|
metadata=meta)
|
||||||
|
|
||||||
new_param_count = sum([v.numel() for v in new_state_dict.values()])
|
print("Done!")
|
||||||
old_param_count = sum([v.numel() for v in state_dict.values()])
|
|
||||||
|
|
||||||
print(f"Old param count: {old_param_count:,}")
|
|
||||||
print(f"New param count: {new_param_count:,}")
|
|
||||||
|
|||||||
110
testing/shrink_pixart_sm2.py
Normal file
110
testing/shrink_pixart_sm2.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
meta = OrderedDict()
|
||||||
|
meta['format'] = "pt"
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_weight(weight, target_size):
|
||||||
|
weight = weight.to(device, torch.float32)
|
||||||
|
original_shape = weight.shape
|
||||||
|
|
||||||
|
if len(original_shape) == 1:
|
||||||
|
# For 1D tensors, simply truncate
|
||||||
|
return weight[:target_size]
|
||||||
|
|
||||||
|
if original_shape[0] <= target_size:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
# Reshape the tensor to 2D
|
||||||
|
flattened = weight.reshape(original_shape[0], -1)
|
||||||
|
|
||||||
|
# Perform SVD
|
||||||
|
U, S, V = torch.svd(flattened)
|
||||||
|
|
||||||
|
# Reduce the dimensions
|
||||||
|
reduced = torch.mm(U[:target_size, :], torch.diag(S)).mm(V.t())
|
||||||
|
|
||||||
|
# Reshape back to the original shape with reduced first dimension
|
||||||
|
new_shape = (target_size,) + original_shape[1:]
|
||||||
|
return reduced.reshape(new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_bias(bias, target_size):
|
||||||
|
bias = bias.to(device, torch.float32)
|
||||||
|
return bias[:target_size]
|
||||||
|
|
||||||
|
|
||||||
|
# Load your original state dict
|
||||||
|
state_dict = load_file(
|
||||||
|
"/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors")
|
||||||
|
|
||||||
|
# Create a new state dict for the reduced model
|
||||||
|
new_state_dict = {}
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
value = value.to(device, torch.float32)
|
||||||
|
|
||||||
|
if 'weight' in key or 'scale_shift_table' in key:
|
||||||
|
if value.shape[0] == 1152:
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
orig_shape = value.shape
|
||||||
|
output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1)
|
||||||
|
# reshape to (1152, -1)
|
||||||
|
value = value.view(value.shape[0], -1)
|
||||||
|
value = reduce_weight(value, 512)
|
||||||
|
value = value.view(output_shape)
|
||||||
|
else:
|
||||||
|
# value = reduce_weight(value.t(), 576).t().contiguous()
|
||||||
|
value = reduce_weight(value, 512)
|
||||||
|
pass
|
||||||
|
elif value.shape[0] == 4608:
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
orig_shape = value.shape
|
||||||
|
output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3])
|
||||||
|
value = value.view(value.shape[0], -1)
|
||||||
|
value = reduce_weight(value, 2048)
|
||||||
|
value = value.view(output_shape)
|
||||||
|
else:
|
||||||
|
value = reduce_weight(value, 2048)
|
||||||
|
elif value.shape[0] == 6912:
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
orig_shape = value.shape
|
||||||
|
output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3])
|
||||||
|
value = value.view(value.shape[0], -1)
|
||||||
|
value = reduce_weight(value, 3072)
|
||||||
|
value = value.view(output_shape)
|
||||||
|
else:
|
||||||
|
value = reduce_weight(value, 3072)
|
||||||
|
|
||||||
|
if len(value.shape) > 1 and value.shape[
|
||||||
|
1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key:
|
||||||
|
value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction
|
||||||
|
pass
|
||||||
|
elif len(value.shape) > 1 and value.shape[1] == 4608:
|
||||||
|
value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif 'bias' in key:
|
||||||
|
if value.shape[0] == 1152:
|
||||||
|
value = reduce_bias(value, 512)
|
||||||
|
elif value.shape[0] == 4608:
|
||||||
|
value = reduce_bias(value, 2048)
|
||||||
|
elif value.shape[0] == 6912:
|
||||||
|
value = reduce_bias(value, 3072)
|
||||||
|
|
||||||
|
new_state_dict[key] = value
|
||||||
|
|
||||||
|
# Move all to CPU and convert to float16
|
||||||
|
for key, value in new_state_dict.items():
|
||||||
|
new_state_dict[key] = value.cpu().to(torch.float16)
|
||||||
|
|
||||||
|
# Save the new state dict
|
||||||
|
save_file(new_state_dict,
|
||||||
|
"/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors",
|
||||||
|
metadata=meta)
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
100
testing/shrink_pixart_sm3.py
Normal file
100
testing/shrink_pixart_sm3.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
meta = OrderedDict()
|
||||||
|
meta['format'] = "pt"
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_weight(weight, target_size):
|
||||||
|
weight = weight.to(device, torch.float32)
|
||||||
|
# resize so target_size is the first dimension
|
||||||
|
tmp_weight = weight.view(1, 1, weight.shape[0], weight.shape[1])
|
||||||
|
|
||||||
|
# use interpolate to resize the tensor
|
||||||
|
new_weight = torch.nn.functional.interpolate(tmp_weight, size=(target_size, weight.shape[1]), mode='bicubic', align_corners=True)
|
||||||
|
|
||||||
|
# reshape back to original shape
|
||||||
|
return new_weight.view(target_size, weight.shape[1])
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_bias(bias, target_size):
|
||||||
|
bias = bias.view(1, 1, bias.shape[0], 1)
|
||||||
|
|
||||||
|
new_bias = torch.nn.functional.interpolate(bias, size=(target_size, 1), mode='bicubic', align_corners=True)
|
||||||
|
|
||||||
|
return new_bias.view(target_size)
|
||||||
|
|
||||||
|
|
||||||
|
# Load your original state dict
|
||||||
|
state_dict = load_file(
|
||||||
|
"/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors")
|
||||||
|
|
||||||
|
# Create a new state dict for the reduced model
|
||||||
|
new_state_dict = {}
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
value = value.to(device, torch.float32)
|
||||||
|
|
||||||
|
if 'weight' in key or 'scale_shift_table' in key:
|
||||||
|
if value.shape[0] == 1152:
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
orig_shape = value.shape
|
||||||
|
output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1)
|
||||||
|
# reshape to (1152, -1)
|
||||||
|
value = value.view(value.shape[0], -1)
|
||||||
|
value = reduce_weight(value, 512)
|
||||||
|
value = value.view(output_shape)
|
||||||
|
else:
|
||||||
|
# value = reduce_weight(value.t(), 576).t().contiguous()
|
||||||
|
value = reduce_weight(value, 512)
|
||||||
|
pass
|
||||||
|
elif value.shape[0] == 4608:
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
orig_shape = value.shape
|
||||||
|
output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3])
|
||||||
|
value = value.view(value.shape[0], -1)
|
||||||
|
value = reduce_weight(value, 2048)
|
||||||
|
value = value.view(output_shape)
|
||||||
|
else:
|
||||||
|
value = reduce_weight(value, 2048)
|
||||||
|
elif value.shape[0] == 6912:
|
||||||
|
if len(value.shape) == 4:
|
||||||
|
orig_shape = value.shape
|
||||||
|
output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3])
|
||||||
|
value = value.view(value.shape[0], -1)
|
||||||
|
value = reduce_weight(value, 3072)
|
||||||
|
value = value.view(output_shape)
|
||||||
|
else:
|
||||||
|
value = reduce_weight(value, 3072)
|
||||||
|
|
||||||
|
if len(value.shape) > 1 and value.shape[
|
||||||
|
1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key:
|
||||||
|
value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction
|
||||||
|
pass
|
||||||
|
elif len(value.shape) > 1 and value.shape[1] == 4608:
|
||||||
|
value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif 'bias' in key:
|
||||||
|
if value.shape[0] == 1152:
|
||||||
|
value = reduce_bias(value, 512)
|
||||||
|
elif value.shape[0] == 4608:
|
||||||
|
value = reduce_bias(value, 2048)
|
||||||
|
elif value.shape[0] == 6912:
|
||||||
|
value = reduce_bias(value, 3072)
|
||||||
|
|
||||||
|
new_state_dict[key] = value
|
||||||
|
|
||||||
|
# Move all to CPU and convert to float16
|
||||||
|
for key, value in new_state_dict.items():
|
||||||
|
new_state_dict[key] = value.cpu().to(torch.float16)
|
||||||
|
|
||||||
|
# Save the new state dict
|
||||||
|
save_file(new_state_dict,
|
||||||
|
"/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors",
|
||||||
|
metadata=meta)
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
@@ -166,7 +166,7 @@ class ClipVisionAdapter(torch.nn.Module):
|
|||||||
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
||||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||||
else:
|
else:
|
||||||
embedding_dim = self.image_encoder.config.hidden_size
|
embedding_dim = self.image_encoder.config.target_hidden_size
|
||||||
|
|
||||||
if self.config.clip_layer == 'image_embeds':
|
if self.config.clip_layer == 'image_embeds':
|
||||||
in_tokens = 1
|
in_tokens = 1
|
||||||
@@ -308,15 +308,15 @@ class ClipVisionAdapter(torch.nn.Module):
|
|||||||
# add it to the text encoder
|
# add it to the text encoder
|
||||||
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
|
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
|
||||||
elif len(self.text_encoder_list) == 2:
|
elif len(self.text_encoder_list) == 2:
|
||||||
if self.text_encoder_list[0].config.hidden_size + self.text_encoder_list[1].config.hidden_size != \
|
if self.text_encoder_list[0].config.target_hidden_size + self.text_encoder_list[1].config.target_hidden_size != \
|
||||||
image_prompt_embeds.shape[2]:
|
image_prompt_embeds.shape[2]:
|
||||||
raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes")
|
raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes")
|
||||||
# sdxl variants
|
# sdxl variants
|
||||||
# image_prompt_embeds = 2048
|
# image_prompt_embeds = 2048
|
||||||
# te1 = 768
|
# te1 = 768
|
||||||
# te2 = 1280
|
# te2 = 1280
|
||||||
te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.hidden_size]
|
te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.target_hidden_size]
|
||||||
te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.hidden_size:]
|
te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.target_hidden_size:]
|
||||||
self.set_vec(te1_embeds[0], text_encoder_idx=0)
|
self.set_vec(te1_embeds[0], text_encoder_idx=0)
|
||||||
self.set_vec(te2_embeds[0], text_encoder_idx=1)
|
self.set_vec(te2_embeds[0], text_encoder_idx=1)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -251,6 +251,7 @@ class TrainConfig:
|
|||||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||||
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
|
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||||
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
elif adapter_config.type == 'ip+':
|
elif adapter_config.type == 'ip+':
|
||||||
heads = 12 if not sd.is_xl else 20
|
heads = 12 if not sd.is_xl else 20
|
||||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
|
embedding_dim = self.image_encoder.config.target_hidden_size if not self.config.image_encoder_arch.startswith(
|
||||||
'convnext') else \
|
'convnext') else \
|
||||||
self.image_encoder.config.hidden_sizes[-1]
|
self.image_encoder.config.hidden_sizes[-1]
|
||||||
|
|
||||||
@@ -436,7 +436,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
||||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||||
else:
|
else:
|
||||||
embedding_dim = self.image_encoder.config.hidden_size
|
embedding_dim = self.image_encoder.config.target_hidden_size
|
||||||
|
|
||||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||||
# max_seq_len = CLIP tokens + CLS token
|
# max_seq_len = CLIP tokens + CLS token
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ class TEAdapter(torch.nn.Module):
|
|||||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||||
self.token_size = self.te_ref().config.d_model
|
self.token_size = self.te_ref().config.d_model
|
||||||
else:
|
else:
|
||||||
self.token_size = self.te_ref().config.hidden_size
|
self.token_size = self.te_ref().config.target_hidden_size
|
||||||
|
|
||||||
# add text projection if is sdxl
|
# add text projection if is sdxl
|
||||||
self.text_projection = None
|
self.text_projection = None
|
||||||
|
|||||||
Reference in New Issue
Block a user