Added bucketting capabilities to dataloader. Finally have full planned capability. noice

This commit is contained in:
Jaret Burkett
2023-08-26 16:36:32 -06:00
parent 2cb27c3f57
commit 8105c05c12
6 changed files with 707 additions and 42 deletions

View File

@@ -288,7 +288,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(dataset_config, list):
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
else:
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]