mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Numerous fixes for time sampling. Still not perfect
This commit is contained in:
@@ -34,6 +34,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
super().__init__(process_id, job, config, **kwargs)
|
super().__init__(process_id, job, config, **kwargs)
|
||||||
self.assistant_adapter: Union['T2IAdapter', None]
|
self.assistant_adapter: Union['T2IAdapter', None]
|
||||||
self.do_prior_prediction = False
|
self.do_prior_prediction = False
|
||||||
|
self.do_long_prompts = False
|
||||||
if self.train_config.inverted_mask_prior:
|
if self.train_config.inverted_mask_prior:
|
||||||
self.do_prior_prediction = True
|
self.do_prior_prediction = True
|
||||||
|
|
||||||
@@ -126,6 +127,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||||
|
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
|
||||||
target = unaugmented_latents.detach()
|
target = unaugmented_latents.detach()
|
||||||
|
|
||||||
# Get the target for loss depending on the prediction type
|
# Get the target for loss depending on the prediction type
|
||||||
@@ -492,7 +494,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
conditional_embeds = self.sd.encode_prompt(
|
conditional_embeds = self.sd.encode_prompt(
|
||||||
conditioned_prompts, prompt_2,
|
conditioned_prompts, prompt_2,
|
||||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
long_prompts=True).to(
|
long_prompts=self.do_long_prompts).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
else:
|
else:
|
||||||
@@ -506,7 +508,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
conditional_embeds = self.sd.encode_prompt(
|
conditional_embeds = self.sd.encode_prompt(
|
||||||
conditioned_prompts, prompt_2,
|
conditioned_prompts, prompt_2,
|
||||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
long_prompts=True).to(
|
long_prompts=self.do_long_prompts).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -666,7 +666,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
unconditional_imgs = batch.unconditional_tensor
|
unconditional_imgs = batch.unconditional_tensor
|
||||||
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
|
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
|
||||||
unconditional_latents = self.sd.encode_images(unconditional_imgs)
|
unconditional_latents = self.sd.encode_images(unconditional_imgs)
|
||||||
batch.unconditional_latents = unconditional_latents
|
batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier
|
||||||
|
|
||||||
unaugmented_latents = None
|
unaugmented_latents = None
|
||||||
if self.train_config.loss_target == 'differential_noise':
|
if self.train_config.loss_target == 'differential_noise':
|
||||||
@@ -715,36 +715,41 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||||
|
|
||||||
if self.train_config.content_or_style == 'content':
|
if self.train_config.content_or_style == 'content':
|
||||||
timesteps = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||||
elif self.train_config.content_or_style == 'style':
|
elif self.train_config.content_or_style == 'style':
|
||||||
timesteps = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||||
|
|
||||||
timesteps = value_map(
|
timestep_indices = value_map(
|
||||||
timesteps,
|
timestep_indices,
|
||||||
0,
|
0,
|
||||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||||
min_noise_steps,
|
min_noise_steps,
|
||||||
max_noise_steps
|
max_noise_steps - 1
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long().clamp(
|
timestep_indices = timestep_indices.long().clamp(
|
||||||
min_noise_steps + 1,
|
min_noise_steps + 1,
|
||||||
max_noise_steps - 1
|
max_noise_steps - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.train_config.content_or_style == 'balanced':
|
elif self.train_config.content_or_style == 'balanced':
|
||||||
if min_noise_steps == max_noise_steps:
|
if min_noise_steps == max_noise_steps:
|
||||||
timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||||
else:
|
else:
|
||||||
timesteps = torch.randint(
|
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
|
||||||
min_noise_steps,
|
timestep_indices = torch.randint(
|
||||||
max_noise_steps,
|
min_noise_steps + 1,
|
||||||
|
max_noise_steps - 1,
|
||||||
(batch_size,),
|
(batch_size,),
|
||||||
device=self.device_torch
|
device=self.device_torch
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long()
|
timestep_indices = timestep_indices.long()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||||
|
|
||||||
|
# convert the timestep_indices to a timestep
|
||||||
|
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||||
|
timesteps = torch.stack(timesteps, dim=0)
|
||||||
|
|
||||||
# get noise
|
# get noise
|
||||||
noise = self.sd.get_latent_noise(
|
noise = self.sd.get_latent_noise(
|
||||||
height=latents.shape[2],
|
height=latents.shape[2],
|
||||||
@@ -765,9 +770,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
noise = noise * noise_multiplier
|
noise = noise * noise_multiplier
|
||||||
|
|
||||||
img_multiplier = self.train_config.img_multiplier
|
latents = latents * self.train_config.latent_multiplier
|
||||||
|
|
||||||
latents = latents * img_multiplier
|
if batch.unconditional_latents is not None:
|
||||||
|
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
|
||||||
|
|
||||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
|||||||
@@ -30,9 +30,12 @@ parser.add_argument('--epochs', type=int, default=1)
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
dataset_folder = args.dataset_folder
|
dataset_folder = args.dataset_folder
|
||||||
resolution = 512
|
resolution = 1024
|
||||||
bucket_tolerance = 64
|
bucket_tolerance = 64
|
||||||
batch_size = 4
|
batch_size = 1
|
||||||
|
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
dataset_config = DatasetConfig(
|
dataset_config = DatasetConfig(
|
||||||
dataset_path=dataset_folder,
|
dataset_path=dataset_folder,
|
||||||
@@ -41,8 +44,31 @@ dataset_config = DatasetConfig(
|
|||||||
default_caption='default',
|
default_caption='default',
|
||||||
buckets=True,
|
buckets=True,
|
||||||
bucket_tolerance=bucket_tolerance,
|
bucket_tolerance=bucket_tolerance,
|
||||||
augments=['ColorJitter'],
|
poi='person',
|
||||||
poi='person'
|
augmentations=[
|
||||||
|
{
|
||||||
|
'method': 'RandomBrightnessContrast',
|
||||||
|
'brightness_limit': (-0.3, 0.3),
|
||||||
|
'contrast_limit': (-0.3, 0.3),
|
||||||
|
'brightness_by_max': False,
|
||||||
|
'p': 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'method': 'HueSaturationValue',
|
||||||
|
'hue_shift_limit': (-0, 0),
|
||||||
|
'sat_shift_limit': (-40, 40),
|
||||||
|
'val_shift_limit': (-40, 40),
|
||||||
|
'p': 1.0
|
||||||
|
},
|
||||||
|
# {
|
||||||
|
# 'method': 'RGBShift',
|
||||||
|
# 'r_shift_limit': (-20, 20),
|
||||||
|
# 'g_shift_limit': (-20, 20),
|
||||||
|
# 'b_shift_limit': (-20, 20),
|
||||||
|
# 'p': 1.0
|
||||||
|
# },
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ class TrainConfig:
|
|||||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
|
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||||
# multiplier applied to loos on regularization images
|
# multiplier applied to loos on regularization images
|
||||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
# def get_associated_caption_from_img_path(img_path):
|
# def get_associated_caption_from_img_path(img_path):
|
||||||
|
# https://demo.albumentations.ai/
|
||||||
class Augments:
|
class Augments:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.method_name = kwargs.get('method', None)
|
self.method_name = kwargs.get('method', None)
|
||||||
@@ -167,10 +167,11 @@ class BucketsMixin:
|
|||||||
width = int(file_item.width * file_item.dataset_config.scale)
|
width = int(file_item.width * file_item.dataset_config.scale)
|
||||||
height = int(file_item.height * file_item.dataset_config.scale)
|
height = int(file_item.height * file_item.dataset_config.scale)
|
||||||
|
|
||||||
|
did_process_poi = False
|
||||||
if file_item.has_point_of_interest:
|
if file_item.has_point_of_interest:
|
||||||
# let the poi module handle the bucketing
|
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
|
||||||
file_item.setup_poi_bucket()
|
did_process_poi = file_item.setup_poi_bucket()
|
||||||
else:
|
if not did_process_poi:
|
||||||
bucket_resolution = get_bucket_for_image_size(
|
bucket_resolution = get_bucket_for_image_size(
|
||||||
width, height,
|
width, height,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
@@ -323,7 +324,7 @@ class CaptionProcessingDTOMixin:
|
|||||||
|
|
||||||
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
|
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
|
||||||
# add random triggers
|
# add random triggers
|
||||||
caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption
|
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
|
||||||
|
|
||||||
if self.dataset_config.shuffle_tokens:
|
if self.dataset_config.shuffle_tokens:
|
||||||
# shuffle again
|
# shuffle again
|
||||||
@@ -803,79 +804,68 @@ class PoiFileItemDTOMixin:
|
|||||||
self.poi_y = self.height - self.poi_y - self.poi_height
|
self.poi_y = self.height - self.poi_y - self.poi_height
|
||||||
|
|
||||||
def setup_poi_bucket(self: 'FileItemDTO'):
|
def setup_poi_bucket(self: 'FileItemDTO'):
|
||||||
# we are using poi, so we need to calculate the bucket based on the poi
|
|
||||||
|
|
||||||
# TODO this will allow poi to be smaller than resolution. Could affect training image size
|
|
||||||
poi_resolution = min(
|
|
||||||
self.dataset_config.resolution,
|
|
||||||
get_resolution(
|
|
||||||
self.poi_width * self.dataset_config.scale,
|
|
||||||
self.poi_height * self.dataset_config.scale
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
resolution = min(self.dataset_config.resolution, poi_resolution)
|
|
||||||
|
|
||||||
bucket_tolerance = self.dataset_config.bucket_tolerance
|
|
||||||
initial_width = int(self.width * self.dataset_config.scale)
|
initial_width = int(self.width * self.dataset_config.scale)
|
||||||
initial_height = int(self.height * self.dataset_config.scale)
|
initial_height = int(self.height * self.dataset_config.scale)
|
||||||
|
# we are using poi, so we need to calculate the bucket based on the poi
|
||||||
|
|
||||||
|
# if img resolution is less than dataset resolution, just return and let the normal bucketing happen
|
||||||
|
img_resolution = get_resolution(initial_width, initial_height)
|
||||||
|
if img_resolution <= self.dataset_config.resolution:
|
||||||
|
return False # will trigger normal bucketing
|
||||||
|
|
||||||
|
bucket_tolerance = self.dataset_config.bucket_tolerance
|
||||||
poi_x = int(self.poi_x * self.dataset_config.scale)
|
poi_x = int(self.poi_x * self.dataset_config.scale)
|
||||||
poi_y = int(self.poi_y * self.dataset_config.scale)
|
poi_y = int(self.poi_y * self.dataset_config.scale)
|
||||||
poi_width = int(self.poi_width * self.dataset_config.scale)
|
poi_width = int(self.poi_width * self.dataset_config.scale)
|
||||||
poi_height = int(self.poi_height * self.dataset_config.scale)
|
poi_height = int(self.poi_height * self.dataset_config.scale)
|
||||||
|
|
||||||
# expand poi to fit resolution
|
# loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better
|
||||||
if poi_width < resolution:
|
num_loops = 0
|
||||||
width_difference = resolution - poi_width
|
while True:
|
||||||
poi_x = poi_x - int(width_difference / 2)
|
# crop left
|
||||||
poi_width = resolution
|
if poi_x > 0:
|
||||||
# make sure we dont go out of bounds
|
poi_x = random.randint(0, poi_x)
|
||||||
if poi_x < 0:
|
else:
|
||||||
poi_x = 0
|
poi_x = 0
|
||||||
# if total width too much, crop
|
|
||||||
if poi_x + poi_width > initial_width:
|
|
||||||
poi_width = initial_width - poi_x
|
|
||||||
|
|
||||||
if poi_height < resolution:
|
# crop right
|
||||||
height_difference = resolution - poi_height
|
cr_min = poi_x + poi_width
|
||||||
poi_y = poi_y - int(height_difference / 2)
|
if cr_min < initial_width:
|
||||||
poi_height = resolution
|
crop_right = random.randint(poi_x + poi_width, initial_width)
|
||||||
# make sure we dont go out of bounds
|
else:
|
||||||
if poi_y < 0:
|
crop_right = initial_width
|
||||||
|
|
||||||
|
poi_width = crop_right - poi_x
|
||||||
|
|
||||||
|
if poi_y > 0:
|
||||||
|
poi_y = random.randint(0, poi_y)
|
||||||
|
else:
|
||||||
poi_y = 0
|
poi_y = 0
|
||||||
# if total height too much, crop
|
|
||||||
if poi_y + poi_height > initial_height:
|
|
||||||
poi_height = initial_height - poi_y
|
|
||||||
|
|
||||||
# crop left
|
if poi_y + poi_height < initial_height:
|
||||||
if poi_x > 0:
|
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
||||||
crop_left = random.randint(0, poi_x)
|
else:
|
||||||
else:
|
crop_bottom = initial_height
|
||||||
crop_left = 0
|
|
||||||
|
|
||||||
# crop right
|
poi_height = crop_bottom - poi_y
|
||||||
cr_min = poi_x + poi_width
|
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
|
||||||
if cr_min < initial_width:
|
current_resolution = get_resolution(poi_width, poi_height)
|
||||||
crop_right = random.randint(poi_x + poi_width, initial_width)
|
if current_resolution >= self.dataset_config.resolution:
|
||||||
else:
|
# We can break now
|
||||||
crop_right = initial_width
|
break
|
||||||
|
else:
|
||||||
|
num_loops += 1
|
||||||
|
if num_loops > 100:
|
||||||
|
print(
|
||||||
|
f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.")
|
||||||
|
return False
|
||||||
|
|
||||||
if poi_y > 0:
|
new_width = poi_width
|
||||||
crop_top = random.randint(0, poi_y)
|
new_height = poi_height
|
||||||
else:
|
|
||||||
crop_top = 0
|
|
||||||
|
|
||||||
if poi_y + poi_height < initial_height:
|
|
||||||
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
|
||||||
else:
|
|
||||||
crop_bottom = initial_height
|
|
||||||
|
|
||||||
new_width = crop_right - crop_left
|
|
||||||
new_height = crop_bottom - crop_top
|
|
||||||
|
|
||||||
bucket_resolution = get_bucket_for_image_size(
|
bucket_resolution = get_bucket_for_image_size(
|
||||||
new_width, new_height,
|
new_width, new_height,
|
||||||
resolution=resolution,
|
resolution=self.dataset_config.resolution,
|
||||||
divisibility=bucket_tolerance
|
divisibility=bucket_tolerance
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -888,8 +878,10 @@ class PoiFileItemDTOMixin:
|
|||||||
self.scale_to_height = int(initial_height * max_scale_factor)
|
self.scale_to_height = int(initial_height * max_scale_factor)
|
||||||
self.crop_width = bucket_resolution['width']
|
self.crop_width = bucket_resolution['width']
|
||||||
self.crop_height = bucket_resolution['height']
|
self.crop_height = bucket_resolution['height']
|
||||||
self.crop_x = int(crop_left * max_scale_factor)
|
self.crop_x = int(poi_x * max_scale_factor)
|
||||||
self.crop_y = int(crop_top * max_scale_factor)
|
self.crop_y = int(poi_y * max_scale_factor)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ArgBreakMixin:
|
class ArgBreakMixin:
|
||||||
|
|||||||
@@ -193,6 +193,9 @@ class StableDiffusion:
|
|||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
torch_dtype=self.torch_dtype,
|
torch_dtype=self.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if 'vae' in load_args and load_args['vae'] is not None:
|
||||||
|
pipe.vae = load_args['vae']
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||||
@@ -679,6 +682,18 @@ class StableDiffusion:
|
|||||||
text_embeddings = text_embeddings.to(self.device_torch)
|
text_embeddings = text_embeddings.to(self.device_torch)
|
||||||
timestep = timestep.to(self.device_torch)
|
timestep = timestep.to(self.device_torch)
|
||||||
|
|
||||||
|
def scale_model_input(model_input, timestep_tensor):
|
||||||
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||||
|
out_chunks = []
|
||||||
|
for idx in range(model_input.shape[0]):
|
||||||
|
# if scheduler has step_index
|
||||||
|
if hasattr(self.noise_scheduler, '_step_index'):
|
||||||
|
self.noise_scheduler._step_index = None
|
||||||
|
out_chunks.append(
|
||||||
|
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx])
|
||||||
|
)
|
||||||
|
return torch.cat(out_chunks, dim=0)
|
||||||
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# 16, 6 for bs of 4
|
# 16, 6 for bs of 4
|
||||||
@@ -691,10 +706,11 @@ class StableDiffusion:
|
|||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
latent_model_input = torch.cat([latents] * 2)
|
latent_model_input = torch.cat([latents] * 2)
|
||||||
|
timestep = torch.cat([timestep] * 2)
|
||||||
else:
|
else:
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
|
|
||||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
latent_model_input = scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
# todo can we zero here the second text encoder? or match a blank string?
|
# todo can we zero here the second text encoder? or match a blank string?
|
||||||
@@ -784,10 +800,11 @@ class StableDiffusion:
|
|||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
# if we are doing classifier free guidance, need to double up
|
# if we are doing classifier free guidance, need to double up
|
||||||
latent_model_input = torch.cat([latents] * 2)
|
latent_model_input = torch.cat([latents] * 2)
|
||||||
|
timestep = torch.cat([timestep] * 2)
|
||||||
else:
|
else:
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
|
|
||||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
latent_model_input = scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
# check if we need to concat timesteps
|
# check if we need to concat timesteps
|
||||||
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
|
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
|
||||||
@@ -823,6 +840,23 @@ class StableDiffusion:
|
|||||||
|
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
||||||
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||||
|
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
||||||
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||||
|
out_chunks = []
|
||||||
|
for idx in range(model_input.shape[0]):
|
||||||
|
# Reset it so it is unique for the
|
||||||
|
if hasattr(self.noise_scheduler, '_step_index'):
|
||||||
|
self.noise_scheduler._step_index = None
|
||||||
|
if hasattr(self.noise_scheduler, 'is_scale_input_called'):
|
||||||
|
self.noise_scheduler.is_scale_input_called = True
|
||||||
|
out_chunks.append(
|
||||||
|
self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
||||||
|
0]
|
||||||
|
)
|
||||||
|
return torch.cat(out_chunks, dim=0)
|
||||||
|
|
||||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||||
def diffuse_some_steps(
|
def diffuse_some_steps(
|
||||||
self,
|
self,
|
||||||
@@ -839,6 +873,7 @@ class StableDiffusion:
|
|||||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||||
|
|
||||||
for timestep in tqdm(timesteps_to_run, leave=False):
|
for timestep in tqdm(timesteps_to_run, leave=False):
|
||||||
|
timestep = timestep.unsqueeze_(0)
|
||||||
noise_pred = self.predict_noise(
|
noise_pred = self.predict_noise(
|
||||||
latents,
|
latents,
|
||||||
text_embeddings,
|
text_embeddings,
|
||||||
@@ -847,7 +882,9 @@ class StableDiffusion:
|
|||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
# some schedulers need to run separately, so do that. (euler for example)
|
||||||
|
|
||||||
|
latents = self.step_scheduler(noise_pred, latents, timestep)
|
||||||
|
|
||||||
# if not last step, and bleeding, bleed in some latents
|
# if not last step, and bleeding, bleed in some latents
|
||||||
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
||||||
|
|||||||
@@ -584,8 +584,13 @@ def encode_prompts_xl(
|
|||||||
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
||||||
if max_length is None and not truncate:
|
if max_length is None and not truncate:
|
||||||
raise ValueError("max_length must be set if truncate is True")
|
raise ValueError("max_length must be set if truncate is True")
|
||||||
|
try:
|
||||||
tokens = tokens.to(text_encoder.device)
|
tokens = tokens.to(text_encoder.device)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("tokens.device", tokens.device)
|
||||||
|
print("text_encoder.device", text_encoder.device)
|
||||||
|
raise e
|
||||||
|
|
||||||
if truncate:
|
if truncate:
|
||||||
return text_encoder(tokens)[0]
|
return text_encoder(tokens)[0]
|
||||||
@@ -771,8 +776,8 @@ def apply_snr_weight(
|
|||||||
):
|
):
|
||||||
# will get it from noise scheduler if exist or will calculate it if not
|
# will get it from noise scheduler if exist or will calculate it if not
|
||||||
all_snr = get_all_snr(noise_scheduler, loss.device)
|
all_snr = get_all_snr(noise_scheduler, loss.device)
|
||||||
|
step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
||||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
snr = torch.stack([all_snr[t] for t in step_indices])
|
||||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||||
if fixed:
|
if fixed:
|
||||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||||
|
|||||||
Reference in New Issue
Block a user