mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Fixed issues with dataloader bucketing. Allow using standard base image for t2i adapters.
This commit is contained in:
@@ -21,7 +21,8 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
transforms.PILToTensor(),
|
||||
# transforms.PILToTensor(),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
@@ -44,6 +45,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
flush()
|
||||
|
||||
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
|
||||
if self.adapter_config.image_dir is None:
|
||||
# adapter needs 0 to 1 values, batch is -1 to 1
|
||||
adapter_batch = batch.tensor.clone().to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
adapter_batch = (adapter_batch + 1) / 2
|
||||
return adapter_batch
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
adapter_folder_path = self.adapter_config.image_dir
|
||||
adapter_images = []
|
||||
|
||||
Reference in New Issue
Block a user