mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes
This commit is contained in:
@@ -68,6 +68,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.sd.vae.to('cpu')
|
self.sd.vae.to('cpu')
|
||||||
flush()
|
flush()
|
||||||
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||||
|
if self.adapter is not None:
|
||||||
|
self.adapter.to(self.device_torch)
|
||||||
|
|
||||||
# you can expand these in a child class to make customization easier
|
# you can expand these in a child class to make customization easier
|
||||||
def calculate_loss(
|
def calculate_loss(
|
||||||
@@ -507,8 +509,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.sd.unet.train()
|
self.sd.unet.train()
|
||||||
prior_pred = prior_pred.detach()
|
prior_pred = prior_pred.detach()
|
||||||
# remove the residuals as we wont use them on prediction when matching control
|
# remove the residuals as we wont use them on prediction when matching control
|
||||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
||||||
del pred_kwargs['down_block_additional_residuals']
|
del pred_kwargs['down_intrablock_additional_residuals']
|
||||||
# restore network
|
# restore network
|
||||||
# self.network.multiplier = network_weight_list
|
# self.network.multiplier = network_weight_list
|
||||||
self.network.is_active = was_network_active
|
self.network.is_active = was_network_active
|
||||||
@@ -746,7 +748,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
down_block_additional_residuals
|
down_block_additional_residuals
|
||||||
]
|
]
|
||||||
|
|
||||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||||
|
|
||||||
prior_pred = None
|
prior_pred = None
|
||||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||||
|
|||||||
@@ -913,7 +913,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.load_training_state_from_metadata(latest_save_path)
|
self.load_training_state_from_metadata(latest_save_path)
|
||||||
|
|
||||||
# get the noise scheduler
|
# get the noise scheduler
|
||||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
sampler = get_sampler(
|
||||||
|
self.train_config.noise_scheduler,
|
||||||
|
{
|
||||||
|
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||||
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
||||||
@@ -1051,6 +1056,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||||
is_v2=self.model_config.is_v2,
|
is_v2=self.model_config.is_v2,
|
||||||
is_ssd=self.model_config.is_ssd,
|
is_ssd=self.model_config.is_ssd,
|
||||||
|
is_vega=self.model_config.is_vega,
|
||||||
dropout=self.network_config.dropout,
|
dropout=self.network_config.dropout,
|
||||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ parser.add_argument('--name', type=str, default='stable_diffusion', help='name f
|
|||||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||||
|
parser.add_argument('--vega', action='store_true', help='is vega model')
|
||||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -66,15 +67,15 @@ print(f'Loading diffusers model')
|
|||||||
|
|
||||||
ignore_ldm_begins_with = []
|
ignore_ldm_begins_with = []
|
||||||
|
|
||||||
diffusers_file_path = file_path
|
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
|
||||||
if args.ssd:
|
if args.ssd:
|
||||||
diffusers_file_path = "segmind/SSD-1B"
|
diffusers_file_path = "segmind/SSD-1B"
|
||||||
|
if args.vega:
|
||||||
|
diffusers_file_path = "segmind/Segmind-Vega"
|
||||||
|
|
||||||
# if args.refiner:
|
# if args.refiner:
|
||||||
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||||
|
|
||||||
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
|
|
||||||
|
|
||||||
if not args.refiner:
|
if not args.refiner:
|
||||||
|
|
||||||
diffusers_model_config = ModelConfig(
|
diffusers_model_config = ModelConfig(
|
||||||
@@ -82,6 +83,7 @@ if not args.refiner:
|
|||||||
is_xl=args.sdxl,
|
is_xl=args.sdxl,
|
||||||
is_v2=args.sd2,
|
is_v2=args.sd2,
|
||||||
is_ssd=args.ssd,
|
is_ssd=args.ssd,
|
||||||
|
is_vega=args.vega,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
diffusers_sd = StableDiffusion(
|
diffusers_sd = StableDiffusion(
|
||||||
@@ -157,7 +159,7 @@ te_suffix = ''
|
|||||||
proj_pattern_weight = None
|
proj_pattern_weight = None
|
||||||
proj_pattern_bias = None
|
proj_pattern_bias = None
|
||||||
text_proj_layer = None
|
text_proj_layer = None
|
||||||
if args.sdxl or args.ssd:
|
if args.sdxl or args.ssd or args.vega:
|
||||||
te_suffix = '1'
|
te_suffix = '1'
|
||||||
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
||||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||||
@@ -176,10 +178,13 @@ if args.sd2:
|
|||||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||||
|
|
||||||
if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
if args.sdxl or args.sd2 or args.ssd or args.refiner or args.vega:
|
||||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||||
|
elif "conditioner.embedders.1.model.text_projection.weight" in ldm_dict_keys:
|
||||||
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
|
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection.weight"].shape[0])
|
||||||
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
|
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
|
||||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
|
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
|
||||||
@@ -191,6 +196,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
|||||||
try:
|
try:
|
||||||
match = re.match(proj_pattern_weight, ldm_key)
|
match = re.match(proj_pattern_weight, ldm_key)
|
||||||
if match:
|
if match:
|
||||||
|
if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight":
|
||||||
|
print("here")
|
||||||
number = int(match.group(1))
|
number = int(match.group(1))
|
||||||
new_val = torch.cat([
|
new_val = torch.cat([
|
||||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
||||||
@@ -217,6 +224,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
matched_ldm_keys.append(ldm_key)
|
||||||
|
|
||||||
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||||
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
||||||
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
||||||
@@ -266,6 +275,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
matched_ldm_keys.append(ldm_key)
|
||||||
|
|
||||||
# add diffusers operators
|
# add diffusers operators
|
||||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = {
|
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = {
|
||||||
"slice": [
|
"slice": [
|
||||||
@@ -298,6 +309,9 @@ for ldm_key in ldm_dict_keys:
|
|||||||
ldm_shape_tuple = ldm_state_dict[ldm_key].shape
|
ldm_shape_tuple = ldm_state_dict[ldm_key].shape
|
||||||
ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple)
|
ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple)
|
||||||
for diffusers_key in diffusers_dict_keys:
|
for diffusers_key in diffusers_dict_keys:
|
||||||
|
if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight" and diffusers_key == "te1_text_model.encoder.layers.0.self_attn.q_proj.weight":
|
||||||
|
print("here")
|
||||||
|
|
||||||
diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape
|
diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape
|
||||||
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
||||||
|
|
||||||
@@ -356,6 +370,8 @@ if args.sdxl:
|
|||||||
name += '_sdxl'
|
name += '_sdxl'
|
||||||
elif args.ssd:
|
elif args.ssd:
|
||||||
name += '_ssd'
|
name += '_ssd'
|
||||||
|
elif args.vega:
|
||||||
|
name += '_vega'
|
||||||
elif args.refiner:
|
elif args.refiner:
|
||||||
name += '_refiner'
|
name += '_refiner'
|
||||||
elif args.sd2:
|
elif args.sd2:
|
||||||
|
|||||||
@@ -247,6 +247,7 @@ class ModelConfig:
|
|||||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||||
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
||||||
|
self.is_vega: bool = kwargs.get('is_vega', False)
|
||||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||||
self.vae_path = kwargs.get('vae_path', None)
|
self.vae_path = kwargs.get('vae_path', None)
|
||||||
@@ -267,6 +268,9 @@ class ModelConfig:
|
|||||||
# sed sdxl as true since it is mostly the same architecture
|
# sed sdxl as true since it is mostly the same architecture
|
||||||
self.is_xl = True
|
self.is_xl = True
|
||||||
|
|
||||||
|
if self.is_vega:
|
||||||
|
self.is_xl = True
|
||||||
|
|
||||||
|
|
||||||
class ReferenceDatasetConfig:
|
class ReferenceDatasetConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -402,7 +406,7 @@ class DatasetConfig:
|
|||||||
if legacy_caption_type:
|
if legacy_caption_type:
|
||||||
self.caption_ext = legacy_caption_type
|
self.caption_ext = legacy_caption_type
|
||||||
self.caption_type = self.caption_ext
|
self.caption_type = self.caption_ext
|
||||||
self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted_polarity')
|
self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted')
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
|||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
|
||||||
GuidanceType = Literal["targeted", "polarity", "targeted_polarity"]
|
GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
|
||||||
|
|
||||||
DIFFERENTIAL_SCALER = 0.2
|
DIFFERENTIAL_SCALER = 0.2
|
||||||
|
|
||||||
@@ -118,8 +118,8 @@ def get_targeted_polarity_loss(
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||||
sd.network.is_active = False
|
# sd.network.is_active = False
|
||||||
sd.unet.eval()
|
# sd.unet.eval()
|
||||||
|
|
||||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||||
# This acts as our control to preserve the unaltered parts of the image.
|
# This acts as our control to preserve the unaltered parts of the image.
|
||||||
@@ -133,15 +133,15 @@ def get_targeted_polarity_loss(
|
|||||||
|
|
||||||
# conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0)
|
# conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0)
|
||||||
|
|
||||||
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
# negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
||||||
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
# positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
||||||
cat_network_weight_list = positive_network_weights + negative_network_weights
|
# cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||||
|
|
||||||
# turn the LoRA network back on.
|
# turn the LoRA network back on.
|
||||||
sd.unet.train()
|
sd.unet.train()
|
||||||
sd.network.is_active = True
|
# sd.network.is_active = True
|
||||||
|
|
||||||
sd.network.multiplier = cat_network_weight_list
|
# sd.network.multiplier = cat_network_weight_list
|
||||||
|
|
||||||
# do our prediction with LoRA active on the scaled guidance latents
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
prediction = sd.predict_noise(
|
prediction = sd.predict_noise(
|
||||||
@@ -183,9 +183,7 @@ def get_targeted_polarity_loss(
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def get_direct_guidance_loss(
|
||||||
# targeted
|
|
||||||
def get_targeted_guidance_loss(
|
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
conditional_embeds: 'PromptEmbeds',
|
conditional_embeds: 'PromptEmbeds',
|
||||||
match_adapter_assist: bool,
|
match_adapter_assist: bool,
|
||||||
@@ -206,81 +204,45 @@ def get_targeted_guidance_loss(
|
|||||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||||
|
|
||||||
# # apply random offset to both latents
|
conditional_noisy_latents = sd.add_noise(
|
||||||
# offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
|
||||||
# offset = offset * 0.1
|
|
||||||
# conditional_latents = conditional_latents + offset
|
|
||||||
# unconditional_latents = unconditional_latents + offset
|
|
||||||
#
|
|
||||||
# # get random scale 0f 0.8 to 1.2
|
|
||||||
# scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
|
||||||
# scale = scale * 0.4
|
|
||||||
# scale = scale + 0.8
|
|
||||||
# conditional_latents = conditional_latents * scale
|
|
||||||
# unconditional_latents = unconditional_latents * scale
|
|
||||||
|
|
||||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
|
||||||
|
|
||||||
# scale it to the timestep
|
|
||||||
unconditional_diff_noise = sd.add_noise(
|
|
||||||
torch.zeros_like(unconditional_latents),
|
|
||||||
unconditional_diff,
|
|
||||||
timesteps
|
|
||||||
)
|
|
||||||
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
|
||||||
|
|
||||||
target_noise = noise + unconditional_diff_noise
|
|
||||||
|
|
||||||
noisy_latents = sd.add_noise(
|
|
||||||
conditional_latents,
|
conditional_latents,
|
||||||
target_noise,
|
# target_noise,
|
||||||
# noise,
|
noise,
|
||||||
timesteps
|
timesteps
|
||||||
).detach()
|
).detach()
|
||||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
|
||||||
sd.network.is_active = False
|
|
||||||
sd.unet.eval()
|
|
||||||
|
|
||||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
|
||||||
# This acts as our control to preserve the unaltered parts of the image.
|
|
||||||
baseline_prediction = sd.predict_noise(
|
|
||||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
|
||||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
|
||||||
timestep=timesteps,
|
|
||||||
guidance_scale=1.0,
|
|
||||||
**pred_kwargs # adapter residuals in here
|
|
||||||
).detach().requires_grad_(False)
|
|
||||||
|
|
||||||
# determine the error for the baseline prediction
|
|
||||||
baseline_prediction_error = baseline_prediction - noise
|
|
||||||
|
|
||||||
prediction_target = baseline_prediction_error + unconditional_diff_noise
|
|
||||||
|
|
||||||
prediction_target = prediction_target.detach().requires_grad_(False)
|
|
||||||
|
|
||||||
|
|
||||||
|
unconditional_noisy_latents = sd.add_noise(
|
||||||
|
unconditional_latents,
|
||||||
|
noise,
|
||||||
|
timesteps
|
||||||
|
).detach()
|
||||||
# turn the LoRA network back on.
|
# turn the LoRA network back on.
|
||||||
sd.unet.train()
|
sd.unet.train()
|
||||||
sd.network.is_active = True
|
# sd.network.is_active = True
|
||||||
|
|
||||||
sd.network.multiplier = network_weight_list
|
# sd.network.multiplier = network_weight_list
|
||||||
# do our prediction with LoRA active on the scaled guidance latents
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
prediction = sd.predict_noise(
|
prediction = sd.predict_noise(
|
||||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
|
||||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
|
||||||
timestep=timesteps,
|
timestep=torch.cat([timesteps, timesteps]),
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
)
|
)
|
||||||
|
|
||||||
prediction_error = prediction - noise
|
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
|
||||||
|
|
||||||
|
guidance_scale = 1.0
|
||||||
|
guidance_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_cond - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
guidance_loss = torch.nn.functional.mse_loss(
|
guidance_loss = torch.nn.functional.mse_loss(
|
||||||
prediction_error.float(),
|
guidance_pred.float(),
|
||||||
# unconditional_diff_noise.float(),
|
noise.detach().float(),
|
||||||
prediction_target.float(),
|
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
|
|
||||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
guidance_loss = guidance_loss.mean()
|
guidance_loss = guidance_loss.mean()
|
||||||
@@ -297,6 +259,242 @@ def get_targeted_guidance_loss(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# targeted
|
||||||
|
def get_targeted_guidance_loss(
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
conditional_embeds: 'PromptEmbeds',
|
||||||
|
match_adapter_assist: bool,
|
||||||
|
network_weight_list: list,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
pred_kwargs: dict,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
noise: torch.Tensor,
|
||||||
|
sd: 'StableDiffusion',
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
dtype = get_torch_dtype(sd.torch_dtype)
|
||||||
|
device = sd.device_torch
|
||||||
|
|
||||||
|
# create the differential mask from the actual tensors
|
||||||
|
conditional_imgs = batch.tensor.to(device, dtype=dtype).detach()
|
||||||
|
unconditional_imgs = batch.unconditional_tensor.to(device, dtype=dtype).detach()
|
||||||
|
differential_mask = torch.abs(conditional_imgs - unconditional_imgs)
|
||||||
|
differential_mask = differential_mask - differential_mask.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
|
||||||
|
differential_mask = differential_mask / differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||||
|
|
||||||
|
# differential_mask is (bs, 3, width, height)
|
||||||
|
# latents are (bs, 4, width, height)
|
||||||
|
# reduce the mean on dim 1 to get a single channel mask and stack it to match latents
|
||||||
|
differential_mask = differential_mask.mean(dim=1, keepdim=True)
|
||||||
|
differential_mask = torch.cat([differential_mask] * 4, dim=1)
|
||||||
|
|
||||||
|
# scale the mask down to latent size
|
||||||
|
differential_mask = torch.nn.functional.interpolate(
|
||||||
|
differential_mask,
|
||||||
|
size=noisy_latents.shape[2:],
|
||||||
|
mode="nearest"
|
||||||
|
)
|
||||||
|
|
||||||
|
conditional_noisy_latents = noisy_latents
|
||||||
|
|
||||||
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||||
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||||
|
|
||||||
|
# unconditional_as_noise = unconditional_latents - conditional_latents
|
||||||
|
# conditional_as_noise = conditional_latents - unconditional_latents
|
||||||
|
|
||||||
|
# Encode the unconditional image into latents
|
||||||
|
unconditional_noisy_latents = sd.noise_scheduler.add_noise(
|
||||||
|
unconditional_latents,
|
||||||
|
noise,
|
||||||
|
timesteps
|
||||||
|
)
|
||||||
|
conditional_noisy_latents = sd.noise_scheduler.add_noise(
|
||||||
|
conditional_latents,
|
||||||
|
noise,
|
||||||
|
timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
# was_network_active = self.network.is_active
|
||||||
|
sd.network.is_active = False
|
||||||
|
sd.unet.eval()
|
||||||
|
|
||||||
|
|
||||||
|
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
|
||||||
|
# target_differential = unconditional_noisy_latents - conditional_noisy_latents
|
||||||
|
target_differential = unconditional_latents - conditional_latents
|
||||||
|
# target_differential = conditional_latents - unconditional_latents
|
||||||
|
|
||||||
|
# scale the target differential by the scheduler
|
||||||
|
# todo, scale it the right way
|
||||||
|
# target_differential = sd.noise_scheduler.add_noise(
|
||||||
|
# torch.zeros_like(target_differential),
|
||||||
|
# target_differential,
|
||||||
|
# timesteps
|
||||||
|
# )
|
||||||
|
|
||||||
|
# noise_abs_mean = torch.abs(noise + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||||
|
|
||||||
|
# target_differential = target_differential.detach()
|
||||||
|
# target_differential_abs_mean = torch.abs(target_differential + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||||
|
# # determins scaler to adjust to same abs mean as noise
|
||||||
|
# scaler = noise_abs_mean / target_differential_abs_mean
|
||||||
|
|
||||||
|
|
||||||
|
target_differential_knowledge = target_differential
|
||||||
|
target_differential_knowledge = target_differential_knowledge.detach()
|
||||||
|
|
||||||
|
# add the target differential to the target latents as if it were noise with the scheduler scaled to
|
||||||
|
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
|
||||||
|
# properly
|
||||||
|
# guidance_latents = sd.noise_scheduler.add_noise(
|
||||||
|
# conditional_noisy_latents,
|
||||||
|
# target_differential,
|
||||||
|
# timesteps
|
||||||
|
# )
|
||||||
|
|
||||||
|
# guidance_latents = conditional_noisy_latents + target_differential
|
||||||
|
# target_noise = conditional_noisy_latents + target_differential
|
||||||
|
|
||||||
|
# With LoRA network bypassed, predict noise to get a baseline of what the network
|
||||||
|
# wants to do with the latents + noise. Pass our target latents here for the input.
|
||||||
|
target_unconditional = sd.predict_noise(
|
||||||
|
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
).detach()
|
||||||
|
# target_conditional = sd.predict_noise(
|
||||||
|
# latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
|
||||||
|
# conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||||
|
# timestep=timesteps,
|
||||||
|
# guidance_scale=1.0,
|
||||||
|
# **pred_kwargs # adapter residuals in here
|
||||||
|
# ).detach()
|
||||||
|
|
||||||
|
# we calculate the networks current knowledge so we do not overlearn what we know
|
||||||
|
# parent_knowledge = target_unconditional - target_conditional
|
||||||
|
# parent_knowledge = parent_knowledge.detach()
|
||||||
|
# del target_conditional
|
||||||
|
# del target_unconditional
|
||||||
|
|
||||||
|
# we now have the differential noise prediction needed to create our convergence target
|
||||||
|
# target_unknown_knowledge = target_differential + parent_knowledge
|
||||||
|
# del parent_knowledge
|
||||||
|
prior_prediction_loss = torch.nn.functional.mse_loss(
|
||||||
|
target_unconditional.float(),
|
||||||
|
noise.float(),
|
||||||
|
reduction="none"
|
||||||
|
).detach().clone()
|
||||||
|
|
||||||
|
# turn the LoRA network back on.
|
||||||
|
sd.unet.train()
|
||||||
|
sd.network.is_active = True
|
||||||
|
sd.network.multiplier = network_weight_list
|
||||||
|
|
||||||
|
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
|
||||||
|
# the opportunity to predict the differential + noise that was added to the latents.
|
||||||
|
prediction_conditional = sd.predict_noise(
|
||||||
|
latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
|
||||||
|
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||||
|
timestep=timesteps,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
**pred_kwargs # adapter residuals in here
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
|
||||||
|
# the prediction of the added differential noise
|
||||||
|
# prediction_positive = prediction_unconditional - target_unconditional
|
||||||
|
# current_knowledge = target_unconditional - prediction_conditional
|
||||||
|
# current_differential_knowledge = prediction_conditional - target_unconditional
|
||||||
|
|
||||||
|
# current_unknown_knowledge = parent_knowledge - current_knowledge
|
||||||
|
#
|
||||||
|
# current_unknown_knowledge_abs_mean = torch.abs(current_unknown_knowledge + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||||
|
# current_unknown_knowledge_std = current_unknown_knowledge / current_unknown_knowledge_abs_mean
|
||||||
|
|
||||||
|
|
||||||
|
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||||
|
# this is the diffusion training process.
|
||||||
|
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||||
|
# differential between the conditional and unconditional images
|
||||||
|
|
||||||
|
# positive_loss = torch.nn.functional.mse_loss(
|
||||||
|
# current_differential_knowledge.float(),
|
||||||
|
# target_differential_knowledge.float(),
|
||||||
|
# reduction="none"
|
||||||
|
# )
|
||||||
|
|
||||||
|
normal_loss = torch.nn.functional.mse_loss(
|
||||||
|
prediction_conditional.float(),
|
||||||
|
noise.float(),
|
||||||
|
reduction="none"
|
||||||
|
)
|
||||||
|
#
|
||||||
|
# # scale positive and neutral loss to the same scale
|
||||||
|
# positive_loss_abs_mean = torch.abs(positive_loss + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||||
|
# normal_loss_abs_mean = torch.abs(normal_loss + 1e-6).mean(dim=[1, 2, 3], keepdim=True)
|
||||||
|
# scaler = normal_loss_abs_mean / positive_loss_abs_mean
|
||||||
|
# positive_loss = positive_loss * scaler
|
||||||
|
|
||||||
|
# positive_loss = positive_loss * differential_mask
|
||||||
|
# positive_loss = positive_loss
|
||||||
|
# masked_normal_loss = normal_loss * differential_mask
|
||||||
|
|
||||||
|
prior_loss = torch.abs(
|
||||||
|
normal_loss.float() - prior_prediction_loss.float(),
|
||||||
|
# ) * (1 - differential_mask)
|
||||||
|
)
|
||||||
|
|
||||||
|
decouple = True
|
||||||
|
|
||||||
|
# positive_loss_full = positive_loss
|
||||||
|
# prior_loss_full = prior_loss
|
||||||
|
#
|
||||||
|
# current_scaler = (prior_loss_full.max() / positive_loss_full.max())
|
||||||
|
# # positive_loss = positive_loss * current_scaler
|
||||||
|
# avg_scaler_arr.append(current_scaler.item())
|
||||||
|
# avg_scaler = sum(avg_scaler_arr) / len(avg_scaler_arr)
|
||||||
|
# print(f"avg scaler: {avg_scaler}, current scaler: {current_scaler.item()}")
|
||||||
|
# # remove extra scalers more than 100
|
||||||
|
# if len(avg_scaler_arr) > 100:
|
||||||
|
# avg_scaler_arr.pop(0)
|
||||||
|
#
|
||||||
|
# # positive_loss = positive_loss * avg_scaler
|
||||||
|
# positive_loss = positive_loss * avg_scaler * 0.1
|
||||||
|
|
||||||
|
if decouple:
|
||||||
|
# positive_loss = positive_loss.mean([1, 2, 3])
|
||||||
|
prior_loss = prior_loss.mean([1, 2, 3])
|
||||||
|
# masked_normal_loss = masked_normal_loss.mean([1, 2, 3])
|
||||||
|
positive_loss = prior_loss
|
||||||
|
# positive_loss = positive_loss + prior_loss
|
||||||
|
else:
|
||||||
|
|
||||||
|
# positive_loss = positive_loss + prior_loss
|
||||||
|
positive_loss = prior_loss
|
||||||
|
positive_loss = positive_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
# positive_loss = positive_loss + adain_loss.mean([1, 2, 3])
|
||||||
|
# send it backwards BEFORE switching network polarity
|
||||||
|
# positive_loss = self.apply_snr(positive_loss, timesteps)
|
||||||
|
positive_loss = positive_loss.mean()
|
||||||
|
positive_loss.backward()
|
||||||
|
# loss = positive_loss.detach() + negative_loss.detach()
|
||||||
|
loss = positive_loss.detach()
|
||||||
|
|
||||||
|
# add a grad so other backward does not fail
|
||||||
|
loss.requires_grad_(True)
|
||||||
|
|
||||||
|
# restore network
|
||||||
|
sd.network.multiplier = network_weight_list
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
def get_guided_loss_polarity(
|
def get_guided_loss_polarity(
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
conditional_embeds: PromptEmbeds,
|
conditional_embeds: PromptEmbeds,
|
||||||
@@ -360,17 +558,17 @@ def get_guided_loss_polarity(
|
|||||||
noise.float(),
|
noise.float(),
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
pred_loss = pred_loss.mean([1, 2, 3])
|
# pred_loss = pred_loss.mean([1, 2, 3])
|
||||||
|
|
||||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||||
pred_neg.float(),
|
pred_neg.float(),
|
||||||
noise.float(),
|
noise.float(),
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
|
||||||
|
|
||||||
loss = pred_loss + pred_neg_loss
|
loss = pred_loss + pred_neg_loss
|
||||||
|
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
@@ -437,5 +635,18 @@ def get_guidance_loss(
|
|||||||
sd,
|
sd,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
elif guidance_type == "direct":
|
||||||
|
return get_direct_guidance_loss(
|
||||||
|
noisy_latents,
|
||||||
|
conditional_embeds,
|
||||||
|
match_adapter_assist,
|
||||||
|
network_weight_list,
|
||||||
|
timesteps,
|
||||||
|
pred_kwargs,
|
||||||
|
batch,
|
||||||
|
noise,
|
||||||
|
sd,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")
|
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")
|
||||||
|
|||||||
3039
toolkit/keymaps/stable_diffusion_vega.json
Normal file
3039
toolkit/keymaps/stable_diffusion_vega.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors
Normal file
Binary file not shown.
@@ -111,8 +111,11 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE = ["''UNet2DConditionModel''"]
|
||||||
|
# UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["'UNet2DConditionModel'"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = "lora_unet"
|
LORA_PREFIX_UNET = "lora_unet"
|
||||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
@@ -230,8 +233,90 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
)
|
)
|
||||||
loras = []
|
loras = []
|
||||||
skipped = []
|
skipped = []
|
||||||
|
attached_modules = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if is_unet:
|
||||||
|
module_name = module.__class__.__name__
|
||||||
|
if module not in attached_modules:
|
||||||
|
# if module.__class__.__name__ in target_replace_modules:
|
||||||
|
# for child_name, child_module in module.named_modules():
|
||||||
|
is_linear = module_name == 'LoRACompatibleLinear'
|
||||||
|
is_conv2d = module_name == 'LoRACompatibleConv'
|
||||||
|
|
||||||
|
if is_linear and self.lora_dim is None:
|
||||||
|
continue
|
||||||
|
if is_conv2d and self.conv_lora_dim is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_conv2d_1x1 = is_conv2d and module.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
if is_conv2d_1x1:
|
||||||
|
pass
|
||||||
|
|
||||||
|
skip = False
|
||||||
|
if any([word in name for word in self.ignore_if_contains]):
|
||||||
|
skip = True
|
||||||
|
|
||||||
|
# see if it is over threshold
|
||||||
|
if count_parameters(module) < parameter_threshold:
|
||||||
|
skip = True
|
||||||
|
|
||||||
|
if (is_linear or is_conv2d) and not skip:
|
||||||
|
lora_name = prefix + "." + name
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
|
dim = None
|
||||||
|
alpha = None
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
# モジュール指定あり
|
||||||
|
if lora_name in modules_dim:
|
||||||
|
dim = modules_dim[lora_name]
|
||||||
|
alpha = modules_alpha[lora_name]
|
||||||
|
elif is_unet and block_dims is not None:
|
||||||
|
# U-Netでblock_dims指定あり
|
||||||
|
block_idx = get_block_index(lora_name)
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
dim = block_dims[block_idx]
|
||||||
|
alpha = block_alphas[block_idx]
|
||||||
|
elif conv_block_dims is not None:
|
||||||
|
dim = conv_block_dims[block_idx]
|
||||||
|
alpha = conv_block_alphas[block_idx]
|
||||||
|
else:
|
||||||
|
# 通常、すべて対象とする
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
dim = self.lora_dim
|
||||||
|
alpha = self.alpha
|
||||||
|
elif self.conv_lora_dim is not None:
|
||||||
|
dim = self.conv_lora_dim
|
||||||
|
alpha = self.conv_alpha
|
||||||
|
else:
|
||||||
|
dim = None
|
||||||
|
alpha = None
|
||||||
|
|
||||||
|
if dim is None or dim == 0:
|
||||||
|
# skipした情報を出力
|
||||||
|
if is_linear or is_conv2d_1x1 or (
|
||||||
|
self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||||
|
skipped.append(lora_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora = module_class(
|
||||||
|
lora_name,
|
||||||
|
module,
|
||||||
|
self.multiplier,
|
||||||
|
dim,
|
||||||
|
alpha,
|
||||||
|
dropout=dropout,
|
||||||
|
rank_dropout=rank_dropout,
|
||||||
|
module_dropout=module_dropout,
|
||||||
|
network=self,
|
||||||
|
parent=module,
|
||||||
|
use_bias=use_bias,
|
||||||
|
)
|
||||||
|
loras.append(lora)
|
||||||
|
attached_modules.append(module)
|
||||||
|
elif module.__class__.__name__ in target_replace_modules:
|
||||||
for child_name, child_module in module.named_modules():
|
for child_name, child_module in module.named_modules():
|
||||||
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
|
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
|
||||||
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
||||||
|
|||||||
@@ -340,6 +340,7 @@ class ToolkitNetworkMixin:
|
|||||||
is_sdxl=False,
|
is_sdxl=False,
|
||||||
is_v2=False,
|
is_v2=False,
|
||||||
is_ssd=False,
|
is_ssd=False,
|
||||||
|
is_vega=False,
|
||||||
network_config: Optional[NetworkConfig] = None,
|
network_config: Optional[NetworkConfig] = None,
|
||||||
is_lorm=False,
|
is_lorm=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -351,6 +352,7 @@ class ToolkitNetworkMixin:
|
|||||||
self.is_active: bool = False
|
self.is_active: bool = False
|
||||||
self.is_sdxl = is_sdxl
|
self.is_sdxl = is_sdxl
|
||||||
self.is_ssd = is_ssd
|
self.is_ssd = is_ssd
|
||||||
|
self.is_vega = is_vega
|
||||||
self.is_v2 = is_v2
|
self.is_v2 = is_v2
|
||||||
self.is_merged_in = False
|
self.is_merged_in = False
|
||||||
self.is_lorm = is_lorm
|
self.is_lorm = is_lorm
|
||||||
@@ -365,6 +367,9 @@ class ToolkitNetworkMixin:
|
|||||||
if self.is_ssd:
|
if self.is_ssd:
|
||||||
keymap_tail = 'ssd'
|
keymap_tail = 'ssd'
|
||||||
use_weight_mapping = True
|
use_weight_mapping = True
|
||||||
|
elif self.is_vega:
|
||||||
|
keymap_tail = 'vega'
|
||||||
|
use_weight_mapping = True
|
||||||
elif self.is_sdxl:
|
elif self.is_sdxl:
|
||||||
keymap_tail = 'sdxl'
|
keymap_tail = 'sdxl'
|
||||||
elif self.is_v2:
|
elif self.is_v2:
|
||||||
|
|||||||
@@ -46,8 +46,11 @@ sdxl_sampler_config = {
|
|||||||
|
|
||||||
def get_sampler(
|
def get_sampler(
|
||||||
sampler: str,
|
sampler: str,
|
||||||
|
kwargs: dict = None,
|
||||||
):
|
):
|
||||||
sched_init_args = {}
|
sched_init_args = {}
|
||||||
|
if kwargs is not None:
|
||||||
|
sched_init_args.update(kwargs)
|
||||||
|
|
||||||
if sampler.startswith("k_"):
|
if sampler.startswith("k_"):
|
||||||
sched_init_args["use_karras_sigmas"] = True
|
sched_init_args["use_karras_sigmas"] = True
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
|||||||
|
|
||||||
def get_ldm_state_dict_from_diffusers(
|
def get_ldm_state_dict_from_diffusers(
|
||||||
state_dict: 'OrderedDict',
|
state_dict: 'OrderedDict',
|
||||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
|
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega', 'sdxl_refiner'] = '2',
|
||||||
device='cpu',
|
device='cpu',
|
||||||
dtype=get_torch_dtype('fp32'),
|
dtype=get_torch_dtype('fp32'),
|
||||||
):
|
):
|
||||||
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
|
|||||||
# load our base
|
# load our base
|
||||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||||
|
elif sd_version == 'vega':
|
||||||
|
# load our base
|
||||||
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega_ldm_base.safetensors')
|
||||||
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega.json')
|
||||||
elif sd_version == 'sdxl_refiner':
|
elif sd_version == 'sdxl_refiner':
|
||||||
# load our base
|
# load our base
|
||||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
|
||||||
@@ -137,7 +141,7 @@ def save_ldm_model_from_diffusers(
|
|||||||
output_file: str,
|
output_file: str,
|
||||||
meta: 'OrderedDict',
|
meta: 'OrderedDict',
|
||||||
save_dtype=get_torch_dtype('fp16'),
|
save_dtype=get_torch_dtype('fp16'),
|
||||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
|
||||||
):
|
):
|
||||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||||
sd.state_dict(),
|
sd.state_dict(),
|
||||||
@@ -156,11 +160,11 @@ def save_lora_from_diffusers(
|
|||||||
output_file: str,
|
output_file: str,
|
||||||
meta: 'OrderedDict',
|
meta: 'OrderedDict',
|
||||||
save_dtype=get_torch_dtype('fp16'),
|
save_dtype=get_torch_dtype('fp16'),
|
||||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
|
||||||
):
|
):
|
||||||
converted_state_dict = OrderedDict()
|
converted_state_dict = OrderedDict()
|
||||||
# only handle sxdxl for now
|
# only handle sxdxl for now
|
||||||
if sd_version != 'sdxl' and sd_version != 'ssd':
|
if sd_version != 'sdxl' and sd_version != 'ssd' and sd_version != 'vega':
|
||||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||||
for key, value in lora_state_dict.items():
|
for key, value in lora_state_dict.items():
|
||||||
# todo verify if this works with ssd
|
# todo verify if this works with ssd
|
||||||
|
|||||||
@@ -84,5 +84,8 @@ def get_train_sd_device_state_preset(
|
|||||||
preset['adapter']['training'] = True
|
preset['adapter']['training'] = True
|
||||||
preset['adapter']['device'] = device
|
preset['adapter']['device'] = device
|
||||||
preset['unet']['training'] = True
|
preset['unet']['training'] = True
|
||||||
|
preset['unet']['requires_grad'] = False
|
||||||
|
preset['unet']['device'] = device
|
||||||
|
preset['text_encoder']['device'] = device
|
||||||
|
|
||||||
return preset
|
return preset
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class StableDiffusion:
|
|||||||
self.is_xl = model_config.is_xl
|
self.is_xl = model_config.is_xl
|
||||||
self.is_v2 = model_config.is_v2
|
self.is_v2 = model_config.is_v2
|
||||||
self.is_ssd = model_config.is_ssd
|
self.is_ssd = model_config.is_ssd
|
||||||
|
self.is_vega = model_config.is_vega
|
||||||
|
|
||||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||||
@@ -149,7 +150,10 @@ class StableDiffusion:
|
|||||||
dtype = get_torch_dtype(self.dtype)
|
dtype = get_torch_dtype(self.dtype)
|
||||||
# sch = KDPM2DiscreteScheduler
|
# sch = KDPM2DiscreteScheduler
|
||||||
if self.noise_scheduler is None:
|
if self.noise_scheduler is None:
|
||||||
scheduler = get_sampler('ddpm')
|
scheduler = get_sampler(
|
||||||
|
'ddpm', {
|
||||||
|
"prediction_type": self.prediction_type,
|
||||||
|
})
|
||||||
self.noise_scheduler = scheduler
|
self.noise_scheduler = scheduler
|
||||||
|
|
||||||
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
||||||
@@ -169,7 +173,7 @@ class StableDiffusion:
|
|||||||
if self.model_config.vae_path is not None:
|
if self.model_config.vae_path is not None:
|
||||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||||
|
|
||||||
if self.model_config.is_xl or self.model_config.is_ssd:
|
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
||||||
if self.custom_pipeline is not None:
|
if self.custom_pipeline is not None:
|
||||||
pipln = self.custom_pipeline
|
pipln = self.custom_pipeline
|
||||||
else:
|
else:
|
||||||
@@ -358,9 +362,17 @@ class StableDiffusion:
|
|||||||
if sampler is not None:
|
if sampler is not None:
|
||||||
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
||||||
# using ksampler
|
# using ksampler
|
||||||
noise_scheduler = get_sampler('lms')
|
noise_scheduler = get_sampler(
|
||||||
|
'lms', {
|
||||||
|
"prediction_type": self.prediction_type,
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
noise_scheduler = get_sampler(sampler)
|
noise_scheduler = get_sampler(
|
||||||
|
sampler,
|
||||||
|
{
|
||||||
|
"prediction_type": self.prediction_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
|
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
|
||||||
@@ -674,7 +686,6 @@ class StableDiffusion:
|
|||||||
'EulerDiscreteSchedulerOutput',
|
'EulerDiscreteSchedulerOutput',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# todo handle if timestep is single value
|
# todo handle if timestep is single value
|
||||||
|
|
||||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||||
@@ -692,10 +703,12 @@ class StableDiffusion:
|
|||||||
noise_timesteps = timesteps_chunks[idx]
|
noise_timesteps = timesteps_chunks[idx]
|
||||||
if scheduler_class_name == 'DPMSolverMultistepScheduler':
|
if scheduler_class_name == 'DPMSolverMultistepScheduler':
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.noise_scheduler.sigmas.to(device=original_samples_chunks[idx].device, dtype=original_samples_chunks[idx].dtype)
|
sigmas = self.noise_scheduler.sigmas.to(device=original_samples_chunks[idx].device,
|
||||||
|
dtype=original_samples_chunks[idx].dtype)
|
||||||
if original_samples_chunks[idx].device.type == "mps" and torch.is_floating_point(noise_timesteps):
|
if original_samples_chunks[idx].device.type == "mps" and torch.is_floating_point(noise_timesteps):
|
||||||
# mps does not support float64
|
# mps does not support float64
|
||||||
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device, dtype=torch.float32)
|
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device,
|
||||||
|
dtype=torch.float32)
|
||||||
noise_timesteps = noise_timesteps.to(original_samples_chunks[idx].device, dtype=torch.float32)
|
noise_timesteps = noise_timesteps.to(original_samples_chunks[idx].device, dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device)
|
schedule_timesteps = self.noise_scheduler.timesteps.to(original_samples_chunks[idx].device)
|
||||||
@@ -719,7 +732,8 @@ class StableDiffusion:
|
|||||||
noisy_samples = alpha_t * original_samples + sigma_t * noise_chunks[idx]
|
noisy_samples = alpha_t * original_samples + sigma_t * noise_chunks[idx]
|
||||||
noisy_latents = noisy_samples
|
noisy_latents = noisy_samples
|
||||||
else:
|
else:
|
||||||
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], noise_timesteps)
|
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx],
|
||||||
|
noise_timesteps)
|
||||||
noisy_latents_chunks.append(noisy_latents)
|
noisy_latents_chunks.append(noisy_latents)
|
||||||
|
|
||||||
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
||||||
@@ -777,7 +791,6 @@ class StableDiffusion:
|
|||||||
else:
|
else:
|
||||||
timestep = timestep.repeat(latents.shape[0], 0)
|
timestep = timestep.repeat(latents.shape[0], 0)
|
||||||
|
|
||||||
|
|
||||||
def scale_model_input(model_input, timestep_tensor):
|
def scale_model_input(model_input, timestep_tensor):
|
||||||
if is_input_scaled:
|
if is_input_scaled:
|
||||||
return model_input
|
return model_input
|
||||||
@@ -986,7 +999,6 @@ class StableDiffusion:
|
|||||||
):
|
):
|
||||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||||
|
|
||||||
|
|
||||||
for timestep in tqdm(timesteps_to_run, leave=False):
|
for timestep in tqdm(timesteps_to_run, leave=False):
|
||||||
timestep = timestep.unsqueeze_(0)
|
timestep = timestep.unsqueeze_(0)
|
||||||
noise_pred = self.predict_noise(
|
noise_pred = self.predict_noise(
|
||||||
@@ -1290,7 +1302,6 @@ class StableDiffusion:
|
|||||||
output_config_path = f"{output_path_no_ext}.yaml"
|
output_config_path = f"{output_path_no_ext}.yaml"
|
||||||
shutil.copyfile(self.config_file, output_config_path)
|
shutil.copyfile(self.config_file, output_config_path)
|
||||||
|
|
||||||
|
|
||||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||||
version_string = '1'
|
version_string = '1'
|
||||||
if self.is_v2:
|
if self.is_v2:
|
||||||
@@ -1300,6 +1311,8 @@ class StableDiffusion:
|
|||||||
if self.is_ssd:
|
if self.is_ssd:
|
||||||
# overwrite sdxl because both wil be true here
|
# overwrite sdxl because both wil be true here
|
||||||
version_string = 'ssd'
|
version_string = 'ssd'
|
||||||
|
if self.is_ssd and self.is_vega:
|
||||||
|
version_string = 'vega'
|
||||||
# if output file does not end in .safetensors, then it is a directory and we are
|
# if output file does not end in .safetensors, then it is a directory and we are
|
||||||
# saving in diffusers format
|
# saving in diffusers format
|
||||||
if not output_file.endswith('.safetensors'):
|
if not output_file.endswith('.safetensors'):
|
||||||
|
|||||||
@@ -776,15 +776,19 @@ def apply_snr_weight(
|
|||||||
):
|
):
|
||||||
# will get it from noise scheduler if exist or will calculate it if not
|
# will get it from noise scheduler if exist or will calculate it if not
|
||||||
all_snr = get_all_snr(noise_scheduler, loss.device)
|
all_snr = get_all_snr(noise_scheduler, loss.device)
|
||||||
step_indices = []
|
# step_indices = []
|
||||||
for t in timesteps:
|
# for t in timesteps:
|
||||||
for i, st in enumerate(noise_scheduler.timesteps):
|
# for i, st in enumerate(noise_scheduler.timesteps):
|
||||||
if st == t:
|
# if st == t:
|
||||||
step_indices.append(i)
|
# step_indices.append(i)
|
||||||
break
|
# break
|
||||||
# this breaks on some schedulers
|
# this breaks on some schedulers
|
||||||
# step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
# step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
||||||
snr = torch.stack([all_snr[t] for t in step_indices])
|
|
||||||
|
offset = 0
|
||||||
|
if noise_scheduler.timesteps[0] == 1000:
|
||||||
|
offset = 1
|
||||||
|
snr = torch.stack([all_snr[t - offset] for t in timesteps])
|
||||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||||
if fixed:
|
if fixed:
|
||||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||||
|
|||||||
Reference in New Issue
Block a user