Fixed issue with grad scaling

This commit is contained in:
Jaret Burkett
2024-07-20 08:21:57 -06:00
parent a2301cf28c
commit 22d2f6e28f
4 changed files with 4 additions and 3161 deletions

View File

@@ -1520,9 +1520,9 @@ class SDTrainer(BaseSDTrainProcess):
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
# if self.is_bfloat:
loss.backward()
# loss.backward()
# else:
# self.scaler.scale(loss).backward()
self.scaler.scale(loss).backward()
# flush()
if not self.is_grad_accumulation_step:

View File

@@ -559,7 +559,7 @@ class DatasetConfig:
self.replacements: List[str] = kwargs.get('replacements', [])
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
self.num_workers: int = kwargs.get('num_workers', 4)
self.num_workers: int = kwargs.get('num_workers', 2)
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
self.extra_values: List[float] = kwargs.get('extra_values', [])
self.square_crop: bool = kwargs.get('square_crop', False)

View File

@@ -400,10 +400,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
# check if dataset_path is a folder or json
if os.path.isdir(self.dataset_path):
file_list = [
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
else:
# assume json
with open(self.dataset_path, 'r') as f:

File diff suppressed because it is too large Load Diff