Change from NCHW to MHWC based on old-ck and manage verifying for c > 1

This commit is contained in:
Mohsen Saffari
2025-12-02 15:47:46 +00:00
parent 3194b653f7
commit c4199307ec
4 changed files with 69 additions and 40 deletions

View File

@@ -6,8 +6,8 @@
#include <cstring>
#include <iomanip>
// Simple POC for batchnorm forward pass
// Tests basic functionality with a small tensor
// Batch normalization forward pass - NHWC layout
// NOTE: Using NHWC (not NCHW) for contiguous channel access
auto create_args(int argc, char* argv[])
{
@@ -52,13 +52,14 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
{
for(ck_tile::index_t w = 0; w < W; ++w)
{
ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w;
ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC indexing
sum += ck_tile::type_convert<ComputeDataType>(x.mData[idx]);
}
}
}
ComputeDataType mean = sum / static_cast<ComputeDataType>(per_channel_size);
// Compute variance across all N samples and H×W positions for this channel
ComputeDataType var_sum = 0;
for(ck_tile::index_t n = 0; n < N; ++n)
@@ -67,7 +68,7 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
{
for(ck_tile::index_t w = 0; w < W; ++w)
{
ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w;
ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC
ComputeDataType val = ck_tile::type_convert<ComputeDataType>(x.mData[idx]);
ComputeDataType diff = val - mean;
var_sum += diff * diff;
@@ -76,6 +77,13 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
}
ComputeDataType variance = var_sum / static_cast<ComputeDataType>(per_channel_size);
// DEBUG: Print reference variance
if(c < 4)
{
ComputeDataType inv_std = static_cast<ComputeDataType>(1.0) / std::sqrt(variance + epsilon);
std::cout << ", var=" << variance << ", inv_std=" << inv_std << std::endl;
}
// Load gamma and beta for this channel
ComputeDataType gamma_val = static_cast<ComputeDataType>(1.0);
ComputeDataType beta_val = static_cast<ComputeDataType>(0.0);
@@ -100,7 +108,7 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
{
for(ck_tile::index_t w = 0; w < W; ++w)
{
ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w;
ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC
ComputeDataType val = ck_tile::type_convert<ComputeDataType>(x.mData[idx]);
ComputeDataType normalized = gamma_val * ((val - mean) * inv_std) + beta_val;
y.mData[idx] = ck_tile::type_convert<YDataType>(normalized);
@@ -129,13 +137,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "Batchnorm POC: N=" << N << ", C=" << C << ", H=" << H << ", W=" << W
<< ", epsilon=" << epsilon << std::endl;
// Allocate host tensors
// Allocate host tensors in NHWC layout
ck_tile::index_t total_size = N * C * H * W;
ck_tile::HostTensor<XDataType> x_host({N, C, H, W});
ck_tile::HostTensor<XDataType> x_host({N, H, W, C}); // NHWC!
ck_tile::HostTensor<ComputeDataType> gamma_host({C});
ck_tile::HostTensor<ComputeDataType> beta_host({C});
ck_tile::HostTensor<YDataType> y_host_ref({N, C, H, W});
ck_tile::HostTensor<YDataType> y_host_dev({N, C, H, W});
ck_tile::HostTensor<YDataType> y_host_ref({N, H, W, C}); // NHWC!
ck_tile::HostTensor<YDataType> y_host_dev({N, H, W, C}); // NHWC!
// Fill input with random data
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
@@ -240,13 +248,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
std::cout << "Channel " << c << ":" << std::endl;
// Print 2 sample values from first sample (n=0)
for(ck_tile::index_t sample = 0; sample < 2 && sample < H * W; ++sample)
// Print 2 sample values from first sample (n=0, h=0, w=0,1)
for(ck_tile::index_t w = 0; w < 2; ++w)
{
ck_tile::index_t idx = 0 * C * H * W + c * H * W + sample;
ck_tile::index_t idx = 0*H*W*C + 0*W*C + w*C + c; // NHWC
float ref_val = ck_tile::type_convert<float>(y_host_ref.mData[idx]);
float dev_val = ck_tile::type_convert<float>(y_host_dev.mData[idx]);
std::cout << " Sample[" << sample << "]: "
std::cout << " Sample[" << w << "]: "
<< "Ref=" << std::fixed << std::setprecision(6) << ref_val
<< ", Kernel=" << dev_val
<< ", Diff=" << std::abs(ref_val - dev_val) << std::endl;

View File

@@ -14,11 +14,11 @@ namespace ck_tile {
// Host-side arguments for batchnorm forward pass
struct BatchnormFwdHostArgs
{
const void* p_x; // [N, C, H, W] input tensor (required)
const void* p_x; // [N, H, W, C] input tensor (required, NHWC layout)
const void* p_gamma; // [C] scale parameter (required, use all 1.0 if not needed)
const void* p_beta; // [C] bias parameter (required, use all 0.0 if not needed)
void* p_y; // [N, C, H, W] output tensor (required)
void* p_y; // [N, H, W, C] output tensor (required, NHWC layout)
void* p_running_mean; // [C] running mean (nullptr if not used)
void* p_running_var; // [C] running variance (nullptr if not used)
@@ -130,18 +130,20 @@ struct BatchnormFwd
static constexpr index_t Block_M = BlockShape::Block_M;
static constexpr index_t Block_N = BlockShape::Block_N;
// Create tensor views WITHOUT distributions (will be applied in pipeline)
// NHWC layout: channels are contiguous!
const auto x_window = [&]() {
const XDataType* p_x = static_cast<const XDataType*>(kargs.p_x);
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
p_x + c * spatial_size,
make_tuple(N, spatial_size),
make_tuple(C * spatial_size, 1),
const XDataType* p_x_channel = p_x + c; // Offset by c (channel stride = 1!)
const auto x_view = make_naive_tensor_view<address_space_enum::global>(
p_x_channel,
make_tuple(N, spatial_size), // [N, H×W]
make_tuple(spatial_size*C, C), // NHWC strides: [H×W×C, C]
number<1>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
x_view, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0, 0});
}();
@@ -170,15 +172,17 @@ struct BatchnormFwd
auto y_window = [&]() {
YDataType* p_y = static_cast<YDataType*>(kargs.p_y);
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
p_y + c * spatial_size,
make_tuple(N, spatial_size),
make_tuple(C * spatial_size, 1),
YDataType* p_y_channel = p_y + c; // Offset by c (NHWC)
const auto y_view = make_naive_tensor_view<address_space_enum::global>(
p_y_channel,
make_tuple(N, spatial_size), // [N, H×W]
make_tuple(spatial_size*C, C), // NHWC strides
number<1>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
y_view, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0, 0});
}();

View File

@@ -51,20 +51,19 @@ struct BatchnormFwdPipeline
{
const index_t thread_id = get_thread_id();
// Apply tile distributions (like layernorm2d does)
// Note: x_window and y_window are NOT const (need to move them)
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto beta_window = make_tile_window(
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
auto y_window =
make_tile_window(y_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Windows from transform are 2D [N, H×W]
// Apply 2D distribution (like C=1 that worked!)
auto x_window = make_tile_window(
x_window_,
Policy::template MakeXBlockTileDistribution<Problem>());
// Load gamma/beta once (constant per channel)
[[maybe_unused]]const auto gamma = load_tile(gamma_window);
[[maybe_unused]]const auto beta = load_tile(beta_window);
auto y_window = make_tile_window(
y_window_,
Policy::template MakeXBlockTileDistribution<Problem>());
// Gamma/beta windows passed in but not used yet (gamma=1, beta=0 in test)
[[maybe_unused]] const auto gamma_window = gamma_window_;
[[maybe_unused]] const auto beta_window = beta_window_;
// Calculate how many tiles needed (like layernorm2d two-pass)
constexpr index_t Block_N = Problem::BlockShape::Block_N;
@@ -108,7 +107,7 @@ struct BatchnormFwdPipeline
BlockWelford<ComputeDataType>::template Run<index_t, kBlockSize>(
block_mean, block_var, block_count, smem);
// ==========================================
// PHASE 2: COMPUTE INVERSE STD
// ==========================================

View File

@@ -44,6 +44,24 @@ struct BatchnormFwdPipelineDefaultPolicy
sequence<0, 3>>{});
}
// Simple 1D tile distribution for transformed [N×H×W] windows
template <typename Problem>
CK_TILE_DEVICE static constexpr auto Make1DBlockTileDistribution()
{
// For merged 1D data, use simple pass-through distribution
// All threads collaborate on the Block_N elements
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0>>,
tuple<sequence<0>>,
sequence<1>,
sequence<0>>{});
}
// Calculate shared memory size
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()