Added a basic torch profiler that can be used in config during development to find some obvious issues.

This commit is contained in:
Jaret Burkett
2025-06-17 13:03:39 -06:00
parent ff617fdaea
commit 989ebfaa11
4 changed files with 53 additions and 66 deletions

View File

@@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.ema: ExponentialMovingAverage = None
validate_configs(self.train_config, self.model_config, self.save_config)
do_profiler = self.get_conf('torch_profiler', False)
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -2058,6 +2066,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush()
### HOOK ###
if self.torch_profiler is not None:
self.torch_profiler.start()
with self.accelerator.accumulate(self.modules_being_trained):
try:
loss_dict = self.hook_train_loop(batch_list)
@@ -2069,7 +2079,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
for item in batch.file_items:
print(f" - {item.path}")
raise e
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()
print("\n==== Profile Results ====")
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
self.timer.stop('train_loop')
if not did_first_flush:
flush()