mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue with grad scaling
This commit is contained in:
@@ -1520,9 +1520,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
# if self.is_bfloat:
|
||||
loss.backward()
|
||||
# loss.backward()
|
||||
# else:
|
||||
# self.scaler.scale(loss).backward()
|
||||
self.scaler.scale(loss).backward()
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
|
||||
@@ -559,7 +559,7 @@ class DatasetConfig:
|
||||
self.replacements: List[str] = kwargs.get('replacements', [])
|
||||
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
|
||||
|
||||
self.num_workers: int = kwargs.get('num_workers', 4)
|
||||
self.num_workers: int = kwargs.get('num_workers', 2)
|
||||
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
||||
self.extra_values: List[float] = kwargs.get('extra_values', [])
|
||||
self.square_crop: bool = kwargs.get('square_crop', False)
|
||||
|
||||
@@ -400,10 +400,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
|
||||
# check if dataset_path is a folder or json
|
||||
if os.path.isdir(self.dataset_path):
|
||||
file_list = [
|
||||
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
]
|
||||
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
else:
|
||||
# assume json
|
||||
with open(self.dataset_path, 'r') as f:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user