From b732595d9f993fe92efc9f2bf2c4ce9f71b8080f Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Thu, 27 Nov 2025 15:06:13 +0000 Subject: [PATCH] add kSaveMeanInvStd, kUpdateMovingAverage in Traits --- .../ck_tile/42_batchnorm/batchnorm_simple.cpp | 25 ++++++--- include/ck_tile/ops/batchnorm.hpp | 1 + .../batchnorm/kernel/batchnorm_fwd_kernel.hpp | 53 +++++++++++++++---- .../batchnorm/pipeline/batchnorm_problem.hpp | 21 ++++---- 4 files changed, 73 insertions(+), 27 deletions(-) diff --git a/example/ck_tile/42_batchnorm/batchnorm_simple.cpp b/example/ck_tile/42_batchnorm/batchnorm_simple.cpp index a150129f02..e38ed2a336 100644 --- a/example/ck_tile/42_batchnorm/batchnorm_simple.cpp +++ b/example/ck_tile/42_batchnorm/batchnorm_simple.cpp @@ -159,24 +159,35 @@ bool run(const ck_tile::ArgParser& arg_parser) using Vector = ck_tile::sequence<1, 1>; using Shape = ck_tile::BatchnormShape; - using Problem = ck_tile::BatchnormProblem; + + // Define traits (compile-time configuration) + using Traits = ck_tile::BatchnormFwdTraits; // No save, no update + + // Define problem with all types + using Problem = ck_tile::BatchnormProblem; using Kernel = ck_tile::BatchnormFwd; // Prepare host arguments + // Note: save/update behavior is determined by Traits (compile-time), not runtime args ck_tile::BatchnormFwdHostArgs hargs{ x_buf.GetDeviceBuffer(), // p_x gamma_buf.GetDeviceBuffer(), // p_gamma beta_buf.GetDeviceBuffer(), // p_beta y_buf.GetDeviceBuffer(), // p_y - nullptr, // p_running_mean + nullptr, // p_running_mean (not used, Traits::kUpdateMovingAverage=false) nullptr, // p_running_var - nullptr, // p_save_mean + nullptr, // p_save_mean (not used, Traits::kSaveMeanInvStd=false) nullptr, // p_save_inv_std epsilon, // epsilon - 0.1f, // momentum (not used yet) - N, C, H, W, // dimensions - false, // update_moving_average - false // save_mean_inv_std + 0.1f, // momentum + N, C, H, W // dimensions }; // Validate arguments diff --git a/include/ck_tile/ops/batchnorm.hpp b/include/ck_tile/ops/batchnorm.hpp index 67a924838a..64a00b336b 100644 --- a/include/ck_tile/ops/batchnorm.hpp +++ b/include/ck_tile/ops/batchnorm.hpp @@ -5,5 +5,6 @@ #include "ck_tile/ops/batchnorm/block/block_welford.hpp" #include "ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp" +#include "ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_traits.hpp" #include "ck_tile/ops/batchnorm/pipeline/batchnorm_problem.hpp" #include "ck_tile/ops/batchnorm/pipeline/batchnorm_shape.hpp" diff --git a/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp b/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp index 0634c93b3b..02caacc839 100644 --- a/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp +++ b/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp @@ -29,8 +29,7 @@ struct BatchnormFwdHostArgs index_t N, C, H, W; - bool update_moving_average; // If true, p_running_mean/var must be valid - bool save_mean_inv_std; // If true, p_save_mean/inv_std must be valid + // Note: save/update flags are now in Traits (compile-time), not here (runtime) }; // BatchnormFwd: Forward pass batch normalization kernel @@ -62,8 +61,7 @@ struct BatchnormFwd index_t N, C, H, W; - bool update_moving_average; - bool save_mean_inv_std; + // Note: save/update flags now come from Problem::Traits (compile-time) }; using Kargs = BatchnormFwdKargs; // Alias for convenience @@ -85,9 +83,7 @@ struct BatchnormFwd hargs.N, hargs.C, hargs.H, - hargs.W, - hargs.update_moving_average, - hargs.save_mean_inv_std}; + hargs.W}; } // Grid size calculation @@ -183,7 +179,7 @@ struct BatchnormFwd const index_t n = idx / spatial_size; const index_t hw = idx % spatial_size; - const index_t offset = n * C * H * W + c * H * W + hw; + const index_t offset = (n * C * H * W) + (c * H * W) + hw; ComputeDataType val = type_convert(p_x[offset]); // Apply batch normalization with scale and bias @@ -191,6 +187,41 @@ struct BatchnormFwd p_y[offset] = type_convert(normalized); } + + // Save mean and inverse std for backward pass (compile-time check) + if constexpr(Problem::Traits::kSaveMeanInvStd) + { + if(thread_id == 0) + { + using MeanVarDataType = typename Problem::MeanVarDataType; + MeanVarDataType* p_save_mean = static_cast(kargs.p_save_mean); + MeanVarDataType* p_save_inv_std = static_cast(kargs.p_save_inv_std); + + p_save_mean[c] = type_convert(block_mean); + p_save_inv_std[c] = type_convert(inv_std); + } + } + + // Update running mean and variance (compile-time check) + if constexpr(Problem::Traits::kUpdateMovingAverage) + { + if(thread_id == 0) + { + using MeanVarDataType = typename Problem::MeanVarDataType; + MeanVarDataType* p_running_mean = static_cast(kargs.p_running_mean); + MeanVarDataType* p_running_var = static_cast(kargs.p_running_var); + + const ComputeDataType momentum = static_cast(kargs.momentum); + const ComputeDataType one_minus_momentum = type_convert(1) - momentum; + + // Exponential moving average: new = (1-momentum)*old + momentum*current + ComputeDataType old_mean = type_convert(p_running_mean[c]); + ComputeDataType old_var = type_convert(p_running_var[c]); + + p_running_mean[c] = type_convert(one_minus_momentum * old_mean + momentum * block_mean); + p_running_var[c] = type_convert(one_minus_momentum * old_var + momentum * block_var); + } + } } // Validate arguments @@ -209,8 +240,8 @@ struct BatchnormFwd return false; } - // Validate optional pointers based on flags - if(hargs.update_moving_average) + // Validate optional pointers based on Traits (compile-time) + if constexpr(Problem::Traits::kUpdateMovingAverage) { if(hargs.p_running_mean == nullptr || hargs.p_running_var == nullptr) { @@ -218,7 +249,7 @@ struct BatchnormFwd } } - if(hargs.save_mean_inv_std) + if constexpr(Problem::Traits::kSaveMeanInvStd) { if(hargs.p_save_mean == nullptr || hargs.p_save_inv_std == nullptr) { diff --git a/include/ck_tile/ops/batchnorm/pipeline/batchnorm_problem.hpp b/include/ck_tile/ops/batchnorm/pipeline/batchnorm_problem.hpp index 51498d2836..e06cb7aa5c 100644 --- a/include/ck_tile/ops/batchnorm/pipeline/batchnorm_problem.hpp +++ b/include/ck_tile/ops/batchnorm/pipeline/batchnorm_problem.hpp @@ -5,29 +5,32 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/batchnorm/pipeline/batchnorm_shape.hpp" +#include "ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_traits.hpp" namespace ck_tile { // BatchnormProblem defines the computational problem for batch normalization // Input: x with shape [N, C, H, W] // Output: y with shape [N, C, H, W] -// Reduction over spatial dimensions (H, W) per channel +// Reduction over batch (N) and spatial dimensions (H, W) per channel (C) template + typename MeanVarDataType_, + typename BlockShape_, + typename Traits_> struct BatchnormProblem { using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using YDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; - - // For now, start with simple forward pass without scale/bias - // We'll add these later: - // using GammaDataType = ... // scale parameter - // using BetaDataType = ... // bias parameter - // using MeanVarDataType = ... // for saving mean/variance + using MeanVarDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using Traits = remove_cvref_t; }; } // namespace ck_tile