Added a basic torch profiler that can be used in config during development to find some obvious issues.

This commit is contained in:
Jaret Burkett
2025-06-17 13:03:39 -06:00
parent ff617fdaea
commit 989ebfaa11
4 changed files with 53 additions and 66 deletions

View File

@@ -15,7 +15,6 @@ class BaseJob:
self.config = config['config']
self.raw_config = config
self.job = config['job']
self.torch_profiler = self.get_conf('torch_profiler', False)
self.name = self.get_conf('name', required=True)
if 'meta' in config:
self.meta = config['meta']

View File

@@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.ema: ExponentialMovingAverage = None
validate_configs(self.train_config, self.model_config, self.save_config)
do_profiler = self.get_conf('torch_profiler', False)
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -2058,6 +2066,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush()
### HOOK ###
if self.torch_profiler is not None:
self.torch_profiler.start()
with self.accelerator.accumulate(self.modules_being_trained):
try:
loss_dict = self.hook_train_loop(batch_list)
@@ -2069,7 +2079,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
for item in batch.file_items:
print(f" - {item.path}")
raise e
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()
print("\n==== Profile Results ====")
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
self.timer.stop('train_loop')
if not did_first_flush:
flush()

View File

@@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer):
do_paramiter_swapping=False,
paramiter_swapping_factor=0.1,
stochastic_accumulation=True,
stochastic_rounding=True,
):
self.stochastic_rounding = stochastic_rounding
if lr is not None and relative_step:
raise ValueError(
"Cannot combine manual `lr` and `relative_step=True` options")
@@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer):
p_data_fp32.add_(-update)
if p.dtype != torch.float32:
if p.dtype != torch.float32 and self.stochastic_rounding:
# apply stochastic rounding
copy_stochastic(p, p_data_fp32)

View File

@@ -112,86 +112,57 @@ def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
return 0, 8 # Int8 doesn't have mantissa bits
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor):
# adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5
# create a random 16 bit integer
result = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << 16),
)
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
del result
def copy_stochastic(
target: torch.Tensor,
source: torch.Tensor,
eps: Optional[float] = None
) -> None:
"""
Performs stochastic rounding from source tensor to target tensor.
Args:
target: Destination tensor (determines the target format)
source: Source tensor (typically float32)
eps: Optional minimum value for stochastic rounding (for numerical stability)
"""
def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None:
with torch.no_grad():
# If target is float32, just copy directly
# assert if target is on cpu, throw error
assert target.device.type != 'cpu', "Target is on cpu!"
assert source.device.type != 'cpu', "Source is on cpu!"
if target.dtype == torch.float32:
target.copy_(source)
return
# Special handling for int8
if target.dtype == torch.int8:
# Scale the source values to utilize the full int8 range
scaled = source * 127.0 # Scale to [-127, 127]
# Add random noise for stochastic rounding
noise = torch.rand_like(scaled) - 0.5
rounded = torch.round(scaled + noise)
# Clamp to int8 range
clamped = torch.clamp(rounded, -127, 127)
target.copy_(clamped.to(torch.int8))
if target.dtype == torch.bfloat16:
copy_stochastic_bf16(target, source)
return
mantissa_bits, _ = get_format_params(target.dtype)
round_factor = 2 ** (23 - mantissa_bits)
# Convert source to int32 view
source_int = source.view(dtype=torch.int32)
# Add uniform noise for stochastic rounding
noise = torch.rand_like(source, device=source.device) - 0.5
rounded = torch.round(source * round_factor + noise)
result_float = rounded / round_factor
# Calculate number of bits to round
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
# Create random integers for stochastic rounding
rand = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << bits_to_round),
)
# Add random values to the bits that will be rounded off
result = source_int.clone()
result.add_(rand)
# Mask to keep only the bits we want
# Create mask with 1s in positions we want to keep
mask = (-1) << bits_to_round
result.bitwise_and_(mask)
# Handle minimum value threshold if specified
if eps is not None:
eps_int = torch.tensor(
eps, dtype=torch.float32).view(dtype=torch.int32)
zero_mask = (result.abs() < eps_int)
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
# Convert back to float32 view
result_float = result.view(dtype=torch.float32)
# Special handling for float8 formats
# Clamp for float8
if target.dtype == torch.float8_e4m3fn:
result_float.clamp_(-448.0, 448.0)
elif target.dtype == torch.float8_e5m2:
result_float.clamp_(-57344.0, 57344.0)
# Copy the result to the target tensor
update_parameter(target, result_float)
# target.copy_(result_float)
del result, rand, source_int
class Auto8bitTensor: