diff --git a/example/ck_tile/42_batchnorm/batchnorm_simple.cpp b/example/ck_tile/42_batchnorm/batchnorm_simple.cpp index d65618e2bc..c3a63a49fd 100644 --- a/example/ck_tile/42_batchnorm/batchnorm_simple.cpp +++ b/example/ck_tile/42_batchnorm/batchnorm_simple.cpp @@ -6,8 +6,8 @@ #include #include -// Simple POC for batchnorm forward pass -// Tests basic functionality with a small tensor +// Batch normalization forward pass - NHWC layout +// NOTE: Using NHWC (not NCHW) for contiguous channel access auto create_args(int argc, char* argv[]) { @@ -52,13 +52,14 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor& x, { for(ck_tile::index_t w = 0; w < W; ++w) { - ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w; + ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC indexing sum += ck_tile::type_convert(x.mData[idx]); } } } ComputeDataType mean = sum / static_cast(per_channel_size); + // Compute variance across all N samples and H×W positions for this channel ComputeDataType var_sum = 0; for(ck_tile::index_t n = 0; n < N; ++n) @@ -67,7 +68,7 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor& x, { for(ck_tile::index_t w = 0; w < W; ++w) { - ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w; + ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC ComputeDataType val = ck_tile::type_convert(x.mData[idx]); ComputeDataType diff = val - mean; var_sum += diff * diff; @@ -76,6 +77,13 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor& x, } ComputeDataType variance = var_sum / static_cast(per_channel_size); + // DEBUG: Print reference variance + if(c < 4) + { + ComputeDataType inv_std = static_cast(1.0) / std::sqrt(variance + epsilon); + std::cout << ", var=" << variance << ", inv_std=" << inv_std << std::endl; + } + // Load gamma and beta for this channel ComputeDataType gamma_val = static_cast(1.0); ComputeDataType beta_val = static_cast(0.0); @@ -100,7 +108,7 @@ void reference_batchnorm_fwd(const ck_tile::HostTensor& x, { for(ck_tile::index_t w = 0; w < W; ++w) { - ck_tile::index_t idx = n * C * H * W + c * H * W + h * W + w; + ck_tile::index_t idx = n*H*W*C + h*W*C + w*C + c; // NHWC ComputeDataType val = ck_tile::type_convert(x.mData[idx]); ComputeDataType normalized = gamma_val * ((val - mean) * inv_std) + beta_val; y.mData[idx] = ck_tile::type_convert(normalized); @@ -129,13 +137,13 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "Batchnorm POC: N=" << N << ", C=" << C << ", H=" << H << ", W=" << W << ", epsilon=" << epsilon << std::endl; - // Allocate host tensors + // Allocate host tensors in NHWC layout ck_tile::index_t total_size = N * C * H * W; - ck_tile::HostTensor x_host({N, C, H, W}); + ck_tile::HostTensor x_host({N, H, W, C}); // NHWC! ck_tile::HostTensor gamma_host({C}); ck_tile::HostTensor beta_host({C}); - ck_tile::HostTensor y_host_ref({N, C, H, W}); - ck_tile::HostTensor y_host_dev({N, C, H, W}); + ck_tile::HostTensor y_host_ref({N, H, W, C}); // NHWC! + ck_tile::HostTensor y_host_dev({N, H, W, C}); // NHWC! // Fill input with random data ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); @@ -240,13 +248,13 @@ bool run(const ck_tile::ArgParser& arg_parser) { std::cout << "Channel " << c << ":" << std::endl; - // Print 2 sample values from first sample (n=0) - for(ck_tile::index_t sample = 0; sample < 2 && sample < H * W; ++sample) + // Print 2 sample values from first sample (n=0, h=0, w=0,1) + for(ck_tile::index_t w = 0; w < 2; ++w) { - ck_tile::index_t idx = 0 * C * H * W + c * H * W + sample; + ck_tile::index_t idx = 0*H*W*C + 0*W*C + w*C + c; // NHWC float ref_val = ck_tile::type_convert(y_host_ref.mData[idx]); float dev_val = ck_tile::type_convert(y_host_dev.mData[idx]); - std::cout << " Sample[" << sample << "]: " + std::cout << " Sample[" << w << "]: " << "Ref=" << std::fixed << std::setprecision(6) << ref_val << ", Kernel=" << dev_val << ", Diff=" << std::abs(ref_val - dev_val) << std::endl; diff --git a/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp b/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp index 40b27d9720..25a9124848 100644 --- a/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp +++ b/include/ck_tile/ops/batchnorm/kernel/batchnorm_fwd_kernel.hpp @@ -14,11 +14,11 @@ namespace ck_tile { // Host-side arguments for batchnorm forward pass struct BatchnormFwdHostArgs { - const void* p_x; // [N, C, H, W] input tensor (required) + const void* p_x; // [N, H, W, C] input tensor (required, NHWC layout) const void* p_gamma; // [C] scale parameter (required, use all 1.0 if not needed) const void* p_beta; // [C] bias parameter (required, use all 0.0 if not needed) - void* p_y; // [N, C, H, W] output tensor (required) + void* p_y; // [N, H, W, C] output tensor (required, NHWC layout) void* p_running_mean; // [C] running mean (nullptr if not used) void* p_running_var; // [C] running variance (nullptr if not used) @@ -130,18 +130,20 @@ struct BatchnormFwd static constexpr index_t Block_M = BlockShape::Block_M; static constexpr index_t Block_N = BlockShape::Block_N; - // Create tensor views WITHOUT distributions (will be applied in pipeline) + // NHWC layout: channels are contiguous! const auto x_window = [&]() { const XDataType* p_x = static_cast(kargs.p_x); - const auto tmp_ = make_naive_tensor_view( - p_x + c * spatial_size, - make_tuple(N, spatial_size), - make_tuple(C * spatial_size, 1), + const XDataType* p_x_channel = p_x + c; // Offset by c (channel stride = 1!) + + const auto x_view = make_naive_tensor_view( + p_x_channel, + make_tuple(N, spatial_size), // [N, H×W] + make_tuple(spatial_size*C, C), // NHWC strides: [H×W×C, C] number<1>{}, number<1>{}); const auto tmp2_ = pad_tensor_view( - tmp_, make_tuple(number{}, number{}), sequence{}); + x_view, make_tuple(number{}, number{}), sequence{}); return make_tile_window(tmp2_, make_tuple(number{}, number{}), {0, 0}); }(); @@ -170,15 +172,17 @@ struct BatchnormFwd auto y_window = [&]() { YDataType* p_y = static_cast(kargs.p_y); - const auto tmp_ = make_naive_tensor_view( - p_y + c * spatial_size, - make_tuple(N, spatial_size), - make_tuple(C * spatial_size, 1), + YDataType* p_y_channel = p_y + c; // Offset by c (NHWC) + + const auto y_view = make_naive_tensor_view( + p_y_channel, + make_tuple(N, spatial_size), // [N, H×W] + make_tuple(spatial_size*C, C), // NHWC strides number<1>{}, number<1>{}); const auto tmp2_ = pad_tensor_view( - tmp_, make_tuple(number{}, number{}), sequence{}); + y_view, make_tuple(number{}, number{}), sequence{}); return make_tile_window(tmp2_, make_tuple(number{}, number{}), {0, 0}); }(); diff --git a/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_pipeline.hpp b/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_pipeline.hpp index acf5885229..25bd00ed7f 100644 --- a/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_pipeline.hpp +++ b/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_pipeline.hpp @@ -51,20 +51,19 @@ struct BatchnormFwdPipeline { const index_t thread_id = get_thread_id(); - // Apply tile distributions (like layernorm2d does) - // Note: x_window and y_window are NOT const (need to move them) - auto x_window = - make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); - const auto gamma_window = make_tile_window( - gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution()); - const auto beta_window = make_tile_window( - beta_window_, Policy::template MakeGammaBetaBlockTileDistribution()); - auto y_window = - make_tile_window(y_window_, Policy::template MakeXBlockTileDistribution()); + // Windows from transform are 2D [N, H×W] + // Apply 2D distribution (like C=1 that worked!) + auto x_window = make_tile_window( + x_window_, + Policy::template MakeXBlockTileDistribution()); - // Load gamma/beta once (constant per channel) - [[maybe_unused]]const auto gamma = load_tile(gamma_window); - [[maybe_unused]]const auto beta = load_tile(beta_window); + auto y_window = make_tile_window( + y_window_, + Policy::template MakeXBlockTileDistribution()); + + // Gamma/beta windows passed in but not used yet (gamma=1, beta=0 in test) + [[maybe_unused]] const auto gamma_window = gamma_window_; + [[maybe_unused]] const auto beta_window = beta_window_; // Calculate how many tiles needed (like layernorm2d two-pass) constexpr index_t Block_N = Problem::BlockShape::Block_N; @@ -108,7 +107,7 @@ struct BatchnormFwdPipeline BlockWelford::template Run( block_mean, block_var, block_count, smem); - + // ========================================== // PHASE 2: COMPUTE INVERSE STD // ========================================== diff --git a/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_policy.hpp b/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_policy.hpp index a21fc32c35..cd9e7a0150 100644 --- a/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_policy.hpp +++ b/include/ck_tile/ops/batchnorm/pipeline/batchnorm_fwd_policy.hpp @@ -44,6 +44,24 @@ struct BatchnormFwdPipelineDefaultPolicy sequence<0, 3>>{}); } + // Simple 1D tile distribution for transformed [N×H×W] windows + template + CK_TILE_DEVICE static constexpr auto Make1DBlockTileDistribution() + { + // For merged 1D data, use simple pass-through distribution + // All threads collaborate on the Block_N elements + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple>, + tuple>, + tuple>, + sequence<1>, + sequence<0>>{}); + } + // Calculate shared memory size template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()