mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 00:03:57 +00:00
Fixed issues with dataloader bucketing. Allow using standard base image for t2i adapters.
This commit is contained in:
@@ -21,7 +21,8 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
transforms.PILToTensor(),
|
||||
# transforms.PILToTensor(),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
@@ -44,6 +45,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
flush()
|
||||
|
||||
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
|
||||
if self.adapter_config.image_dir is None:
|
||||
# adapter needs 0 to 1 values, batch is -1 to 1
|
||||
adapter_batch = batch.tensor.clone().to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
adapter_batch = (adapter_batch + 1) / 2
|
||||
return adapter_batch
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
adapter_folder_path = self.adapter_config.image_dir
|
||||
adapter_images = []
|
||||
|
||||
@@ -75,7 +75,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.has_first_sample_requested = False
|
||||
self.first_sample_config = self.sample_config
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.optimizer: torch.optim.Optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
self.data_loader_reg: Union[DataLoader, None] = None
|
||||
@@ -543,28 +543,39 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
def setup_adapter(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# t2i adapter
|
||||
is_t2i = self.adapter_config.type == 't2i'
|
||||
suffix = 't2i' if is_t2i else 'ip'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
if is_t2i:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
# if we do not have a last save path and we have a name_or_path,
|
||||
# load from that
|
||||
if latest_save_path is None and self.adapter_config.name_or_path is not None:
|
||||
self.adapter = T2IAdapter.from_pretrained(
|
||||
self.adapter_config.name_or_path,
|
||||
torch_dtype=get_torch_dtype(self.train_config.dtype),
|
||||
varient="fp16",
|
||||
# use_safetensors=True,
|
||||
)
|
||||
else:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
else:
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
# t2i adapter
|
||||
suffix = 't2i' if is_t2i else 'ip'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
|
||||
@@ -83,7 +83,8 @@ def get_bucket_for_image_size(
|
||||
width: int,
|
||||
height: int,
|
||||
bucket_size_list: List[BucketResolution] = None,
|
||||
resolution: Union[int, None] = None
|
||||
resolution: Union[int, None] = None,
|
||||
divisibility: int = 8
|
||||
) -> BucketResolution:
|
||||
|
||||
if bucket_size_list is None and resolution is None:
|
||||
@@ -93,7 +94,7 @@ def get_bucket_for_image_size(
|
||||
# if real resolution is smaller, use that instead
|
||||
real_resolution = get_resolution(width, height)
|
||||
resolution = min(resolution, real_resolution)
|
||||
bucket_size_list = get_bucket_sizes(resolution=resolution)
|
||||
bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility)
|
||||
|
||||
# Check for exact match first
|
||||
for bucket in bucket_size_list:
|
||||
|
||||
@@ -281,7 +281,8 @@ class PairedImageDataset(Dataset):
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
width=img2.width,
|
||||
height=img2.height,
|
||||
resolution=self.size
|
||||
resolution=self.size,
|
||||
# divisibility=self.
|
||||
)
|
||||
|
||||
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop
|
||||
|
||||
@@ -121,7 +121,7 @@ class BucketsMixin:
|
||||
width = file_item.crop_width
|
||||
height = file_item.crop_height
|
||||
|
||||
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution)
|
||||
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, divisibility=bucket_tolerance)
|
||||
|
||||
# set the scaling height and with to match smallest size, and keep aspect ratio
|
||||
if width > height:
|
||||
|
||||
Reference in New Issue
Block a user