Various features and fixes. Too much brain fog to do a proper description

This commit is contained in:
Jaret Burkett
2024-07-18 07:34:14 -06:00
parent 58dffd43a8
commit 11e426fdf1
6 changed files with 119 additions and 25 deletions

View File

@@ -356,7 +356,7 @@ class SDTrainer(BaseSDTrainProcess):
# we have to encode images into latents for now
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
target = unaugmented_latents.detach()
@@ -907,6 +907,17 @@ class SDTrainer(BaseSDTrainProcess):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
# sanity check
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
if encoder.dtype != self.sd.te_torch_dtype:
encoder.to(self.sd.te_torch_dtype)
else:
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
self.sd.text_encoder.to(self.sd.te_torch_dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.train_config.do_cfg or self.train_config.do_random_cfg:
# pick random negative prompts

View File

@@ -83,7 +83,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
self.network_config = None
self.train_config = TrainConfig(**self.get_conf('train', {}))
self.model_config = ModelConfig(**self.get_conf('model', {}))
model_config = self.get_conf('model', {})
# update modelconfig dtype to match train
model_config['dtype'] = self.train_config.dtype
self.model_config = ModelConfig(**model_config)
self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
first_sample_config = self.get_conf('first_sample', None)
@@ -723,6 +728,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
dtype=noise.dtype) * 2 - 1
# multiply by shift amount
noise_shift *= self.train_config.random_noise_shift
# add to noise
noise += noise_shift
return noise
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):

View File

@@ -160,6 +160,8 @@ class AdapterConfig:
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
if self.train_only_image_encoder:
self.train_image_encoder = True
self.train_only_image_encoder_positional_embedding: bool = kwargs.get(
'train_only_image_encoder_positional_embedding', False)
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
self.safe_channels: int = kwargs.get('safe_channels', 2048)
@@ -260,6 +262,7 @@ class TrainConfig:
# multiplier applied to loos on regularization images
self.reg_weight = kwargs.get('reg_weight', 1.0)
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
# dropout that happens before encoding. It functions independently per text encoder
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
@@ -385,6 +388,11 @@ class ModelConfig:
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
self.unet_path = kwargs.get("unet_path", None)
self.unet_sample_size = kwargs.get("unet_sample_size", None)
self.vae_device = kwargs.get("vae_device", None)
self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
self.te_device = kwargs.get("te_device", None)
self.te_dtype = kwargs.get("te_dtype", self.dtype)
pass
class EMAConfig:

View File

@@ -394,7 +394,7 @@ class IPAdapter(torch.nn.Module):
elif adapter_config.type == 'ip+':
heads = 12 if not sd.is_xl else 20
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
embedding_dim = self.image_encoder.config.target_hidden_size if not self.config.image_encoder_arch.startswith(
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
'convnext') else \
self.image_encoder.config.hidden_sizes[-1]
@@ -964,7 +964,10 @@ class IPAdapter(torch.nn.Module):
def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
yield from self.image_encoder.parameters(recurse)
if self.config.train_only_image_encoder_positional_embedding:
yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse)
else:
yield from self.image_encoder.parameters(recurse)
return
if self.config.train_scaler:
# no params

View File

