Numerous fixes for time sampling. Still not perfect

This commit is contained in:
Jaret Burkett
2023-11-28 07:34:43 -07:00
parent d7e55b6ad4
commit 792a5e37e2
7 changed files with 160 additions and 91 deletions

View File

@@ -34,6 +34,7 @@ class SDTrainer(BaseSDTrainProcess):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
self.do_long_prompts = False
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
@@ -126,6 +127,7 @@ class SDTrainer(BaseSDTrainProcess):
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
target = unaugmented_latents.detach()
# Get the target for loss depending on the prediction type
@@ -492,7 +494,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
else:
@@ -506,7 +508,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)

View File

@@ -666,7 +666,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
unconditional_imgs = batch.unconditional_tensor
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
unconditional_latents = self.sd.encode_images(unconditional_imgs)
batch.unconditional_latents = unconditional_latents
batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier
unaugmented_latents = None
if self.train_config.loss_target == 'differential_noise':
@@ -715,36 +715,41 @@ class BaseSDTrainProcess(BaseTrainProcess):
orig_timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'content':
timesteps = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'style':
timesteps = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,
timestep_indices = value_map(
timestep_indices,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
min_noise_steps,
max_noise_steps
max_noise_steps - 1
)
timesteps = timesteps.long().clamp(
timestep_indices = timestep_indices.long().clamp(
min_noise_steps + 1,
max_noise_steps - 1
)
elif self.train_config.content_or_style == 'balanced':
if min_noise_steps == max_noise_steps:
timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
else:
timesteps = torch.randint(
min_noise_steps,
max_noise_steps,
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
timestep_indices = torch.randint(
min_noise_steps + 1,
max_noise_steps - 1,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
timestep_indices = timestep_indices.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
@@ -765,9 +770,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise = noise * noise_multiplier
img_multiplier = self.train_config.img_multiplier
latents = latents * self.train_config.latent_multiplier
latents = latents * img_multiplier
if batch.unconditional_latents is not None:
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)

View File

@@ -30,9 +30,12 @@ parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 512
resolution = 1024
bucket_tolerance = 64
batch_size = 4
batch_size = 1
##
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
@@ -41,8 +44,31 @@ dataset_config = DatasetConfig(
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
augments=['ColorJitter'],
poi='person'
poi='person',
augmentations=[
{
'method': 'RandomBrightnessContrast',
'brightness_limit': (-0.3, 0.3),
'contrast_limit': (-0.3, 0.3),
'brightness_by_max': False,
'p': 1.0
},
{
'method': 'HueSaturationValue',
'hue_shift_limit': (-0, 0),
'sat_shift_limit': (-40, 40),
'val_shift_limit': (-40, 40),
'p': 1.0
},
# {
# 'method': 'RGBShift',
# 'r_shift_limit': (-20, 20),
# 'g_shift_limit': (-20, 20),
# 'b_shift_limit': (-20, 20),
# 'p': 1.0
# },
]
)

View File

@@ -193,6 +193,7 @@ class TrainConfig:
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
self.negative_prompt = kwargs.get('negative_prompt', None)
# multiplier applied to loos on regularization images
self.reg_weight = kwargs.get('reg_weight', 1.0)

View File

@@ -30,7 +30,7 @@ if TYPE_CHECKING:
# def get_associated_caption_from_img_path(img_path):
# https://demo.albumentations.ai/
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
@@ -167,10 +167,11 @@ class BucketsMixin:
width = int(file_item.width * file_item.dataset_config.scale)
height = int(file_item.height * file_item.dataset_config.scale)
did_process_poi = False
if file_item.has_point_of_interest:
# let the poi module handle the bucketing
file_item.setup_poi_bucket()
else:
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
did_process_poi = file_item.setup_poi_bucket()
if not did_process_poi:
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,
@@ -323,7 +324,7 @@ class CaptionProcessingDTOMixin:
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
# add random triggers
caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
if self.dataset_config.shuffle_tokens:
# shuffle again
@@ -803,79 +804,68 @@ class PoiFileItemDTOMixin:
self.poi_y = self.height - self.poi_y - self.poi_height
def setup_poi_bucket(self: 'FileItemDTO'):
# we are using poi, so we need to calculate the bucket based on the poi
# TODO this will allow poi to be smaller than resolution. Could affect training image size
poi_resolution = min(
self.dataset_config.resolution,
get_resolution(
self.poi_width * self.dataset_config.scale,
self.poi_height * self.dataset_config.scale
)
)
resolution = min(self.dataset_config.resolution, poi_resolution)
bucket_tolerance = self.dataset_config.bucket_tolerance
initial_width = int(self.width * self.dataset_config.scale)
initial_height = int(self.height * self.dataset_config.scale)
# we are using poi, so we need to calculate the bucket based on the poi
# if img resolution is less than dataset resolution, just return and let the normal bucketing happen
img_resolution = get_resolution(initial_width, initial_height)
if img_resolution <= self.dataset_config.resolution:
return False # will trigger normal bucketing
bucket_tolerance = self.dataset_config.bucket_tolerance
poi_x = int(self.poi_x * self.dataset_config.scale)
poi_y = int(self.poi_y * self.dataset_config.scale)
poi_width = int(self.poi_width * self.dataset_config.scale)
poi_height = int(self.poi_height * self.dataset_config.scale)
# expand poi to fit resolution
if poi_width < resolution:
width_difference = resolution - poi_width
poi_x = poi_x - int(width_difference / 2)
poi_width = resolution
# make sure we dont go out of bounds
if poi_x < 0:
# loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better
num_loops = 0
while True:
# crop left
if poi_x > 0:
poi_x = random.randint(0, poi_x)
else:
poi_x = 0
# if total width too much, crop
if poi_x + poi_width > initial_width:
poi_width = initial_width - poi_x
if poi_height < resolution:
height_difference = resolution - poi_height
poi_y = poi_y - int(height_difference / 2)
poi_height = resolution
# make sure we dont go out of bounds
if poi_y < 0:
# crop right
cr_min = poi_x + poi_width
if cr_min < initial_width:
crop_right = random.randint(poi_x + poi_width, initial_width)
else:
crop_right = initial_width
poi_width = crop_right - poi_x
if poi_y > 0:
poi_y = random.randint(0, poi_y)
else:
poi_y = 0
# if total height too much, crop
if poi_y + poi_height > initial_height:
poi_height = initial_height - poi_y
# crop left
if poi_x > 0:
crop_left = random.randint(0, poi_x)
else:
crop_left = 0
if poi_y + poi_height < initial_height:
crop_bottom = random.randint(poi_y + poi_height, initial_height)
else:
crop_bottom = initial_height
# crop right
cr_min = poi_x + poi_width
if cr_min < initial_width:
crop_right = random.randint(poi_x + poi_width, initial_width)
else:
crop_right = initial_width
poi_height = crop_bottom - poi_y
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
current_resolution = get_resolution(poi_width, poi_height)
if current_resolution >= self.dataset_config.resolution:
# We can break now
break
else:
num_loops += 1
if num_loops > 100:
print(
f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.")
return False
if poi_y > 0:
crop_top = random.randint(0, poi_y)
else:
crop_top = 0
if poi_y + poi_height < initial_height:
crop_bottom = random.randint(poi_y + poi_height, initial_height)
else:
crop_bottom = initial_height
new_width = crop_right - crop_left
new_height = crop_bottom - crop_top
new_width = poi_width
new_height = poi_height
bucket_resolution = get_bucket_for_image_size(
new_width, new_height,
resolution=resolution,
resolution=self.dataset_config.resolution,
divisibility=bucket_tolerance
)
@@ -888,8 +878,10 @@ class PoiFileItemDTOMixin:
self.scale_to_height = int(initial_height * max_scale_factor)
self.crop_width = bucket_resolution['width']
self.crop_height = bucket_resolution['height']
self.crop_x = int(crop_left * max_scale_factor)
self.crop_y = int(crop_top * max_scale_factor)
self.crop_x = int(poi_x * max_scale_factor)
self.crop_y = int(poi_y * max_scale_factor)
return True
class ArgBreakMixin:

View File

@@ -193,6 +193,9 @@ class StableDiffusion:
device=self.device_torch,
torch_dtype=self.torch_dtype,
)
if 'vae' in load_args and load_args['vae'] is not None:
pipe.vae = load_args['vae']
flush()
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
@@ -679,6 +682,18 @@ class StableDiffusion:
text_embeddings = text_embeddings.to(self.device_torch)
timestep = timestep.to(self.device_torch)
def scale_model_input(model_input, timestep_tensor):
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
out_chunks = []
for idx in range(model_input.shape[0]):
# if scheduler has step_index
if hasattr(self.noise_scheduler, '_step_index'):
self.noise_scheduler._step_index = None
out_chunks.append(
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx])
)
return torch.cat(out_chunks, dim=0)
if self.is_xl:
with torch.no_grad():
# 16, 6 for bs of 4
@@ -691,10 +706,11 @@ class StableDiffusion:
if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2)
timestep = torch.cat([timestep] * 2)
else:
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
latent_model_input = scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
# todo can we zero here the second text encoder? or match a blank string?
@@ -784,10 +800,11 @@ class StableDiffusion:
if do_classifier_free_guidance:
# if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
timestep = torch.cat([timestep] * 2)
else:
latent_model_input = latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
latent_model_input = scale_model_input(latent_model_input, timestep)
# check if we need to concat timesteps
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
@@ -823,6 +840,23 @@ class StableDiffusion:
return noise_pred
def step_scheduler(self, model_input, latent_input, timestep_tensor):
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
for idx in range(model_input.shape[0]):
# Reset it so it is unique for the
if hasattr(self.noise_scheduler, '_step_index'):
self.noise_scheduler._step_index = None
if hasattr(self.noise_scheduler, 'is_scale_input_called'):
self.noise_scheduler.is_scale_input_called = True
out_chunks.append(
self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
0]
)
return torch.cat(out_chunks, dim=0)
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
@@ -839,6 +873,7 @@ class StableDiffusion:
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
for timestep in tqdm(timesteps_to_run, leave=False):
timestep = timestep.unsqueeze_(0)
noise_pred = self.predict_noise(
latents,
text_embeddings,
@@ -847,7 +882,9 @@ class StableDiffusion:
add_time_ids=add_time_ids,
**kwargs,
)
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
# some schedulers need to run separately, so do that. (euler for example)
latents = self.step_scheduler(noise_pred, latents, timestep)
# if not last step, and bleeding, bleed in some latents
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:

View File

@@ -584,8 +584,13 @@ def encode_prompts_xl(
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
if max_length is None and not truncate:
raise ValueError("max_length must be set if truncate is True")
tokens = tokens.to(text_encoder.device)
try:
tokens = tokens.to(text_encoder.device)
except Exception as e:
print(e)
print("tokens.device", tokens.device)
print("text_encoder.device", text_encoder.device)
raise e
if truncate:
return text_encoder(tokens)[0]
@@ -771,8 +776,8 @@ def apply_snr_weight(
):
# will get it from noise scheduler if exist or will calculate it if not
all_snr = get_all_snr(noise_scheduler, loss.device)
snr = torch.stack([all_snr[t] for t in timesteps])
step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
snr = torch.stack([all_snr[t] for t in step_indices])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
if fixed:
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr