Fixed issues with dataloader bucketing. Allow using standard base image for t2i adapters.

This commit is contained in:
Jaret Burkett
2023-09-24 05:19:57 -06:00
parent 830e87cb87
commit e5153d87c9
5 changed files with 41 additions and 20 deletions

View File

@@ -75,7 +75,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.has_first_sample_requested = False
self.first_sample_config = self.sample_config
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.optimizer: torch.optim.Optimizer = None
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
self.data_loader_reg: Union[DataLoader, None] = None
@@ -543,28 +543,39 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def setup_adapter(self):
dtype = get_torch_dtype(self.train_config.dtype)
# t2i adapter
is_t2i = self.adapter_config.type == 't2i'
suffix = 't2i' if is_t2i else 'ip'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
latest_save_path = self.get_latest_save_path(adapter_name)
dtype = get_torch_dtype(self.train_config.dtype)
if is_t2i:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
# if we do not have a last save path and we have a name_or_path,
# load from that
if latest_save_path is None and self.adapter_config.name_or_path is not None:
self.adapter = T2IAdapter.from_pretrained(
self.adapter_config.name_or_path,
torch_dtype=get_torch_dtype(self.train_config.dtype),
varient="fp16",
# use_safetensors=True,
)
else:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
else:
self.adapter = IPAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
self.adapter.to(self.device_torch, dtype=dtype)
# t2i adapter
suffix = 't2i' if is_t2i else 'ip'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
latest_save_path = self.get_latest_save_path(adapter_name)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")