mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issues with dataloader bucketing. Allow using standard base image for t2i adapters.
This commit is contained in:
@@ -75,7 +75,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.has_first_sample_requested = False
|
||||
self.first_sample_config = self.sample_config
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.optimizer: torch.optim.Optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
self.data_loader_reg: Union[DataLoader, None] = None
|
||||
@@ -543,28 +543,39 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
def setup_adapter(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# t2i adapter
|
||||
is_t2i = self.adapter_config.type == 't2i'
|
||||
suffix = 't2i' if is_t2i else 'ip'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
if is_t2i:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
# if we do not have a last save path and we have a name_or_path,
|
||||
# load from that
|
||||
if latest_save_path is None and self.adapter_config.name_or_path is not None:
|
||||
self.adapter = T2IAdapter.from_pretrained(
|
||||
self.adapter_config.name_or_path,
|
||||
torch_dtype=get_torch_dtype(self.train_config.dtype),
|
||||
varient="fp16",
|
||||
# use_safetensors=True,
|
||||
)
|
||||
else:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
else:
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
# t2i adapter
|
||||
suffix = 't2i' if is_t2i else 'ip'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
|
||||
Reference in New Issue
Block a user