mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a basic torch profiler that can be used in config during development to find some obvious issues.
This commit is contained in:
@@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.ema: ExponentialMovingAverage = None
|
||||
|
||||
validate_configs(self.train_config, self.model_config, self.save_config)
|
||||
|
||||
do_profiler = self.get_conf('torch_profiler', False)
|
||||
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -2058,6 +2066,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.start()
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
try:
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
@@ -2069,7 +2079,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for item in batch.file_items:
|
||||
print(f" - {item.path}")
|
||||
raise e
|
||||
|
||||
if self.torch_profiler is not None:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
self.torch_profiler.stop()
|
||||
|
||||
print("\n==== Profile Results ====")
|
||||
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
flush()
|
||||
|
||||
Reference in New Issue
Block a user