mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a basic torch profiler that can be used in config during development to find some obvious issues.
This commit is contained in:
@@ -15,7 +15,6 @@ class BaseJob:
|
||||
self.config = config['config']
|
||||
self.raw_config = config
|
||||
self.job = config['job']
|
||||
self.torch_profiler = self.get_conf('torch_profiler', False)
|
||||
self.name = self.get_conf('name', required=True)
|
||||
if 'meta' in config:
|
||||
self.meta = config['meta']
|
||||
|
||||
@@ -236,6 +236,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.ema: ExponentialMovingAverage = None
|
||||
|
||||
validate_configs(self.train_config, self.model_config, self.save_config)
|
||||
|
||||
do_profiler = self.get_conf('torch_profiler', False)
|
||||
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -2058,6 +2066,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.start()
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
try:
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
@@ -2069,7 +2079,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for item in batch.file_items:
|
||||
print(f" - {item.path}")
|
||||
raise e
|
||||
|
||||
if self.torch_profiler is not None:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
self.torch_profiler.stop()
|
||||
|
||||
print("\n==== Profile Results ====")
|
||||
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
flush()
|
||||
|
||||
@@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer):
|
||||
do_paramiter_swapping=False,
|
||||
paramiter_swapping_factor=0.1,
|
||||
stochastic_accumulation=True,
|
||||
stochastic_rounding=True,
|
||||
):
|
||||
self.stochastic_rounding = stochastic_rounding
|
||||
if lr is not None and relative_step:
|
||||
raise ValueError(
|
||||
"Cannot combine manual `lr` and `relative_step=True` options")
|
||||
@@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer):
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype != torch.float32:
|
||||
if p.dtype != torch.float32 and self.stochastic_rounding:
|
||||
# apply stochastic rounding
|
||||
copy_stochastic(p, p_data_fp32)
|
||||
|
||||
|
||||
@@ -112,86 +112,57 @@ def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
||||
return 0, 8 # Int8 doesn't have mantissa bits
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor):
|
||||
# adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5
|
||||
# create a random 16 bit integer
|
||||
result = torch.randint_like(
|
||||
source,
|
||||
dtype=torch.int32,
|
||||
low=0,
|
||||
high=(1 << 16),
|
||||
)
|
||||
|
||||
# add the random number to the lower 16 bit of the mantissa
|
||||
result.add_(source.view(dtype=torch.int32))
|
||||
|
||||
# mask off the lower 16 bit of the mantissa
|
||||
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
||||
|
||||
# copy the higher 16 bit into the target tensor
|
||||
target.copy_(result.view(dtype=torch.float32))
|
||||
|
||||
del result
|
||||
|
||||
|
||||
def copy_stochastic(
|
||||
target: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
eps: Optional[float] = None
|
||||
) -> None:
|
||||
"""
|
||||
Performs stochastic rounding from source tensor to target tensor.
|
||||
|
||||
Args:
|
||||
target: Destination tensor (determines the target format)
|
||||
source: Source tensor (typically float32)
|
||||
eps: Optional minimum value for stochastic rounding (for numerical stability)
|
||||
"""
|
||||
def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None:
|
||||
with torch.no_grad():
|
||||
# If target is float32, just copy directly
|
||||
# assert if target is on cpu, throw error
|
||||
assert target.device.type != 'cpu', "Target is on cpu!"
|
||||
assert source.device.type != 'cpu', "Source is on cpu!"
|
||||
|
||||
if target.dtype == torch.float32:
|
||||
target.copy_(source)
|
||||
return
|
||||
|
||||
# Special handling for int8
|
||||
if target.dtype == torch.int8:
|
||||
# Scale the source values to utilize the full int8 range
|
||||
scaled = source * 127.0 # Scale to [-127, 127]
|
||||
|
||||
# Add random noise for stochastic rounding
|
||||
noise = torch.rand_like(scaled) - 0.5
|
||||
rounded = torch.round(scaled + noise)
|
||||
|
||||
# Clamp to int8 range
|
||||
clamped = torch.clamp(rounded, -127, 127)
|
||||
target.copy_(clamped.to(torch.int8))
|
||||
if target.dtype == torch.bfloat16:
|
||||
copy_stochastic_bf16(target, source)
|
||||
return
|
||||
|
||||
mantissa_bits, _ = get_format_params(target.dtype)
|
||||
round_factor = 2 ** (23 - mantissa_bits)
|
||||
|
||||
# Convert source to int32 view
|
||||
source_int = source.view(dtype=torch.int32)
|
||||
# Add uniform noise for stochastic rounding
|
||||
noise = torch.rand_like(source, device=source.device) - 0.5
|
||||
rounded = torch.round(source * round_factor + noise)
|
||||
result_float = rounded / round_factor
|
||||
|
||||
# Calculate number of bits to round
|
||||
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
|
||||
|
||||
# Create random integers for stochastic rounding
|
||||
rand = torch.randint_like(
|
||||
source,
|
||||
dtype=torch.int32,
|
||||
low=0,
|
||||
high=(1 << bits_to_round),
|
||||
)
|
||||
|
||||
# Add random values to the bits that will be rounded off
|
||||
result = source_int.clone()
|
||||
result.add_(rand)
|
||||
|
||||
# Mask to keep only the bits we want
|
||||
# Create mask with 1s in positions we want to keep
|
||||
mask = (-1) << bits_to_round
|
||||
result.bitwise_and_(mask)
|
||||
|
||||
# Handle minimum value threshold if specified
|
||||
if eps is not None:
|
||||
eps_int = torch.tensor(
|
||||
eps, dtype=torch.float32).view(dtype=torch.int32)
|
||||
zero_mask = (result.abs() < eps_int)
|
||||
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
|
||||
|
||||
# Convert back to float32 view
|
||||
result_float = result.view(dtype=torch.float32)
|
||||
|
||||
# Special handling for float8 formats
|
||||
# Clamp for float8
|
||||
if target.dtype == torch.float8_e4m3fn:
|
||||
result_float.clamp_(-448.0, 448.0)
|
||||
elif target.dtype == torch.float8_e5m2:
|
||||
result_float.clamp_(-57344.0, 57344.0)
|
||||
|
||||
# Copy the result to the target tensor
|
||||
update_parameter(target, result_float)
|
||||
# target.copy_(result_float)
|
||||
del result, rand, source_int
|
||||
|
||||
|
||||
class Auto8bitTensor:
|
||||
|
||||
Reference in New Issue
Block a user