mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Added bucketting capabilities to dataloader. Finally have full planned capability. noice
This commit is contained in:
@@ -288,7 +288,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
imgs, prompts, dataset_config = batch
|
||||
|
||||
# convert the 0 or 1 for is reg to a bool list
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(dataset_config, list):
|
||||
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
|
||||
else:
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(is_reg_list, torch.Tensor):
|
||||
is_reg_list = is_reg_list.numpy().tolist()
|
||||
is_reg_list = [bool(x) for x in is_reg_list]
|
||||
|
||||
Reference in New Issue
Block a user