mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Fixed issue with grad scaling
This commit is contained in:
@@ -1520,9 +1520,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# I spent weeks on fighting this. DON'T DO IT
|
# I spent weeks on fighting this. DON'T DO IT
|
||||||
# with fsdp_overlap_step_with_backward():
|
# with fsdp_overlap_step_with_backward():
|
||||||
# if self.is_bfloat:
|
# if self.is_bfloat:
|
||||||
loss.backward()
|
# loss.backward()
|
||||||
# else:
|
# else:
|
||||||
# self.scaler.scale(loss).backward()
|
self.scaler.scale(loss).backward()
|
||||||
# flush()
|
# flush()
|
||||||
|
|
||||||
if not self.is_grad_accumulation_step:
|
if not self.is_grad_accumulation_step:
|
||||||
|
|||||||
@@ -559,7 +559,7 @@ class DatasetConfig:
|
|||||||
self.replacements: List[str] = kwargs.get('replacements', [])
|
self.replacements: List[str] = kwargs.get('replacements', [])
|
||||||
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
|
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
|
||||||
|
|
||||||
self.num_workers: int = kwargs.get('num_workers', 4)
|
self.num_workers: int = kwargs.get('num_workers', 2)
|
||||||
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
||||||
self.extra_values: List[float] = kwargs.get('extra_values', [])
|
self.extra_values: List[float] = kwargs.get('extra_values', [])
|
||||||
self.square_crop: bool = kwargs.get('square_crop', False)
|
self.square_crop: bool = kwargs.get('square_crop', False)
|
||||||
|
|||||||
@@ -400,10 +400,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
|||||||
|
|
||||||
# check if dataset_path is a folder or json
|
# check if dataset_path is a folder or json
|
||||||
if os.path.isdir(self.dataset_path):
|
if os.path.isdir(self.dataset_path):
|
||||||
file_list = [
|
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||||
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
|
|
||||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
# assume json
|
# assume json
|
||||||
with open(self.dataset_path, 'r') as f:
|
with open(self.dataset_path, 'r') as f:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user