mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added IP adapter training. Not functioning correctly yet
This commit is contained in:
@@ -15,6 +15,7 @@ from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.network_mixins import Network
|
||||
@@ -22,7 +23,8 @@ from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
|
||||
load_ip_adapter_model
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
@@ -118,9 +120,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
@@ -128,17 +132,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=self.adapter_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
)
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
self.is_fine_tuning = True
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
if self.network_config is not None or is_training_adapter or self.embed_config is not None:
|
||||
self.is_fine_tuning = False
|
||||
|
||||
self.named_lora = False
|
||||
if self.embed_config is not None or self.adapter_config is not None:
|
||||
if self.embed_config is not None or is_training_adapter:
|
||||
self.named_lora = True
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -179,7 +183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter_config is not None:
|
||||
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
||||
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
@@ -318,22 +322,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(emb_file_path)
|
||||
|
||||
if self.adapter is not None:
|
||||
if self.adapter is not None and self.adapter_config.train:
|
||||
adapter_name = self.job.name
|
||||
if self.network_config is not None or self.embedding is not None:
|
||||
# add _lora to name
|
||||
adapter_name += '_t2i'
|
||||
if self.adapter_config.type == 't2i':
|
||||
adapter_name += '_t2i'
|
||||
else:
|
||||
adapter_name += '_ip'
|
||||
|
||||
filename = f'{adapter_name}{step_num}.safetensors'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# save adapter
|
||||
state_dict = self.adapter.state_dict()
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
if self.adapter_config.type == 't2i':
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
save_ip_adapter_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -527,6 +542,50 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
def setup_adapter(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
is_t2i = self.adapter_config.type == 't2i'
|
||||
if is_t2i:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
else:
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
# t2i adapter
|
||||
suffix = 't2i' if is_t2i else 'ip'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
if is_t2i:
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
else:
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
if self.adapter_config.train:
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
# set trainable params
|
||||
self.sd.adapter = self.adapter
|
||||
|
||||
def run(self):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# run base process run
|
||||
@@ -741,35 +800,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
# t2i adapter
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_t2i"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
self.setup_adapter()
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
|
||||
params = self.load_additional_training_modules(params)
|
||||
@@ -785,6 +821,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
|
||||
Reference in New Issue
Block a user