Fixed issues with dataloader bucketing. Allow using standard base image for t2i adapters.

This commit is contained in:
Jaret Burkett
2023-09-24 05:19:57 -06:00
parent 830e87cb87
commit e5153d87c9
5 changed files with 41 additions and 20 deletions

View File

@@ -21,7 +21,8 @@ def flush():
gc.collect()
adapter_transforms = transforms.Compose([
transforms.PILToTensor(),
# transforms.PILToTensor(),
transforms.ToTensor(),
])
@@ -44,6 +45,13 @@ class SDTrainer(BaseSDTrainProcess):
flush()
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
if self.adapter_config.image_dir is None:
# adapter needs 0 to 1 values, batch is -1 to 1
adapter_batch = batch.tensor.clone().to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
adapter_batch = (adapter_batch + 1) / 2
return adapter_batch
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.adapter_config.image_dir
adapter_images = []

View File

@@ -75,7 +75,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.has_first_sample_requested = False
self.first_sample_config = self.sample_config
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.optimizer: torch.optim.Optimizer = None
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
self.data_loader_reg: Union[DataLoader, None] = None
@@ -543,28 +543,39 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def setup_adapter(self):
dtype = get_torch_dtype(self.train_config.dtype)
# t2i adapter
is_t2i = self.adapter_config.type == 't2i'
suffix = 't2i' if is_t2i else 'ip'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
latest_save_path = self.get_latest_save_path(adapter_name)
dtype = get_torch_dtype(self.train_config.dtype)
if is_t2i:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
# if we do not have a last save path and we have a name_or_path,
# load from that
if latest_save_path is None and self.adapter_config.name_or_path is not None:
self.adapter = T2IAdapter.from_pretrained(
self.adapter_config.name_or_path,
torch_dtype=get_torch_dtype(self.train_config.dtype),
varient="fp16",
# use_safetensors=True,
)
else:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
else:
self.adapter = IPAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
self.adapter.to(self.device_torch, dtype=dtype)
# t2i adapter
suffix = 't2i' if is_t2i else 'ip'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
latest_save_path = self.get_latest_save_path(adapter_name)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")

View File

@@ -83,7 +83,8 @@ def get_bucket_for_image_size(
width: int,
height: int,
bucket_size_list: List[BucketResolution] = None,
resolution: Union[int, None] = None
resolution: Union[int, None] = None,
divisibility: int = 8
) -> BucketResolution:
if bucket_size_list is None and resolution is None:
@@ -93,7 +94,7 @@ def get_bucket_for_image_size(
# if real resolution is smaller, use that instead
real_resolution = get_resolution(width, height)
resolution = min(resolution, real_resolution)
bucket_size_list = get_bucket_sizes(resolution=resolution)
bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility)
# Check for exact match first
for bucket in bucket_size_list:

View File

@@ -281,7 +281,8 @@ class PairedImageDataset(Dataset):
bucket_resolution = get_bucket_for_image_size(
width=img2.width,
height=img2.height,
resolution=self.size
resolution=self.size,
# divisibility=self.
)
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop

View File

@@ -121,7 +121,7 @@ class BucketsMixin:
width = file_item.crop_width
height = file_item.crop_height
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution)
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, divisibility=bucket_tolerance)
# set the scaling height and with to match smallest size, and keep aspect ratio
if width > height: