mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
add kSaveMeanInvStd, kUpdateMovingAverage in Traits
This commit is contained in:
@@ -159,24 +159,35 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
|
||||
using Shape = ck_tile::BatchnormShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
using Problem = ck_tile::BatchnormProblem<XDataType, ComputeDataType, YDataType, Shape>;
|
||||
|
||||
// Define traits (compile-time configuration)
|
||||
using Traits = ck_tile::BatchnormFwdTraits<false, false>; // No save, no update
|
||||
|
||||
// Define problem with all types
|
||||
using Problem = ck_tile::BatchnormProblem<XDataType, // input type
|
||||
ComputeDataType, // gamma type
|
||||
ComputeDataType, // beta type
|
||||
ComputeDataType, // compute type
|
||||
YDataType, // output type
|
||||
ComputeDataType, // mean/var type
|
||||
Shape,
|
||||
Traits>;
|
||||
using Kernel = ck_tile::BatchnormFwd<Problem>;
|
||||
|
||||
// Prepare host arguments
|
||||
// Note: save/update behavior is determined by Traits (compile-time), not runtime args
|
||||
ck_tile::BatchnormFwdHostArgs hargs{
|
||||
x_buf.GetDeviceBuffer(), // p_x
|
||||
gamma_buf.GetDeviceBuffer(), // p_gamma
|
||||
beta_buf.GetDeviceBuffer(), // p_beta
|
||||
y_buf.GetDeviceBuffer(), // p_y
|
||||
nullptr, // p_running_mean
|
||||
nullptr, // p_running_mean (not used, Traits::kUpdateMovingAverage=false)
|
||||
nullptr, // p_running_var
|
||||
nullptr, // p_save_mean
|
||||
nullptr, // p_save_mean (not used, Traits::kSaveMeanInvStd=false)
|
||||
nullptr, // p_save_inv_std
|
||||
epsilon, // epsilon
|
||||
0.1f, // momentum (not used yet)
|
||||
N, C, H, W, // dimensions
|
||||
false, // update_moving_average
|
||||
false // save_mean_inv_std
|
||||
0.1f, // momentum
|
||||
N, C, H, W // dimensions
|
||||
};
|
||||
|
||||
// Validate arguments
|
||||
|
||||
@@ -5,5 +5,6 @@
|
||||
|
||||
#include "ck_tile/ops/batchnorm/block/block_welford.hpp"
|
||||
#include "ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_problem.hpp"
|
||||
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_shape.hpp"
|
||||
|
||||
@@ -29,8 +29,7 @@ struct BatchnormFwdHostArgs
|
||||
|
||||
index_t N, C, H, W;
|
||||
|
||||
bool update_moving_average; // If true, p_running_mean/var must be valid
|
||||
bool save_mean_inv_std; // If true, p_save_mean/inv_std must be valid
|
||||
// Note: save/update flags are now in Traits (compile-time), not here (runtime)
|
||||
};
|
||||
|
||||
// BatchnormFwd: Forward pass batch normalization kernel
|
||||
@@ -62,8 +61,7 @@ struct BatchnormFwd
|
||||
|
||||
index_t N, C, H, W;
|
||||
|
||||
bool update_moving_average;
|
||||
bool save_mean_inv_std;
|
||||
// Note: save/update flags now come from Problem::Traits (compile-time)
|
||||
};
|
||||
|
||||
using Kargs = BatchnormFwdKargs; // Alias for convenience
|
||||
@@ -85,9 +83,7 @@ struct BatchnormFwd
|
||||
hargs.N,
|
||||
hargs.C,
|
||||
hargs.H,
|
||||
hargs.W,
|
||||
hargs.update_moving_average,
|
||||
hargs.save_mean_inv_std};
|
||||
hargs.W};
|
||||
}
|
||||
|
||||
// Grid size calculation
|
||||
@@ -183,7 +179,7 @@ struct BatchnormFwd
|
||||
const index_t n = idx / spatial_size;
|
||||
const index_t hw = idx % spatial_size;
|
||||
|
||||
const index_t offset = n * C * H * W + c * H * W + hw;
|
||||
const index_t offset = (n * C * H * W) + (c * H * W) + hw;
|
||||
ComputeDataType val = type_convert<ComputeDataType>(p_x[offset]);
|
||||
|
||||
// Apply batch normalization with scale and bias
|
||||
@@ -191,6 +187,41 @@ struct BatchnormFwd
|
||||
|
||||
p_y[offset] = type_convert<YDataType>(normalized);
|
||||
}
|
||||
|
||||
// Save mean and inverse std for backward pass (compile-time check)
|
||||
if constexpr(Problem::Traits::kSaveMeanInvStd)
|
||||
{
|
||||
if(thread_id == 0)
|
||||
{
|
||||
using MeanVarDataType = typename Problem::MeanVarDataType;
|
||||
MeanVarDataType* p_save_mean = static_cast<MeanVarDataType*>(kargs.p_save_mean);
|
||||
MeanVarDataType* p_save_inv_std = static_cast<MeanVarDataType*>(kargs.p_save_inv_std);
|
||||
|
||||
p_save_mean[c] = type_convert<MeanVarDataType>(block_mean);
|
||||
p_save_inv_std[c] = type_convert<MeanVarDataType>(inv_std);
|
||||
}
|
||||
}
|
||||
|
||||
// Update running mean and variance (compile-time check)
|
||||
if constexpr(Problem::Traits::kUpdateMovingAverage)
|
||||
{
|
||||
if(thread_id == 0)
|
||||
{
|
||||
using MeanVarDataType = typename Problem::MeanVarDataType;
|
||||
MeanVarDataType* p_running_mean = static_cast<MeanVarDataType*>(kargs.p_running_mean);
|
||||
MeanVarDataType* p_running_var = static_cast<MeanVarDataType*>(kargs.p_running_var);
|
||||
|
||||
const ComputeDataType momentum = static_cast<ComputeDataType>(kargs.momentum);
|
||||
const ComputeDataType one_minus_momentum = type_convert<ComputeDataType>(1) - momentum;
|
||||
|
||||
// Exponential moving average: new = (1-momentum)*old + momentum*current
|
||||
ComputeDataType old_mean = type_convert<ComputeDataType>(p_running_mean[c]);
|
||||
ComputeDataType old_var = type_convert<ComputeDataType>(p_running_var[c]);
|
||||
|
||||
p_running_mean[c] = type_convert<MeanVarDataType>(one_minus_momentum * old_mean + momentum * block_mean);
|
||||
p_running_var[c] = type_convert<MeanVarDataType>(one_minus_momentum * old_var + momentum * block_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate arguments
|
||||
@@ -209,8 +240,8 @@ struct BatchnormFwd
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate optional pointers based on flags
|
||||
if(hargs.update_moving_average)
|
||||
// Validate optional pointers based on Traits (compile-time)
|
||||
if constexpr(Problem::Traits::kUpdateMovingAverage)
|
||||
{
|
||||
if(hargs.p_running_mean == nullptr || hargs.p_running_var == nullptr)
|
||||
{
|
||||
@@ -218,7 +249,7 @@ struct BatchnormFwd
|
||||
}
|
||||
}
|
||||
|
||||
if(hargs.save_mean_inv_std)
|
||||
if constexpr(Problem::Traits::kSaveMeanInvStd)
|
||||
{
|
||||
if(hargs.p_save_mean == nullptr || hargs.p_save_inv_std == nullptr)
|
||||
{
|
||||
|
||||
@@ -5,29 +5,32 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_shape.hpp"
|
||||
#include "ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// BatchnormProblem defines the computational problem for batch normalization
|
||||
// Input: x with shape [N, C, H, W]
|
||||
// Output: y with shape [N, C, H, W]
|
||||
// Reduction over spatial dimensions (H, W) per channel
|
||||
// Reduction over batch (N) and spatial dimensions (H, W) per channel (C)
|
||||
template <typename XDataType_,
|
||||
typename GammaDataType_,
|
||||
typename BetaDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename Shape_>
|
||||
typename MeanVarDataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct BatchnormProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<Shape_>;
|
||||
|
||||
// For now, start with simple forward pass without scale/bias
|
||||
// We'll add these later:
|
||||
// using GammaDataType = ... // scale parameter
|
||||
// using BetaDataType = ... // bias parameter
|
||||
// using MeanVarDataType = ... // for saving mean/variance
|
||||
using MeanVarDataType = remove_cvref_t<MeanVarDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user