Added reference adapters, many bug fixes, more ip adapter work and customizability

This commit is contained in:
Jaret Burkett
2024-01-01 17:15:53 -07:00
parent bafacf3b65
commit afc231efc1
7 changed files with 510 additions and 30 deletions

View File

@@ -31,6 +31,7 @@ from toolkit.network_mixins import Network
from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.progress_bar import ToolkitProgressBar
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.sampler import get_sampler
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
load_ip_adapter_model
@@ -140,7 +141,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, None] = None
self.embedding: Union[Embedding, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
@@ -771,8 +772,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
num_train_timesteps, device=self.device_torch
)
content_or_style = self.train_config.content_or_style
if is_reg:
content_or_style = self.train_config.content_or_style_reg
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
if content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
@@ -783,9 +788,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
orig_timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'content':
if content_or_style == 'content':
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'style':
elif content_or_style == 'style':
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timestep_indices = value_map(
@@ -800,7 +805,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
max_noise_steps - 1
)
elif self.train_config.content_or_style == 'balanced':
elif content_or_style == 'balanced':
if min_noise_steps == max_noise_steps:
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
else:
@@ -813,7 +818,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
timestep_indices = timestep_indices.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
raise ValueError(f"Unknown content_or_style {content_or_style}")
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
@@ -824,9 +829,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
# this will negate any noise offsets
if self.train_config.dynamic_noise_offset and not is_reg:
latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2
# subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel
noise = noise + latents_channel_mean
if self.train_config.loss_target == 'differential_noise':
differential = latents - unaugmented_latents
# add noise to differential
@@ -912,6 +924,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
suffix = 't2i'
elif self.adapter_config.type == 'clip':
suffix = 'clip'
elif self.adapter_config.type == 'reference':
suffix = 'ref'
else:
suffix = 'ip'
adapter_name = self.name
@@ -943,6 +957,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
sd=self.sd,
adapter_config=self.adapter_config,
)
elif self.adapter_config.type == 'reference':
self.adapter = ReferenceAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
else:
self.adapter = IPAdapter(
sd=self.sd,
@@ -1441,6 +1460,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
did_first_flush = True
# flush()
# setup the networks to gradient checkpointing and everything works
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
self.adapter.clear_memory()
with torch.no_grad():
# torch.cuda.empty_cache()