diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py index 8efd0097..5b339ebe 100644 --- a/jobs/BaseJob.py +++ b/jobs/BaseJob.py @@ -15,7 +15,6 @@ class BaseJob: self.config = config['config'] self.raw_config = config self.job = config['job'] - self.torch_profiler = self.get_conf('torch_profiler', False) self.name = self.get_conf('name', required=True) if 'meta' in config: self.meta = config['meta'] diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 313285fc..1528e772 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess): self.ema: ExponentialMovingAverage = None validate_configs(self.train_config, self.model_config, self.save_config) + + do_profiler = self.get_conf('torch_profiler', False) + self.torch_profiler = None if not do_profiler else torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -2058,6 +2066,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### + if self.torch_profiler is not None: + self.torch_profiler.start() with self.accelerator.accumulate(self.modules_being_trained): try: loss_dict = self.hook_train_loop(batch_list) @@ -2069,7 +2079,12 @@ class BaseSDTrainProcess(BaseTrainProcess): for item in batch.file_items: print(f" - {item.path}") raise e - + if self.torch_profiler is not None: + torch.cuda.synchronize() # Make sure all CUDA ops are done + self.torch_profiler.stop() + + print("\n==== Profile Results ====") + print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) self.timer.stop('train_loop') if not did_first_flush: flush() diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 4d97b2cd..8897bdc0 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer): do_paramiter_swapping=False, paramiter_swapping_factor=0.1, stochastic_accumulation=True, + stochastic_rounding=True, ): + self.stochastic_rounding = stochastic_rounding if lr is not None and relative_step: raise ValueError( "Cannot combine manual `lr` and `relative_step=True` options") @@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer): p_data_fp32.add_(-update) - if p.dtype != torch.float32: + if p.dtype != torch.float32 and self.stochastic_rounding: # apply stochastic rounding copy_stochastic(p, p_data_fp32) diff --git a/toolkit/optimizers/optimizer_utils.py b/toolkit/optimizers/optimizer_utils.py index 67991f21..3adae275 100644 --- a/toolkit/optimizers/optimizer_utils.py +++ b/toolkit/optimizers/optimizer_utils.py @@ -112,86 +112,57 @@ def get_format_params(dtype: torch.dtype) -> tuple[int, int]: return 0, 8 # Int8 doesn't have mantissa bits else: raise ValueError(f"Unsupported dtype: {dtype}") + +def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor): + # adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5 + # create a random 16 bit integer + result = torch.randint_like( + source, + dtype=torch.int32, + low=0, + high=(1 << 16), + ) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result -def copy_stochastic( - target: torch.Tensor, - source: torch.Tensor, - eps: Optional[float] = None -) -> None: - """ - Performs stochastic rounding from source tensor to target tensor. - - Args: - target: Destination tensor (determines the target format) - source: Source tensor (typically float32) - eps: Optional minimum value for stochastic rounding (for numerical stability) - """ +def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None: with torch.no_grad(): - # If target is float32, just copy directly + # assert if target is on cpu, throw error + assert target.device.type != 'cpu', "Target is on cpu!" + assert source.device.type != 'cpu', "Source is on cpu!" + if target.dtype == torch.float32: target.copy_(source) return - - # Special handling for int8 - if target.dtype == torch.int8: - # Scale the source values to utilize the full int8 range - scaled = source * 127.0 # Scale to [-127, 127] - - # Add random noise for stochastic rounding - noise = torch.rand_like(scaled) - 0.5 - rounded = torch.round(scaled + noise) - - # Clamp to int8 range - clamped = torch.clamp(rounded, -127, 127) - target.copy_(clamped.to(torch.int8)) + if target.dtype == torch.bfloat16: + copy_stochastic_bf16(target, source) return mantissa_bits, _ = get_format_params(target.dtype) + round_factor = 2 ** (23 - mantissa_bits) - # Convert source to int32 view - source_int = source.view(dtype=torch.int32) + # Add uniform noise for stochastic rounding + noise = torch.rand_like(source, device=source.device) - 0.5 + rounded = torch.round(source * round_factor + noise) + result_float = rounded / round_factor - # Calculate number of bits to round - bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits - - # Create random integers for stochastic rounding - rand = torch.randint_like( - source, - dtype=torch.int32, - low=0, - high=(1 << bits_to_round), - ) - - # Add random values to the bits that will be rounded off - result = source_int.clone() - result.add_(rand) - - # Mask to keep only the bits we want - # Create mask with 1s in positions we want to keep - mask = (-1) << bits_to_round - result.bitwise_and_(mask) - - # Handle minimum value threshold if specified - if eps is not None: - eps_int = torch.tensor( - eps, dtype=torch.float32).view(dtype=torch.int32) - zero_mask = (result.abs() < eps_int) - result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int - - # Convert back to float32 view - result_float = result.view(dtype=torch.float32) - - # Special handling for float8 formats + # Clamp for float8 if target.dtype == torch.float8_e4m3fn: result_float.clamp_(-448.0, 448.0) elif target.dtype == torch.float8_e5m2: result_float.clamp_(-57344.0, 57344.0) - # Copy the result to the target tensor update_parameter(target, result_float) - # target.copy_(result_float) - del result, rand, source_int class Auto8bitTensor: