mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed big issue with bucketing dataloader and added random cripping to a point of interest
This commit is contained in:
@@ -12,7 +12,7 @@ import torch
|
||||
import torch.backends.cuda
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
@@ -931,16 +931,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user