Added masking to slider training. Something is still weird though

This commit is contained in:
Jaret Burkett
2023-11-01 14:51:29 -06:00
parent a899ec91c8
commit 7d707b2fe6
6 changed files with 97 additions and 25 deletions

View File

@@ -152,7 +152,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:

View File

@@ -13,7 +13,7 @@ from toolkit.basic import value_map
from toolkit.config_modules import SliderConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos
import gc
from toolkit import train_tools
from toolkit.prompt_utils import \
@@ -35,6 +35,7 @@ adapter_transforms = transforms.Compose([
transforms.ToTensor(),
])
class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -273,7 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
return adapter_tensors
def hook_train_loop(self, batch):
def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]):
# set to eval mode
self.sd.set_device_state(self.eval_slider_device_state)
with torch.no_grad():
@@ -309,7 +310,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
if dbr_batch_size != dn.shape[0]:
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
down_kwargs['down_block_additional_residuals'] = [
torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals']
torch.cat([sample.clone()] * amount_to_add) for sample in
down_kwargs['down_block_additional_residuals']
]
return self.sd.predict_noise(
latents=dn,
@@ -325,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with torch.no_grad():
adapter_images = None
# for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
from_batch = False
@@ -370,7 +373,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
1, self.train_config.max_denoising_steps, (1,)
).item()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=height,
@@ -401,7 +403,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
noise_scheduler.set_timesteps(1000)
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
@@ -410,6 +411,33 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
# flush() # 4.2GB to 3GB on 512x512
mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
has_mask = False
if batch and batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
has_mask = True
if has_mask:
unmasked_target = get_noise_pred(
prompt_pair.positive_target, # negative prompt
prompt_pair.target_class, # positive prompt
1,
current_timestep,
denoised_latents
)
unmasked_target = unmasked_target.detach()
unmasked_target.requires_grad = False
else:
unmasked_target = None
# 4.20 GB RAM for 512x512
positive_latents = get_noise_pred(
@@ -504,19 +532,30 @@ class TrainSliderProcess(BaseSDTrainProcess):
anchor.to("cpu")
with torch.no_grad():
if self.slider_config.high_ram:
if self.slider_config.low_ram:
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
unconditional_latents_chunks = torch.chunk(
unconditional_latents.detach(),
self.prompt_chunk_size,
dim=0
)
mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0)
if unmasked_target is not None:
unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0)
else:
unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)]
else:
# run through in one instance
prompt_pair_chunks = [prompt_pair.detach()]
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
positive_latents_chunks = [positive_latents.detach()]
neutral_latents_chunks = [neutral_latents.detach()]
unconditional_latents_chunks = [unconditional_latents.detach()]
else:
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0)
mask_multiplier_chunks = [mask_multiplier]
unmasked_target_chunks = [unmasked_target]
# flush()
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
@@ -528,13 +567,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latent_chunk, \
positive_latents_chunk, \
neutral_latents_chunk, \
unconditional_latents_chunk \
unconditional_latents_chunk, \
mask_multiplier_chunk, \
unmasked_target_chunk \
in zip(
prompt_pair_chunks,
denoised_latent_chunks,
positive_latents_chunks,
neutral_latents_chunks,
unconditional_latents_chunks,
mask_multiplier_chunks,
unmasked_target_chunks
):
self.network.multiplier = prompt_pair_chunk.multiplier_list
target_latents = get_noise_pred(
@@ -568,17 +611,43 @@ class TrainSliderProcess(BaseSDTrainProcess):
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
# do inverted mask to preserve non masked
if has_mask and unmasked_target_chunk is not None:
loss = loss * mask_multiplier_chunk
# match the mask unmasked_target_chunk
mask_target_loss = torch.nn.functional.mse_loss(
target_latents.float(),
unmasked_target_chunk.float(),
reduction="none"
)
mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk)
loss += mask_target_loss
loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
if from_batch:
# match batch size
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
self.train_config.min_snr_gamma)
else:
# match batch size
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos)
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
if from_batch:
# match batch size
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
self.train_config.min_snr_gamma)
else:
# match batch size
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler,
self.train_config.min_snr_gamma)
loss = loss.mean() * prompt_pair_chunk.weight

View File

@@ -283,7 +283,7 @@ class SliderConfig:
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
self.high_ram = kwargs.get('high_ram', False)
self.low_ram = kwargs.get('low_ram', False)
# expand targets if shuffling
from toolkit.prompt_utils import get_slider_target_permutations
@@ -334,6 +334,7 @@ class DatasetConfig:
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black)
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi',
None) # if one is set and in json data, will be used as auto crop scale point of interes

View File

@@ -18,7 +18,7 @@ from toolkit.buckets import get_bucket_for_image_size
from toolkit.metadata import get_meta_for_safetensors
from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image, ImageFilter
from PIL import Image, ImageFilter, ImageOps
from PIL.ImageOps import exif_transpose
import albumentations as A
@@ -612,6 +612,8 @@ class MaskFileItemDTOMixin:
img = Image.fromarray(np_img)
img = img.convert('RGB')
if self.dataset_config.invert_mask:
img = ImageOps.invert(img)
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match

View File

@@ -535,7 +535,7 @@ class StableDiffusion:
text_embeddings: Union[PromptEmbeds, None] = None,
timestep: Union[int, torch.Tensor] = 1,
guidance_scale=7.5,
guidance_rescale=0, # 0.7 sdxl
guidance_rescale=0,
add_time_ids=None,
conditional_embeddings: Union[PromptEmbeds, None] = None,
unconditional_embeddings: Union[PromptEmbeds, None] = None,
@@ -674,7 +674,7 @@ class StableDiffusion:
add_time_ids=add_time_ids,
**kwargs,
)
latents = self.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
# return latents_steps
return latents

View File

@@ -691,12 +691,12 @@ class LearnableSNRGamma:
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01)
self.buffer = []
self.max_buffer_size = 100
self.max_buffer_size = 20
def forward(self, loss, timesteps):
# do a our train loop for lsnr here and return our values detached