Fixed big issue with bucketing dataloader and added random cripping to a point of interest

This commit is contained in:
Jaret Burkett
2023-10-02 18:31:08 -06:00
parent 320e109c5f
commit 579650eaf8
6 changed files with 264 additions and 72 deletions

View File

@@ -12,7 +12,7 @@ import torch
import torch.backends.cuda
from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter
@@ -931,16 +931,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
is_reg_step = True
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None