add kSaveMeanInvStd, kUpdateMovingAverage in Traits

This commit is contained in:
Mohsen Saffari
2025-11-27 15:06:13 +00:00
parent 5c8e8684ec
commit b732595d9f
4 changed files with 73 additions and 27 deletions

View File

@@ -159,24 +159,35 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::BatchnormShape<BlockWarps, BlockTile, WarpTile, Vector>;
using Problem = ck_tile::BatchnormProblem<XDataType, ComputeDataType, YDataType, Shape>;
// Define traits (compile-time configuration)
using Traits = ck_tile::BatchnormFwdTraits<false, false>; // No save, no update
// Define problem with all types
using Problem = ck_tile::BatchnormProblem<XDataType, // input type
ComputeDataType, // gamma type
ComputeDataType, // beta type
ComputeDataType, // compute type
YDataType, // output type
ComputeDataType, // mean/var type
Shape,
Traits>;
using Kernel = ck_tile::BatchnormFwd<Problem>;
// Prepare host arguments
// Note: save/update behavior is determined by Traits (compile-time), not runtime args
ck_tile::BatchnormFwdHostArgs hargs{
x_buf.GetDeviceBuffer(), // p_x
gamma_buf.GetDeviceBuffer(), // p_gamma
beta_buf.GetDeviceBuffer(), // p_beta
y_buf.GetDeviceBuffer(), // p_y
nullptr, // p_running_mean
nullptr, // p_running_mean (not used, Traits::kUpdateMovingAverage=false)
nullptr, // p_running_var
nullptr, // p_save_mean
nullptr, // p_save_mean (not used, Traits::kSaveMeanInvStd=false)
nullptr, // p_save_inv_std
epsilon, // epsilon
0.1f, // momentum (not used yet)
N, C, H, W, // dimensions
false, // update_moving_average
false // save_mean_inv_std
0.1f, // momentum
N, C, H, W // dimensions
};
// Validate arguments

View File

@@ -5,5 +5,6 @@
#include "ck_tile/ops/batchnorm/block/block_welford.hpp"
#include "ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp"
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_traits.hpp"
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_problem.hpp"
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_shape.hpp"

View File

@@ -29,8 +29,7 @@ struct BatchnormFwdHostArgs
index_t N, C, H, W;
bool update_moving_average; // If true, p_running_mean/var must be valid
bool save_mean_inv_std; // If true, p_save_mean/inv_std must be valid
// Note: save/update flags are now in Traits (compile-time), not here (runtime)
};
// BatchnormFwd: Forward pass batch normalization kernel
@@ -62,8 +61,7 @@ struct BatchnormFwd
index_t N, C, H, W;
bool update_moving_average;
bool save_mean_inv_std;
// Note: save/update flags now come from Problem::Traits (compile-time)
};
using Kargs = BatchnormFwdKargs; // Alias for convenience
@@ -85,9 +83,7 @@ struct BatchnormFwd
hargs.N,
hargs.C,
hargs.H,
hargs.W,
hargs.update_moving_average,
hargs.save_mean_inv_std};
hargs.W};
}
// Grid size calculation
@@ -183,7 +179,7 @@ struct BatchnormFwd
const index_t n = idx / spatial_size;
const index_t hw = idx % spatial_size;
const index_t offset = n * C * H * W + c * H * W + hw;
const index_t offset = (n * C * H * W) + (c * H * W) + hw;
ComputeDataType val = type_convert<ComputeDataType>(p_x[offset]);
// Apply batch normalization with scale and bias
@@ -191,6 +187,41 @@ struct BatchnormFwd
p_y[offset] = type_convert<YDataType>(normalized);
}
// Save mean and inverse std for backward pass (compile-time check)
if constexpr(Problem::Traits::kSaveMeanInvStd)
{
if(thread_id == 0)
{
using MeanVarDataType = typename Problem::MeanVarDataType;
MeanVarDataType* p_save_mean = static_cast<MeanVarDataType*>(kargs.p_save_mean);
MeanVarDataType* p_save_inv_std = static_cast<MeanVarDataType*>(kargs.p_save_inv_std);
p_save_mean[c] = type_convert<MeanVarDataType>(block_mean);
p_save_inv_std[c] = type_convert<MeanVarDataType>(inv_std);
}
}
// Update running mean and variance (compile-time check)
if constexpr(Problem::Traits::kUpdateMovingAverage)
{
if(thread_id == 0)
{
using MeanVarDataType = typename Problem::MeanVarDataType;
MeanVarDataType* p_running_mean = static_cast<MeanVarDataType*>(kargs.p_running_mean);
MeanVarDataType* p_running_var = static_cast<MeanVarDataType*>(kargs.p_running_var);
const ComputeDataType momentum = static_cast<ComputeDataType>(kargs.momentum);
const ComputeDataType one_minus_momentum = type_convert<ComputeDataType>(1) - momentum;
// Exponential moving average: new = (1-momentum)*old + momentum*current
ComputeDataType old_mean = type_convert<ComputeDataType>(p_running_mean[c]);
ComputeDataType old_var = type_convert<ComputeDataType>(p_running_var[c]);
p_running_mean[c] = type_convert<MeanVarDataType>(one_minus_momentum * old_mean + momentum * block_mean);
p_running_var[c] = type_convert<MeanVarDataType>(one_minus_momentum * old_var + momentum * block_var);
}
}
}
// Validate arguments
@@ -209,8 +240,8 @@ struct BatchnormFwd
return false;
}
// Validate optional pointers based on flags
if(hargs.update_moving_average)
// Validate optional pointers based on Traits (compile-time)
if constexpr(Problem::Traits::kUpdateMovingAverage)
{
if(hargs.p_running_mean == nullptr || hargs.p_running_var == nullptr)
{
@@ -218,7 +249,7 @@ struct BatchnormFwd
}
}
if(hargs.save_mean_inv_std)
if constexpr(Problem::Traits::kSaveMeanInvStd)
{
if(hargs.p_save_mean == nullptr || hargs.p_save_inv_std == nullptr)
{

View File

@@ -5,29 +5,32 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_shape.hpp"
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_traits.hpp"
namespace ck_tile {
// BatchnormProblem defines the computational problem for batch normalization
// Input: x with shape [N, C, H, W]
// Output: y with shape [N, C, H, W]
// Reduction over spatial dimensions (H, W) per channel
// Reduction over batch (N) and spatial dimensions (H, W) per channel (C)
template <typename XDataType_,
typename GammaDataType_,
typename BetaDataType_,
typename ComputeDataType_,
typename YDataType_,
typename Shape_>
typename MeanVarDataType_,
typename BlockShape_,
typename Traits_>
struct BatchnormProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<Shape_>;
// For now, start with simple forward pass without scale/bias
// We'll add these later:
// using GammaDataType = ... // scale parameter
// using BetaDataType = ... // bias parameter
// using MeanVarDataType = ... // for saving mean/variance
using MeanVarDataType = remove_cvref_t<MeanVarDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile