mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Change from NCHW to MHWC based on old-ck and manage verifying for c > 1
This commit is contained in:
@@ -6,8 +6,8 @@
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
|
||||
// Simple POC for batchnorm forward pass
|
||||
// Tests basic functionality with a small tensor
|
||||
// Batch normalization forward pass - NHWC layout
|
||||
// NOTE: Using NHWC (not NCHW) for contiguous channel access
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -52,13 +52,14 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
|
||||
{
|
||||
for(ck_tile::index_t w = 0; w < W; ++w)
|
||||
{
|
||||
ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w;
|
||||
ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC indexing
|
||||
sum += ck_tile::type_convert<ComputeDataType>(x.mData[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
ComputeDataType mean = sum / static_cast<ComputeDataType>(per_channel_size);
|
||||
|
||||
|
||||
// Compute variance across all N samples and H×W positions for this channel
|
||||
ComputeDataType var_sum = 0;
|
||||
for(ck_tile::index_t n = 0; n < N; ++n)
|
||||
@@ -67,7 +68,7 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
|
||||
{
|
||||
for(ck_tile::index_t w = 0; w < W; ++w)
|
||||
{
|
||||
ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w;
|
||||
ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC
|
||||
ComputeDataType val = ck_tile::type_convert<ComputeDataType>(x.mData[idx]);
|
||||
ComputeDataType diff = val - mean;
|
||||
var_sum += diff * diff;
|
||||
@@ -76,6 +77,13 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
|
||||
}
|
||||
ComputeDataType variance = var_sum / static_cast<ComputeDataType>(per_channel_size);
|
||||
|
||||
// DEBUG: Print reference variance
|
||||
if(c < 4)
|
||||
{
|
||||
ComputeDataType inv_std = static_cast<ComputeDataType>(1.0) / std::sqrt(variance + epsilon);
|
||||
std::cout << ", var=" << variance << ", inv_std=" << inv_std << std::endl;
|
||||
}
|
||||
|
||||
// Load gamma and beta for this channel
|
||||
ComputeDataType gamma_val = static_cast<ComputeDataType>(1.0);
|
||||
ComputeDataType beta_val = static_cast<ComputeDataType>(0.0);
|
||||
@@ -100,7 +108,7 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor<XDataType>& x,
|
||||
{
|
||||
for(ck_tile::index_t w = 0; w < W; ++w)
|
||||
{
|
||||
ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w;
|
||||
ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC
|
||||
ComputeDataType val = ck_tile::type_convert<ComputeDataType>(x.mData[idx]);
|
||||
ComputeDataType normalized = gamma_val * ((val - mean) * inv_std) + beta_val;
|
||||
y.mData[idx] = ck_tile::type_convert<YDataType>(normalized);
|
||||
@@ -129,13 +137,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "Batchnorm POC: N=" << N << ", C=" << C << ", H=" << H << ", W=" << W
|
||||
<< ", epsilon=" << epsilon << std::endl;
|
||||
|
||||
// Allocate host tensors
|
||||
// Allocate host tensors in NHWC layout
|
||||
ck_tile::index_t total_size = N * C * H * W;
|
||||
ck_tile::HostTensor<XDataType> x_host({N, C, H, W});
|
||||
ck_tile::HostTensor<XDataType> x_host({N, H, W, C}); // NHWC!
|
||||
ck_tile::HostTensor<ComputeDataType> gamma_host({C});
|
||||
ck_tile::HostTensor<ComputeDataType> beta_host({C});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({N, C, H, W});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({N, C, H, W});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({N, H, W, C}); // NHWC!
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({N, H, W, C}); // NHWC!
|
||||
|
||||
// Fill input with random data
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
@@ -240,13 +248,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::cout << "Channel " << c << ":" << std::endl;
|
||||
|
||||
// Print 2 sample values from first sample (n=0)
|
||||
for(ck_tile::index_t sample = 0; sample < 2 && sample < H * W; ++sample)
|
||||
// Print 2 sample values from first sample (n=0, h=0, w=0,1)
|
||||
for(ck_tile::index_t w = 0; w < 2; ++w)
|
||||
{
|
||||
ck_tile::index_t idx = 0 * C * H * W + c * H * W + sample;
|
||||
ck_tile::index_t idx = 0*H*W*C + 0*W*C + w*C + c; // NHWC
|
||||
float ref_val = ck_tile::type_convert<float>(y_host_ref.mData[idx]);
|
||||
float dev_val = ck_tile::type_convert<float>(y_host_dev.mData[idx]);
|
||||
std::cout << " Sample[" << sample << "]: "
|
||||
std::cout << " Sample[" << w << "]: "
|
||||
<< "Ref=" << std::fixed << std::setprecision(6) << ref_val
|
||||
<< ", Kernel=" << dev_val
|
||||
<< ", Diff=" << std::abs(ref_val - dev_val) << std::endl;
|
||||
|
||||
@@ -14,11 +14,11 @@ namespace ck_tile {
|
||||
// Host-side arguments for batchnorm forward pass
|
||||
struct BatchnormFwdHostArgs
|
||||
{
|
||||
const void* p_x; // [N, C, H, W] input tensor (required)
|
||||
const void* p_x; // [N, H, W, C] input tensor (required, NHWC layout)
|
||||
const void* p_gamma; // [C] scale parameter (required, use all 1.0 if not needed)
|
||||
const void* p_beta; // [C] bias parameter (required, use all 0.0 if not needed)
|
||||
|
||||
void* p_y; // [N, C, H, W] output tensor (required)
|
||||
void* p_y; // [N, H, W, C] output tensor (required, NHWC layout)
|
||||
|
||||
void* p_running_mean; // [C] running mean (nullptr if not used)
|
||||
void* p_running_var; // [C] running variance (nullptr if not used)
|
||||
@@ -130,18 +130,20 @@ struct BatchnormFwd
|
||||
static constexpr index_t Block_M = BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = BlockShape::Block_N;
|
||||
|
||||
// Create tensor views WITHOUT distributions (will be applied in pipeline)
|
||||
// NHWC layout: channels are contiguous!
|
||||
const auto x_window = [&]() {
|
||||
const XDataType* p_x = static_cast<const XDataType*>(kargs.p_x);
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x + c * spatial_size,
|
||||
make_tuple(N, spatial_size),
|
||||
make_tuple(C * spatial_size, 1),
|
||||
const XDataType* p_x_channel = p_x + c; // Offset by c (channel stride = 1!)
|
||||
|
||||
const auto x_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x_channel,
|
||||
make_tuple(N, spatial_size), // [N, H×W]
|
||||
make_tuple(spatial_size*C, C), // NHWC strides: [H×W×C, C]
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
x_view, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0, 0});
|
||||
}();
|
||||
@@ -170,15 +172,17 @@ struct BatchnormFwd
|
||||
|
||||
auto y_window = [&]() {
|
||||
YDataType* p_y = static_cast<YDataType*>(kargs.p_y);
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y + c * spatial_size,
|
||||
make_tuple(N, spatial_size),
|
||||
make_tuple(C * spatial_size, 1),
|
||||
YDataType* p_y_channel = p_y + c; // Offset by c (NHWC)
|
||||
|
||||
const auto y_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y_channel,
|
||||
make_tuple(N, spatial_size), // [N, H×W]
|
||||
make_tuple(spatial_size*C, C), // NHWC strides
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
y_view, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0, 0});
|
||||
}();
|
||||
|
||||
@@ -51,20 +51,19 @@ struct BatchnormFwdPipeline
|
||||
{
|
||||
const index_t thread_id = get_thread_id();
|
||||
|
||||
// Apply tile distributions (like layernorm2d does)
|
||||
// Note: x_window and y_window are NOT const (need to move them)
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto beta_window = make_tile_window(
|
||||
beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto y_window =
|
||||
make_tile_window(y_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
// Windows from transform are 2D [N, H×W]
|
||||
// Apply 2D distribution (like C=1 that worked!)
|
||||
auto x_window = make_tile_window(
|
||||
x_window_,
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Load gamma/beta once (constant per channel)
|
||||
[[maybe_unused]]const auto gamma = load_tile(gamma_window);
|
||||
[[maybe_unused]]const auto beta = load_tile(beta_window);
|
||||
auto y_window = make_tile_window(
|
||||
y_window_,
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Gamma/beta windows passed in but not used yet (gamma=1, beta=0 in test)
|
||||
[[maybe_unused]] const auto gamma_window = gamma_window_;
|
||||
[[maybe_unused]] const auto beta_window = beta_window_;
|
||||
|
||||
// Calculate how many tiles needed (like layernorm2d two-pass)
|
||||
constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -108,7 +107,7 @@ struct BatchnormFwdPipeline
|
||||
|
||||
BlockWelford<ComputeDataType>::template Run<index_t, kBlockSize>(
|
||||
block_mean, block_var, block_count, smem);
|
||||
|
||||
|
||||
// ==========================================
|
||||
// PHASE 2: COMPUTE INVERSE STD
|
||||
// ==========================================
|
||||
|
||||
@@ -44,6 +44,24 @@ struct BatchnormFwdPipelineDefaultPolicy
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
// Simple 1D tile distribution for transformed [N×H×W] windows
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto Make1DBlockTileDistribution()
|
||||
{
|
||||
// For merged 1D data, use simple pass-through distribution
|
||||
// All threads collaborate on the Block_N elements
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<1>,
|
||||
sequence<0>>{});
|
||||
}
|
||||
|
||||
// Calculate shared memory size
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
|
||||
Reference in New Issue
Block a user