Fixed issues with dataloader bucketing. Allow using standard base image for t2i adapters.

This commit is contained in:
Jaret Burkett
2023-09-24 05:19:57 -06:00
parent 830e87cb87
commit e5153d87c9
5 changed files with 41 additions and 20 deletions

View File

@@ -21,7 +21,8 @@ def flush():
gc.collect()
adapter_transforms = transforms.Compose([
transforms.PILToTensor(),
# transforms.PILToTensor(),
transforms.ToTensor(),
])
@@ -44,6 +45,13 @@ class SDTrainer(BaseSDTrainProcess):
flush()
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
if self.adapter_config.image_dir is None:
# adapter needs 0 to 1 values, batch is -1 to 1
adapter_batch = batch.tensor.clone().to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
adapter_batch = (adapter_batch + 1) / 2
return adapter_batch
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.adapter_config.image_dir
adapter_images = []