Added a basic torch profiler that can be used in config during development to find some obvious issues.

This commit is contained in:
Jaret Burkett
2025-06-17 13:03:39 -06:00
parent ff617fdaea
commit 989ebfaa11
4 changed files with 53 additions and 66 deletions

View File

@@ -112,86 +112,57 @@ def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
return 0, 8 # Int8 doesn't have mantissa bits
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor):
# adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5
# create a random 16 bit integer
result = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << 16),
)
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
del result
def copy_stochastic(
target: torch.Tensor,
source: torch.Tensor,
eps: Optional[float] = None
) -> None:
"""
Performs stochastic rounding from source tensor to target tensor.
Args:
target: Destination tensor (determines the target format)
source: Source tensor (typically float32)
eps: Optional minimum value for stochastic rounding (for numerical stability)
"""
def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None:
with torch.no_grad():
# If target is float32, just copy directly
# assert if target is on cpu, throw error
assert target.device.type != 'cpu', "Target is on cpu!"
assert source.device.type != 'cpu', "Source is on cpu!"
if target.dtype == torch.float32:
target.copy_(source)
return
# Special handling for int8
if target.dtype == torch.int8:
# Scale the source values to utilize the full int8 range
scaled = source * 127.0 # Scale to [-127, 127]
# Add random noise for stochastic rounding
noise = torch.rand_like(scaled) - 0.5
rounded = torch.round(scaled + noise)
# Clamp to int8 range
clamped = torch.clamp(rounded, -127, 127)
target.copy_(clamped.to(torch.int8))
if target.dtype == torch.bfloat16:
copy_stochastic_bf16(target, source)
return
mantissa_bits, _ = get_format_params(target.dtype)
round_factor = 2 ** (23 - mantissa_bits)
# Convert source to int32 view
source_int = source.view(dtype=torch.int32)
# Add uniform noise for stochastic rounding
noise = torch.rand_like(source, device=source.device) - 0.5
rounded = torch.round(source * round_factor + noise)
result_float = rounded / round_factor
# Calculate number of bits to round
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
# Create random integers for stochastic rounding
rand = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << bits_to_round),
)
# Add random values to the bits that will be rounded off
result = source_int.clone()
result.add_(rand)
# Mask to keep only the bits we want
# Create mask with 1s in positions we want to keep
mask = (-1) << bits_to_round
result.bitwise_and_(mask)
# Handle minimum value threshold if specified
if eps is not None:
eps_int = torch.tensor(
eps, dtype=torch.float32).view(dtype=torch.int32)
zero_mask = (result.abs() < eps_int)
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
# Convert back to float32 view
result_float = result.view(dtype=torch.float32)
# Special handling for float8 formats
# Clamp for float8
if target.dtype == torch.float8_e4m3fn:
result_float.clamp_(-448.0, 448.0)
elif target.dtype == torch.float8_e5m2:
result_float.clamp_(-57344.0, 57344.0)
# Copy the result to the target tensor
update_parameter(target, result_float)
# target.copy_(result_float)
del result, rand, source_int
class Auto8bitTensor: