mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added reference adapters, many bug fixes, more ip adapter work and customizability
This commit is contained in:
@@ -31,6 +31,7 @@ from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
|
||||
load_ip_adapter_model
|
||||
@@ -140,7 +141,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
@@ -771,8 +772,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
|
||||
content_or_style = self.train_config.content_or_style
|
||||
if is_reg:
|
||||
content_or_style = self.train_config.content_or_style_reg
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
if content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
@@ -783,9 +788,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'content':
|
||||
if content_or_style == 'content':
|
||||
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'style':
|
||||
elif content_or_style == 'style':
|
||||
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timestep_indices = value_map(
|
||||
@@ -800,7 +805,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
max_noise_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
elif content_or_style == 'balanced':
|
||||
if min_noise_steps == max_noise_steps:
|
||||
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
else:
|
||||
@@ -813,7 +818,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
timestep_indices = timestep_indices.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
@@ -824,9 +829,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||
# this will negate any noise offsets
|
||||
if self.train_config.dynamic_noise_offset and not is_reg:
|
||||
latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2
|
||||
# subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel
|
||||
noise = noise + latents_channel_mean
|
||||
|
||||
if self.train_config.loss_target == 'differential_noise':
|
||||
differential = latents - unaugmented_latents
|
||||
# add noise to differential
|
||||
@@ -912,6 +924,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
suffix = 't2i'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
suffix = 'clip'
|
||||
elif self.adapter_config.type == 'reference':
|
||||
suffix = 'ref'
|
||||
else:
|
||||
suffix = 'ip'
|
||||
adapter_name = self.name
|
||||
@@ -943,6 +957,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
elif self.adapter_config.type == 'reference':
|
||||
self.adapter = ReferenceAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
else:
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
@@ -1441,6 +1460,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
did_first_flush = True
|
||||
# flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
self.adapter.clear_memory()
|
||||
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user