@@ -21,7 +21,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
from toolkit.custom_adapter import CustomAdapter
@@ -202,6 +202,10 @@ class TEAdapterAttnProcessor(nn.Module):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# remove attn mask if doing clip
if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip":
attention_mask = None
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -246,7 +250,7 @@ class TEAdapter(torch.nn.Module):
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
self.token_size = self.te_ref().config.d_model
else:
self.token_size = self.te_ref().config.target_hidden_size
self.token_size = self.te_ref().config.hidden_size
# add text projection if is sdxl
self.text_projection = None
@@ -388,8 +392,17 @@ class TEAdapter(torch.nn.Module):
# ).input_ids.to(te.device)
# outputs = te(input_ids=input_ids)
# outputs = outputs.last_hidden_state
if self.adapter_ref().config.text_encoder_arch == "clip":
embeds = train_tools.encode_prompts(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
attention_mask = torch.ones(embeds.shape[:2], device=embeds.device)
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
elif self.adapter_ref().config.text_encoder_arch == "pile-t5":
# just use aura pile
embeds, attention_mask = train_tools.encode_prompts_auraflow(
tokenizer,
@@ -407,7 +420,8 @@ class TEAdapter(torch.nn.Module):
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
if attention_mask is not None:
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
if self.text_projection is not None:
# pool the output of embeds ignoring 0 in the attention mask
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
@@ -420,19 +434,19 @@ class TEAdapter(torch.nn.Module):
pooled_embeds = self.text_projection(pooled_output)
t5_embeds = PromptEmbeds(
prompt_embeds = PromptEmbeds(
(embeds, pooled_embeds),
attention_mask=attention_mask,
).detach()
else:
t5_embeds = PromptEmbeds(
prompt_embeds = PromptEmbeds(
embeds,
attention_mask=attention_mask,
).detach()
return t5_embeds
return prompt_embeds

View File

@@ -123,6 +123,13 @@ class StableDiffusion:
self.dtype = dtype
self.torch_dtype = get_torch_dtype(dtype)
self.device_torch = torch.device(self.device)
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device)
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device)
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
self.model_config = model_config
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
@@ -220,11 +227,13 @@ class StableDiffusion:
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
if self.model_config.experimental_xl:
print("Experimental XL mode enabled")
print("Loading and injecting alt weights")
@@ -333,6 +342,8 @@ class StableDiffusion:
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
if self.model_config.is_pixart_sigma:
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(
@@ -375,6 +386,8 @@ class StableDiffusion:
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
elif self.model_config.is_auraflow:
te_kwargs = {}
@@ -427,7 +440,7 @@ class StableDiffusion:
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
# patch auraflow so it can handle other aspect ratios
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
# patch_auraflow_pos_embed(pipe.transformer.pos_embed)
flush()
# text_encoder = pipe.text_encoder
@@ -442,6 +455,31 @@ class StableDiffusion:
else:
pipln = StableDiffusionPipeline
if self.model_config.text_encoder_bits < 16:
# this is only supported for T5 models for now
te_kwargs = {}
# handle quantization of TE
te_is_quantized = False
if self.model_config.text_encoder_bits == 8:
te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
elif self.model_config.text_encoder_bits == 4:
te_kwargs['load_in_4bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
text_encoder = T5EncoderModel.from_pretrained(
model_path,
subfolder="text_encoder",
torch_dtype=self.te_torch_dtype,
**te_kwargs
)
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
load_args['text_encoder'] = text_encoder
# see if path exists
if not os.path.exists(model_path) or os.path.isdir(model_path):
# try to load with default diffusers
@@ -455,7 +493,7 @@ class StableDiffusion:
# variant="fp16",
trust_remote_code=True,
**load_args
).to(self.device_torch)
)
else:
pipe = pipln.from_single_file(
model_path,
@@ -467,12 +505,12 @@ class StableDiffusion:
safety_checker=None,
trust_remote_code=True,
**load_args
).to(self.device_torch)
)
flush()
pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
tokenizer = pipe.tokenizer
@@ -488,7 +526,7 @@ class StableDiffusion:
self.unet = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
self.vae.eval()
self.vae.requires_grad_(False)
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
@@ -707,7 +745,7 @@ class StableDiffusion:
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
)
flush()
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
@@ -873,6 +911,9 @@ class StableDiffusion:
if not self.is_xl:
raise ValueError("Refiner is only supported for XL models")
conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
if self.is_xl:
# fix guidance rescale for sdxl
# was trained on 0.7 (I believe)
@@ -1014,6 +1055,7 @@ class StableDiffusion:
self.network.train()
self.network.multiplier = start_multiplier
self.unet.to(self.device_torch, dtype=self.torch_dtype)
if network.is_merged_in:
network.merge_out(merge_multiplier)
# self.tokenizer.to(original_device_dict['tokenizer'])
@@ -1655,18 +1697,18 @@ class StableDiffusion:
dtype=None
):
if device is None:
device = self.device
device = self.vae_device_torch
if dtype is None:
dtype = self.torch_dtype
dtype = self.vae_torch_dtype
latent_list = []
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
# move to device and dtype
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
image_list = [image.to(device, dtype=dtype) for image in image_list]
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
@@ -2158,7 +2200,7 @@ class StableDiffusion:
# vae
state['vae'] = {
'training': 'vae' in training_modules,
'device': self.device_torch if 'vae' in active_modules else 'cpu',
'device': self.vae_device_torch if 'vae' in active_modules else 'cpu',
'requires_grad': 'vae' in training_modules,
}
@@ -2182,13 +2224,13 @@ class StableDiffusion:
for i, encoder in enumerate(self.text_encoder):
state['text_encoder'].append({
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
'requires_grad': 'text_encoder' in training_modules,
})
else:
state['text_encoder'] = {
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
'requires_grad': 'text_encoder' in training_modules,
}