From 241ee59880ebf9e6516201c2df845ebf8eab5eef Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 6 Feb 2026 18:07:36 +0000 Subject: [PATCH] clean up example a bit --- example/ck_tile/42_mx_gemm/mx_gemm.hpp | 6 +- .../ck_tile/42_mx_gemm/mx_gemm_instance.hpp | 42 +-- example/ck_tile/42_mx_gemm/run_mx_gemm.inc | 259 ++---------------- 3 files changed, 35 insertions(+), 272 deletions(-) diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.hpp b/example/ck_tile/42_mx_gemm/mx_gemm.hpp index c80df5d621..ff1c6d60cd 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.hpp @@ -73,9 +73,9 @@ struct MxGemmConfig }; struct MXfp4_GemmConfig16 : MxGemmConfig { - static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t M_Tile = 64; static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 512; + static constexpr ck_tile::index_t K_Tile = 256; }; // GEMM config with 16x16 warp tile @@ -83,5 +83,5 @@ struct MXfp8_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 32; static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 512; + static constexpr ck_tile::index_t K_Tile = 256; }; diff --git a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp index e055401260..d53a64da4a 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp @@ -12,28 +12,6 @@ template using is_row_major_t = ck_tile::bool_constant< std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>; -// Problem definition for MX GEMM with comp_async pipeline -// The comp_async pipeline handles MX scaling with OpSel parameters -template -struct MXGemmPipelineProblem : ck_tile::GemmPipelineProblem -{ - static constexpr auto Scheduler = Scheduler_; -}; - -// Epilogue wrapper that adds MemoryOperation member for MX GEMM kernel compatibility -template -struct MXGemmEpilogueWrapper : BaseEpilogue_ -{ - static constexpr ck_tile::memory_operation_enum MemoryOperation = MemOp_; - using BaseEpilogue_::BaseEpilogue_; - using BaseEpilogue_::operator(); -}; - template & args, static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), "mixed_prec_gemm requires ADataType is a wider type than BDataType"); - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = - Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set; - using MXPipelineProblem = MXGemmPipelineProblem; + MXGemmTraits>; // Use the new MX comp_async pipeline with MX scaling support using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync; @@ -92,7 +66,7 @@ float mx_gemm_calc(const MXGemmHostArgs& args, GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using BaseEpilogue = + using GemmEpilogue = ck_tile::CShuffleEpilogue, // DsDataType @@ -108,15 +82,7 @@ float mx_gemm_calc(const MXGemmHostArgs& args, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - MXPipelineProblem::TransposeC, - GemmConfig::NumWaveGroups, // kNumWaveGroups - false, // FixedVectorSize - 1, // VectorSizeC - false, // TiledMMAPermuteN - 1, // BlockedXDLN_PerWarp - false>>; // DoubleSmemBuffer - - using GemmEpilogue = MXGemmEpilogueWrapper; + MXPipelineProblem::TransposeC>>; using Kernel = ck_tile::MXGemmKernel; diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index 0e1e08cbef..75bff4c3b7 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -1,89 +1,7 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Pack 4 consecutive e8m0_t scales in K dimension into int32 for efficient 32-bit loads -// For Scale A: [M, K/32] → [M, K/32/4] with int32 elements -// For Scale B: [K/32, N] → [K/32/4, N] with int32 elements -template -auto pack_scales_for_k_dimension(const ScaleTensor& scale_unpacked, - ck_tile::index_t pack_size = 4) -{ - using ScaleType = typename ScaleTensor::Data::value_type; - static_assert(sizeof(ScaleType) == 1, "Scale type must be 1 byte (e8m0_t)"); - - const auto& desc = scale_unpacked.mDesc; - ck_tile::index_t dim0 = desc.get_lengths()[0]; - ck_tile::index_t dim1 = desc.get_lengths()[1]; - ck_tile::index_t stride1 = desc.get_strides()[1]; - - // Determine which dimension is K (the one to pack) - // If stride1 == 1, then dim1 is contiguous (K dimension for row-major scale A) - // If stride0 == 1, then dim0 is contiguous (K dimension for col-major scale B) - bool pack_dim1 = (stride1 == 1); - - ck_tile::index_t packed_k_dim = pack_dim1 ? (dim1 / pack_size) : (dim0 / pack_size); - ck_tile::index_t new_dim0 = pack_dim1 ? dim0 : packed_k_dim; - ck_tile::index_t new_dim1 = pack_dim1 ? packed_k_dim : dim1; - // Calculate new strides based on which dimension was packed - ck_tile::index_t new_stride0, new_stride1; - if (pack_dim1) { - // Packed dim1 (K dimension for row-major): new shape [dim0, packed_k_dim] - // If original was row-major [dim0, dim1] with stride [dim1, 1] - // New should be row-major [dim0, packed_k_dim] with stride [packed_k_dim, 1] - new_stride0 = packed_k_dim; - new_stride1 = 1; - } else { - // Packed dim0 (K dimension for col-major): new shape [packed_k_dim, dim1] - // If original was col-major [dim0, dim1] with stride [1, dim0] - // New should be col-major [packed_k_dim, dim1] with stride [1, packed_k_dim] - new_stride0 = 1; - new_stride1 = packed_k_dim; - } - - ck_tile::HostTensor scale_packed( - ck_tile::HostTensorDescriptor({new_dim0, new_dim1}, {new_stride0, new_stride1})); - - // Pack scales: strided packing for K_lane distribution with OpSel - // Each int32_t packs 4 strided scales (one per kIter at same K_lane position) - // For K=512: 16 unpacked scales [0-15] -> 4 packed int32s - // int32[0] = {scale[0], scale[4], scale[8], scale[12]} <- K_lane=0, OpSel selects kIter - // int32[1] = {scale[1], scale[5], scale[9], scale[13]} <- K_lane=1, OpSel selects kIter - // int32[2] = {scale[2], scale[6], scale[10], scale[14]} <- K_lane=2, OpSel selects kIter - // int32[3] = {scale[3], scale[7], scale[11], scale[15]} <- K_lane=3, OpSel selects kIter - // OpSel(kIter) selects byte within thread's int32 for current kIter - for(ck_tile::index_t i = 0; i < new_dim0; ++i) - { - for(ck_tile::index_t j = 0; j < new_dim1; ++j) - { - int32_t packed_value = 0; - for(ck_tile::index_t k = 0; k < pack_size; ++k) - { - // Strided packing: byte k corresponds to kIter=k - // The stride is always pack_size (4), not packed_k_dim! - // For K=512: 16 unpacked elements [0-15] -> 4 packed int32s - // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4) - // int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (stride=4) - // For K=1024: 32 unpacked elements [0-31] -> 8 packed int32s - // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4) - // int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (stride=4) - // For row-major (pack_dim1=true): packed index j, byte k -> unpacked[j + k*4] - // For col-major (pack_dim1=false): packed index i, byte k -> unpacked[i*4 + k*4] = unpacked[(i + k)*4] - // But we want: packed index i, byte k -> unpacked[i*4 + k] (base i*4, then stride 4) - // Actually: int32[i] should pack {unpacked[i*4 + 0*4], unpacked[i*4 + 1*4], unpacked[i*4 + 2*4], unpacked[i*4 + 3*4]} - // = {unpacked[i*4], unpacked[i*4 + 4], unpacked[i*4 + 8], unpacked[i*4 + 12]} - ck_tile::index_t src_i = pack_dim1 ? i : (i * pack_size + k * pack_size); - ck_tile::index_t src_j = pack_dim1 ? (j * pack_size + k * pack_size) : j; - - uint8_t scale_byte = *reinterpret_cast(&scale_unpacked(src_i, src_j)); - packed_value |= (static_cast(scale_byte) << (k * 8)); - } - scale_packed(i, j) = packed_value; - } - } - - return scale_packed; -} - +// Use e8m0_t directly without packing - simpler and cleaner approach template {-1.f, 1.f, seed++}(a_host); ck_tile::FillUniformDistribution{-1.f, 1.f, seed++}(b_host); ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_a_host); ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_b_host); break; case 1: + // Initialize A, B, and scales to 1.0 ck_tile::FillConstant{ADataType(1.f)}(a_host); ck_tile::FillConstant{BDataType(1.f)}(b_host); ck_tile::FillConstant{ScaleType(1.f)}(scale_a_host); ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); break; case 2: + // Initialize A and B with random values but with constant 1.0 scales ck_tile::FillUniformDistribution{-1.f, 1.f, seed++}(a_host); ck_tile::FillUniformDistribution{-1.f, 1.f, seed++}(b_host); ck_tile::FillConstant{ScaleType(1.f)}(scale_a_host); ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); break; - case 3: - // Debug mode: simple power-of-2 pattern for scales (e8m0 format) - ck_tile::FillConstant{ADataType(1.f)}(a_host); - ck_tile::FillConstant{BDataType(1.f)}(b_host); - // Fill scales with power-of-2 pattern: 1.0, 2.0, 4.0, 8.0, 16.0, ... - // e8m0 is exponent-only, so these give clear distinct values - // for(std::size_t i = 0; i < scale_a_host.mDesc.get_element_space_size(); ++i) - // { - // float val = std::pow(2.0f, static_cast(i % 16)); // cycle through 2^0 to 2^15 - // scale_a_host.mData[i] = ScaleType(val); - // } - // for(std::size_t i = 0; i < scale_b_host.mDesc.get_element_space_size(); ++i) - // { - // float val = std::pow(2.0f, static_cast(i % 16)); // cycle through 2^0 to 2^15 - // scale_b_host.mData[i] = ScaleType(val); - // } - ck_tile::FillConstant{ScaleType(1.f)}(scale_a_host); - ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); - - // Test data to verify K block loading for K=1024 (2 K blocks) - // K block 0: K indices 0-511, scale K indices 0-15, packed into K_packed indices 0-3 - // K block 1: K indices 512-1023, scale K indices 16-31, packed into K_packed indices 4-7 - - // Scale A: [M, K/32] row-major (unpacked K indices in second dim) - // Strided packing: int32[j] packs unpacked[j], unpacked[j+4], unpacked[j+8], unpacked[j+12] - // K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3 - // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0) - // int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1) - scale_a_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0]) - scale_a_host(0, 4) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4]) - scale_a_host(0, 8) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8]) - scale_a_host(0, 12) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12]) - scale_a_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1]) - - // K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7 - // int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0) - // int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1) - scale_a_host(0, 16) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16]) - scale_a_host(0, 20) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20]) - scale_a_host(0, 24) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24]) - scale_a_host(0, 28) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28]) - scale_a_host(1, 16) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17]) - - // mIter=1: M rows 16-31 (second XDL block) - scale_a_host(16, 0) = ScaleType(64.f); // K block 0, unpacked K=0, M=16 - scale_a_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, M=16 - - // Scale B: [K/32, N] col-major (unpacked K indices in first dim, N in second dim) - // Strided packing: int32[i] packs unpacked[i], unpacked[i+8], unpacked[i+16], unpacked[i+24] - // K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3 - // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0) - // int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1) - scale_b_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0]) - scale_b_host(4, 0) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4]) - scale_b_host(8, 0) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8]) - scale_b_host(12, 0) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12]) - scale_b_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1]) - - // K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7 - // int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0) - // int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1) - scale_b_host(16, 0) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16]) - scale_b_host(20, 0) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20]) - scale_b_host(24, 0) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24]) - scale_b_host(28, 0) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28]) - scale_b_host(17, 0) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17]) - - // nIter=1: N rows 16-31 (second XDL block) - scale_b_host(0, 16) = ScaleType(64.f); // K block 0, unpacked K=0, N=16 - scale_b_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, N=16 - break; - } - - // Pack scales: 4 consecutive e8m0_t in K dimension → 1 int32 for efficient 32-bit loads - // This enables the GPU to load 4 scales (for 4 K-blocks) with a single 32-bit load - // Scale A: [M, K/32] → [M, K/128] with int32 elements (since K/32/4 = K/128) - // Scale B: [K/32, N] → [K/128, N] with int32 elements - auto scale_a_packed = pack_scales_for_k_dimension(scale_a_host, 4); - auto scale_b_packed = pack_scales_for_k_dimension(scale_b_host, 4); - - // DEBUG: Print first few packed scale values - if (true ||init_method == 3) - { - std::cout << "Host: ScaleA packed [0,0]: "; - uint8_t* a_bytes = reinterpret_cast(&scale_a_packed(0, 0)); - std::cout << "[" << static_cast(a_bytes[0]) << "," << static_cast(a_bytes[1]) << "," - << static_cast(a_bytes[2]) << "," << static_cast(a_bytes[3]) << "]\n"; - std::cout << "Host: ScaleA packed [0,4]: "; - uint8_t* a_bytes4 = reinterpret_cast(&scale_a_packed(0, 4)); - std::cout << "[" << static_cast(a_bytes4[0]) << "," << static_cast(a_bytes4[1]) << "," - << static_cast(a_bytes4[2]) << "," << static_cast(a_bytes4[3]) << "]\n"; - std::cout << "Host: ScaleB packed [0,0]: "; - uint8_t* b_bytes = reinterpret_cast(&scale_b_packed(0, 0)); - std::cout << "[" << static_cast(b_bytes[0]) << "," << static_cast(b_bytes[1]) << "," - << static_cast(b_bytes[2]) << "," << static_cast(b_bytes[3]) << "]\n"; - std::cout << "Host: ScaleB packed [4,0]: "; - uint8_t* b_bytes4 = reinterpret_cast(&scale_b_packed(4, 0)); - std::cout << "[" << static_cast(b_bytes4[0]) << "," << static_cast(b_bytes4[1]) << "," - << static_cast(b_bytes4[2]) << "," << static_cast(b_bytes4[3]) << "]\n"; - - // Print unpacked first row/col for reference - std::cout << "Host: ScaleA unpacked thread 0, every 4th element: ["; - for (int k = 0; k < std::min(32, static_cast(scale_a_host.mDesc.get_lengths()[1])); k += 4) - std::cout << static_cast(*reinterpret_cast(&scale_a_host(0, k))) << ","; - std::cout << "]\n"; - std::cout << "Host: ScaleB unpacked thread 0, every 4th element: ["; - for (int k = 0; k < std::min(32, static_cast(scale_b_host.mDesc.get_lengths()[0])); k += 4) - std::cout << static_cast(*reinterpret_cast(&scale_b_host(k, 0))) << ","; - std::cout << "]\n"; - // Threads 0-15: M rows 0-15, K_Lane cycles through 0,1,2,3 - // Thread 16: M row 0 again, but next K_Lane group (K_Lane=1 if cycling, or next K group) - // Actually, thread 16 goes back to row 0 with a different K index - std::cout << "Host: ScaleA unpacked thread 16 (row 0, next K group), every 4th element: ["; - for (int k = 1; k < std::min(32, static_cast(scale_a_host.mDesc.get_lengths()[1])); k += 4) - std::cout << static_cast(*reinterpret_cast(&scale_a_host(0, k))) << ","; - std::cout << "]\n"; - std::cout << "Host: ScaleB unpacked thread 16 (row 0, next K group), every 4th element: ["; - for (int k = 1; k < std::min(32, static_cast(scale_b_host.mDesc.get_lengths()[0])); k += 4) - std::cout << static_cast(*reinterpret_cast(&scale_b_host(k, 0))) << ","; - std::cout << "]\n"; } + // Device buffers for A, B, C, and scale tensors ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes()); a_dev_buf.ToDevice(a_host.data()); b_dev_buf.ToDevice(b_host.data()); - scale_a_dev_buf.ToDevice(scale_a_packed.data()); - scale_b_dev_buf.ToDevice(scale_b_packed.data()); + scale_a_dev_buf.ToDevice(scale_a_host.data()); + scale_b_dev_buf.ToDevice(scale_b_host.data()); - // Scale pointers - using ScaleM = ck_tile::MXScalePointer<1, 32>; // in blocks of 32 in K - using ScaleN = ck_tile::MXScalePointer<1, 32>; - - ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); - ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); + // Scale pointers - use e8m0_t* directly + using ScaleM = ck_tile::MXScalePointer; // in blocks of 32 in K + using ScaleN = ck_tile::MXScalePointer; + ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); + ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); float ave_time = invoke_mx_gemm( a_host, b_host, c_m_n_host_ref, scale_a_host, scale_b_host); - // ck_tile::reference_gemm( - // a_host, b_host, c_m_n_host_ref); auto calculate_rtol_atol = [&K, &kbatch](const float max_accumulated_value) { - // using ComputeType = - // std::conditional_t; - // // Calculate thresholds - // const auto rtol = ck_tile::get_relative_threshold( - // ck_tile::integer_divide_ceil(K, kbatch)); - // const auto atol = ck_tile::get_absolute_threshold( - // max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); - // // Calculate error due to split_k accumulation - // const auto rtol_split_k = - // ck_tile::get_relative_threshold(kbatch); - // const auto atol_split_k = ck_tile::get_absolute_threshold( - // max_accumulated_value, kbatch); - // // Use higher threshold - // return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); - ck_tile::ignore = K; ck_tile::ignore = kbatch; ck_tile::ignore = max_accumulated_value; - return ck_tile::make_tuple(0.1, 1.0); + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); }; const float max_accumulated_value =