Added IP adapter training. Not functioning correctly yet

This commit is contained in:
Jaret Burkett
2023-09-24 02:39:43 -06:00
parent 19255cdc7c
commit 830e87cb87
9 changed files with 336 additions and 53 deletions

View File

@@ -15,6 +15,7 @@ from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.network_mixins import Network
@@ -22,7 +23,8 @@ from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.progress_bar import ToolkitProgressBar
from toolkit.sampler import get_sampler
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
load_ip_adapter_model
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
@@ -118,9 +120,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, None] = None
self.embedding: Union[Embedding, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
# get the device state preset based on what we are training
self.train_device_state_preset = get_train_sd_device_state_preset(
device=self.device_torch,
@@ -128,17 +132,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_adapter=self.adapter_config is not None,
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
)
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
self.is_fine_tuning = True
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
if self.network_config is not None or is_training_adapter or self.embed_config is not None:
self.is_fine_tuning = False
self.named_lora = False
if self.embed_config is not None or self.adapter_config is not None:
if self.embed_config is not None or is_training_adapter:
self.named_lora = True
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -179,7 +183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
extra_args = {}
if self.adapter_config is not None:
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
gen_img_config_list.append(GenerateImageConfig(
@@ -318,22 +322,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(emb_file_path)
if self.adapter is not None:
if self.adapter is not None and self.adapter_config.train:
adapter_name = self.job.name
if self.network_config is not None or self.embedding is not None:
# add _lora to name
adapter_name += '_t2i'
if self.adapter_config.type == 't2i':
adapter_name += '_t2i'
else:
adapter_name += '_ip'
filename = f'{adapter_name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
# save adapter
state_dict = self.adapter.state_dict()
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
if self.adapter_config.type == 't2i':
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
save_ip_adapter_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
self.sd.save(
file_path,
@@ -527,6 +542,50 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def setup_adapter(self):
dtype = get_torch_dtype(self.train_config.dtype)
is_t2i = self.adapter_config.type == 't2i'
if is_t2i:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
else:
self.adapter = IPAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
self.adapter.to(self.device_torch, dtype=dtype)
# t2i adapter
suffix = 't2i' if is_t2i else 'ip'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
latest_save_path = self.get_latest_save_path(adapter_name)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
if is_t2i:
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device,
dtype=dtype
)
else:
loaded_state_dict = load_ip_adapter_model(
latest_save_path,
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
if self.adapter_config.train:
self.load_training_state_from_metadata(latest_save_path)
# set trainable params
self.sd.adapter = self.adapter
def run(self):
# torch.autograd.set_detect_anomaly(True)
# run base process run
@@ -741,35 +800,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.adapter_config is not None:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
self.adapter.to(self.device_torch, dtype=dtype)
# t2i adapter
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_t2i"
latest_save_path = self.get_latest_save_path(adapter_name)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
self.load_training_state_from_metadata(latest_save_path)
self.setup_adapter()
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
self.sd.adapter = self.adapter
flush()
params = self.load_additional_training_modules(params)
@@ -785,6 +821,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
# we may be using it for prompt injections
if self.adapter_config is not None:
self.setup_adapter()
flush()
### HOOK ###
params = self.hook_add_extra_train_params(params)