mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added masking to slider training. Something is still weird though
This commit is contained in:
@@ -13,7 +13,7 @@ from toolkit.basic import value_map
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
from toolkit.prompt_utils import \
|
||||
@@ -35,6 +35,7 @@ adapter_transforms = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
class TrainSliderProcess(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -273,7 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]):
|
||||
# set to eval mode
|
||||
self.sd.set_device_state(self.eval_slider_device_state)
|
||||
with torch.no_grad():
|
||||
@@ -309,7 +310,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
if dbr_batch_size != dn.shape[0]:
|
||||
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
|
||||
down_kwargs['down_block_additional_residuals'] = [
|
||||
torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals']
|
||||
torch.cat([sample.clone()] * amount_to_add) for sample in
|
||||
down_kwargs['down_block_additional_residuals']
|
||||
]
|
||||
return self.sd.predict_noise(
|
||||
latents=dn,
|
||||
@@ -325,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
|
||||
# for a complete slider, the batch size is 4 to begin with now
|
||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||
from_batch = False
|
||||
@@ -370,7 +373,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
@@ -401,7 +403,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
|
||||
|
||||
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
@@ -410,6 +411,33 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
||||
|
||||
# flush() # 4.2GB to 3GB on 512x512
|
||||
mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
has_mask = False
|
||||
if batch and batch.mask_tensor is not None:
|
||||
with self.timer('get_mask_multiplier'):
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
has_mask = True
|
||||
|
||||
if has_mask:
|
||||
unmasked_target = get_noise_pred(
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
prompt_pair.target_class, # positive prompt
|
||||
1,
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
)
|
||||
unmasked_target = unmasked_target.detach()
|
||||
unmasked_target.requires_grad = False
|
||||
else:
|
||||
unmasked_target = None
|
||||
|
||||
# 4.20 GB RAM for 512x512
|
||||
positive_latents = get_noise_pred(
|
||||
@@ -504,19 +532,30 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
anchor.to("cpu")
|
||||
|
||||
with torch.no_grad():
|
||||
if self.slider_config.high_ram:
|
||||
if self.slider_config.low_ram:
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
|
||||
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
|
||||
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
unconditional_latents_chunks = torch.chunk(
|
||||
unconditional_latents.detach(),
|
||||
self.prompt_chunk_size,
|
||||
dim=0
|
||||
)
|
||||
mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0)
|
||||
if unmasked_target is not None:
|
||||
unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0)
|
||||
else:
|
||||
unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)]
|
||||
else:
|
||||
# run through in one instance
|
||||
prompt_pair_chunks = [prompt_pair.detach()]
|
||||
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
|
||||
positive_latents_chunks = [positive_latents.detach()]
|
||||
neutral_latents_chunks = [neutral_latents.detach()]
|
||||
unconditional_latents_chunks = [unconditional_latents.detach()]
|
||||
else:
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
|
||||
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
|
||||
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
mask_multiplier_chunks = [mask_multiplier]
|
||||
unmasked_target_chunks = [unmasked_target]
|
||||
|
||||
# flush()
|
||||
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
|
||||
@@ -528,13 +567,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
denoised_latent_chunk, \
|
||||
positive_latents_chunk, \
|
||||
neutral_latents_chunk, \
|
||||
unconditional_latents_chunk \
|
||||
unconditional_latents_chunk, \
|
||||
mask_multiplier_chunk, \
|
||||
unmasked_target_chunk \
|
||||
in zip(
|
||||
prompt_pair_chunks,
|
||||
denoised_latent_chunks,
|
||||
positive_latents_chunks,
|
||||
neutral_latents_chunks,
|
||||
unconditional_latents_chunks,
|
||||
mask_multiplier_chunks,
|
||||
unmasked_target_chunks
|
||||
):
|
||||
self.network.multiplier = prompt_pair_chunk.multiplier_list
|
||||
target_latents = get_noise_pred(
|
||||
@@ -568,17 +611,43 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
||||
|
||||
# do inverted mask to preserve non masked
|
||||
if has_mask and unmasked_target_chunk is not None:
|
||||
loss = loss * mask_multiplier_chunk
|
||||
# match the mask unmasked_target_chunk
|
||||
mask_target_loss = torch.nn.functional.mse_loss(
|
||||
target_latents.float(),
|
||||
unmasked_target_chunk.float(),
|
||||
reduction="none"
|
||||
)
|
||||
mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk)
|
||||
loss += mask_target_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
if from_batch:
|
||||
# match batch size
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
|
||||
self.train_config.min_snr_gamma)
|
||||
else:
|
||||
# match batch size
|
||||
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
if from_batch:
|
||||
# match batch size
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
|
||||
self.train_config.min_snr_gamma)
|
||||
else:
|
||||
# match batch size
|
||||
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler,
|
||||
self.train_config.min_snr_gamma)
|
||||
|
||||
|
||||
loss = loss.mean() * prompt_pair_chunk.weight
|
||||
|
||||
|
||||
Reference in New Issue
Block a user