mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Various features and fixes. Too much brain fog to do a proper description
This commit is contained in:
@@ -356,7 +356,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# we have to encode images into latents for now
|
||||
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
||||
with torch.no_grad():
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
|
||||
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
|
||||
target = unaugmented_latents.detach()
|
||||
|
||||
@@ -907,6 +907,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
|
||||
@@ -83,7 +83,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
self.network_config = None
|
||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||
model_config = self.get_conf('model', {})
|
||||
|
||||
# update modelconfig dtype to match train
|
||||
model_config['dtype'] = self.train_config.dtype
|
||||
self.model_config = ModelConfig(**model_config)
|
||||
|
||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
||||
first_sample_config = self.get_conf('first_sample', None)
|
||||
@@ -723,6 +728,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||
dtype=noise.dtype) * 2 - 1
|
||||
|
||||
# multiply by shift amount
|
||||
noise_shift *= self.train_config.random_noise_shift
|
||||
|
||||
# add to noise
|
||||
noise += noise_shift
|
||||
|
||||
return noise
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
@@ -160,6 +160,8 @@ class AdapterConfig:
|
||||
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
|
||||
if self.train_only_image_encoder:
|
||||
self.train_image_encoder = True
|
||||
self.train_only_image_encoder_positional_embedding: bool = kwargs.get(
|
||||
'train_only_image_encoder_positional_embedding', False)
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
||||
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
||||
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
||||
@@ -260,6 +262,7 @@ class TrainConfig:
|
||||
# multiplier applied to loos on regularization images
|
||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
|
||||
# dropout that happens before encoding. It functions independently per text encoder
|
||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||
@@ -385,6 +388,11 @@ class ModelConfig:
|
||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
|
||||
self.unet_path = kwargs.get("unet_path", None)
|
||||
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
||||
self.vae_device = kwargs.get("vae_device", None)
|
||||
self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
|
||||
self.te_device = kwargs.get("te_device", None)
|
||||
self.te_dtype = kwargs.get("te_dtype", self.dtype)
|
||||
pass
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
|
||||
@@ -394,7 +394,7 @@ class IPAdapter(torch.nn.Module):
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.target_hidden_size if not self.config.image_encoder_arch.startswith(
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
|
||||
'convnext') else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
@@ -964,7 +964,10 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
if self.config.train_only_image_encoder_positional_embedding:
|
||||
yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse)
|
||||
else:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
return
|
||||
if self.config.train_scaler:
|
||||
# no params
|
||||
|
||||
@@ -21,7 +21,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
@@ -202,6 +202,10 @@ class TEAdapterAttnProcessor(nn.Module):
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
# remove attn mask if doing clip
|
||||
if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip":
|
||||
attention_mask = None
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
@@ -246,7 +250,7 @@ class TEAdapter(torch.nn.Module):
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
else:
|
||||
self.token_size = self.te_ref().config.target_hidden_size
|
||||
self.token_size = self.te_ref().config.hidden_size
|
||||
|
||||
# add text projection if is sdxl
|
||||
self.text_projection = None
|
||||
@@ -388,8 +392,17 @@ class TEAdapter(torch.nn.Module):
|
||||
# ).input_ids.to(te.device)
|
||||
# outputs = te(input_ids=input_ids)
|
||||
# outputs = outputs.last_hidden_state
|
||||
if self.adapter_ref().config.text_encoder_arch == "clip":
|
||||
embeds = train_tools.encode_prompts(
|
||||
tokenizer,
|
||||
te,
|
||||
text,
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
attention_mask = torch.ones(embeds.shape[:2], device=embeds.device)
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
elif self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
# just use aura pile
|
||||
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
||||
tokenizer,
|
||||
@@ -407,7 +420,8 @@ class TEAdapter(torch.nn.Module):
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||
if attention_mask is not None:
|
||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||
if self.text_projection is not None:
|
||||
# pool the output of embeds ignoring 0 in the attention mask
|
||||
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
|
||||
@@ -420,19 +434,19 @@ class TEAdapter(torch.nn.Module):
|
||||
|
||||
pooled_embeds = self.text_projection(pooled_output)
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
prompt_embeds = PromptEmbeds(
|
||||
(embeds, pooled_embeds),
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
else:
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
prompt_embeds = PromptEmbeds(
|
||||
embeds,
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
return t5_embeds
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -123,6 +123,13 @@ class StableDiffusion:
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
self.device_torch = torch.device(self.device)
|
||||
|
||||
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device)
|
||||
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||
|
||||
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device)
|
||||
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||
|
||||
self.model_config = model_config
|
||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
|
||||
@@ -220,11 +227,13 @@ class StableDiffusion:
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
|
||||
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
|
||||
if self.model_config.experimental_xl:
|
||||
print("Experimental XL mode enabled")
|
||||
print("Loading and injecting alt weights")
|
||||
@@ -333,6 +342,8 @@ class StableDiffusion:
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
|
||||
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||
|
||||
if self.model_config.is_pixart_sigma:
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
@@ -375,6 +386,8 @@ class StableDiffusion:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
|
||||
|
||||
elif self.model_config.is_auraflow:
|
||||
te_kwargs = {}
|
||||
@@ -427,7 +440,7 @@ class StableDiffusion:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# patch auraflow so it can handle other aspect ratios
|
||||
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
||||
# patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
||||
|
||||
flush()
|
||||
# text_encoder = pipe.text_encoder
|
||||
@@ -442,6 +455,31 @@ class StableDiffusion:
|
||||
else:
|
||||
pipln = StableDiffusionPipeline
|
||||
|
||||
if self.model_config.text_encoder_bits < 16:
|
||||
# this is only supported for T5 models for now
|
||||
te_kwargs = {}
|
||||
# handle quantization of TE
|
||||
te_is_quantized = False
|
||||
if self.model_config.text_encoder_bits == 8:
|
||||
te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
elif self.model_config.text_encoder_bits == 4:
|
||||
te_kwargs['load_in_4bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="text_encoder",
|
||||
torch_dtype=self.te_torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
|
||||
load_args['text_encoder'] = text_encoder
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# try to load with default diffusers
|
||||
@@ -455,7 +493,7 @@ class StableDiffusion:
|
||||
# variant="fp16",
|
||||
trust_remote_code=True,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
@@ -467,12 +505,12 @@ class StableDiffusion:
|
||||
safety_checker=None,
|
||||
trust_remote_code=True,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
)
|
||||
flush()
|
||||
|
||||
pipe.register_to_config(requires_safety_checker=False)
|
||||
text_encoder = pipe.text_encoder
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
tokenizer = pipe.tokenizer
|
||||
@@ -488,7 +526,7 @@ class StableDiffusion:
|
||||
self.unet = pipe.transformer
|
||||
else:
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
@@ -707,7 +745,7 @@ class StableDiffusion:
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
)
|
||||
flush()
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
@@ -873,6 +911,9 @@ class StableDiffusion:
|
||||
if not self.is_xl:
|
||||
raise ValueError("Refiner is only supported for XL models")
|
||||
|
||||
conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
||||
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
# was trained on 0.7 (I believe)
|
||||
@@ -1014,6 +1055,7 @@ class StableDiffusion:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
|
||||
self.unet.to(self.device_torch, dtype=self.torch_dtype)
|
||||
if network.is_merged_in:
|
||||
network.merge_out(merge_multiplier)
|
||||
# self.tokenizer.to(original_device_dict['tokenizer'])
|
||||
@@ -1655,18 +1697,18 @@ class StableDiffusion:
|
||||
dtype=None
|
||||
):
|
||||
if device is None:
|
||||
device = self.device
|
||||
device = self.vae_device_torch
|
||||
if dtype is None:
|
||||
dtype = self.torch_dtype
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
latent_list = []
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
# move to device and dtype
|
||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
@@ -2158,7 +2200,7 @@ class StableDiffusion:
|
||||
# vae
|
||||
state['vae'] = {
|
||||
'training': 'vae' in training_modules,
|
||||
'device': self.device_torch if 'vae' in active_modules else 'cpu',
|
||||
'device': self.vae_device_torch if 'vae' in active_modules else 'cpu',
|
||||
'requires_grad': 'vae' in training_modules,
|
||||
}
|
||||
|
||||
@@ -2182,13 +2224,13 @@ class StableDiffusion:
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
state['text_encoder'].append({
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
})
|
||||
else:
|
||||
state['text_encoder'] = {
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user