From e40a675f745b210fa7e159d595f815152b1cae4f Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Sat, 28 Mar 2026 04:36:39 +0800 Subject: [PATCH] [CK_TILE ]Revert "[CK_TILE] Enable MXFP6 for MX GEMM op (#5095)" (#5849) This reverts commit 7e55766ddf7e9e20791b0e4e2d7b4026cf16b637. ## Motivation ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- include/ck_tile/core/numeric/pk_fp6.hpp | 30 +++------------- include/ck_tile/core/tensor/buffer_view.hpp | 16 ++------- .../core/tensor/tile_scatter_gather.hpp | 18 ++++------ include/ck_tile/core/tensor/tile_window.hpp | 14 +++----- include/ck_tile/host/check_err.hpp | 2 -- include/ck_tile/host/fill.hpp | 35 ------------------- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 7 ++-- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 22 +++++------- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 6 ++-- test/ck_tile/gemm_mx/CMakeLists.txt | 3 +- test/ck_tile/gemm_mx/test_mx_gemm_config.hpp | 7 ---- test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp | 30 ---------------- test/ck_tile/gemm_mx/test_mx_gemm_util.hpp | 1 + 13 files changed, 31 insertions(+), 160 deletions(-) delete mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp diff --git a/include/ck_tile/core/numeric/pk_fp6.hpp b/include/ck_tile/core/numeric/pk_fp6.hpp index a8b1d2eea1..0de61f6b1f 100644 --- a/include/ck_tile/core/numeric/pk_fp6.hpp +++ b/include/ck_tile/core/numeric/pk_fp6.hpp @@ -22,10 +22,7 @@ struct pk_fp6_t static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; element_type data_[vector_size]; // packed data using type = pk_fp6_t; - - CK_TILE_HOST_DEVICE constexpr pk_fp6_t() : data_{element_type{}} {} - - CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value) + CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0) { for(size_t i = 0; i < vector_size; ++i) { @@ -62,14 +59,13 @@ struct pk_fp6_t const int bit_offset = bit_pos % num_bits_vec_elem; const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t bits = static_cast(pk.data_[arr_idx]) >> bit_offset; + int32_t bits = pk.data_[arr_idx] >> bit_offset; if(overhang > 0 && (arr_idx + 1) < vector_size) { - bits |= (static_cast(pk.data_[arr_idx + 1]) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); + bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); } - return static_cast(bits & 0x3F); + return bits & 0x3F; } CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); } @@ -101,22 +97,6 @@ struct pk_fp6_t } return sign == 1 ? -1 * result : result; } - - CK_TILE_HOST static int32_t float_to_fp6_e2m3(float val) - { - int32_t best = 0; - float best_err = 1e30f; - for(int32_t i = 0; i < 64; i++) - { - float err = std::fabs(val - fp6_e2m3_to_float(i)); - if(err < best_err) - { - best = i; - best_err = err; - } - } - return best; - } }; using pk_fp6x16_t = pk_fp6_t<16>; @@ -125,7 +105,5 @@ template <> struct numeric_traits { static constexpr int PackedSize = 16; - static constexpr int exp = 2; - static constexpr int mant = 3; }; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index bba9fd7fb0..4d0e915685 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -19,14 +19,6 @@ namespace ck_tile { -// buffer_load_dwordx3 to LDS uses a fixed 16-byte per-thread stride, -// padding each 12-byte element to 16 bytes in LDS. -template -CK_TILE_HOST_DEVICE constexpr index_t lds_padded_sizeof() -{ - return (sizeof(T) == 12) ? 16 : sizeof(T); -} - // T may be scalar or vector // X may be scalar or vector // T and X have same scalar type @@ -848,10 +840,7 @@ struct buffer_view>::scalar_type, scalar_per_t_vector * scalar_per_x_vector>; - constexpr index_t padded_stride = lds_padded_sizeof(); - const char* base = - reinterpret_cast(p_data_) + (i + linear_offset) * padded_stride; - auto rtn = *c_style_pointer_cast(base); + auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); return bit_cast(rtn); } #endif @@ -883,8 +872,7 @@ struct buffer_view = {}) const { - constexpr index_t padded_stride = lds_padded_sizeof(); - smem_load{}(dst, v_offset * padded_stride, i_offset * padded_stride); + smem_load{}(dst, v_offset * sizeof(T), i_offset * sizeof(T)); } template (); - const index_t size_per_buf = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<0>{}, number<0>{})) * - lds_stride; + sizeof(LdsDataType); const index_t size_per_wave = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<1>{}, number<0>{})) * - lds_stride - + sizeof(LdsDataType) - size_per_buf; const index_t size_per_issue = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<1>{}, number<0>{}, number<0>{})) * - lds_stride - + sizeof(LdsDataType) - size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); @@ -783,12 +780,9 @@ struct tile_scatter_gather make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); // Calculate SMEM address using base pointer - // Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride) - CK_TILE_LDS_ADDR LdsDataType* smem = - reinterpret_cast( - reinterpret_cast(lds_base_ptr) + - (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize * - lds_padded_sizeof()); + CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr + + lds_coord.get_offset() / Traits::PackedSize + + lds_ys_offset / Traits::PackedSize; const auto dram_ys_offset = [&]() { if constexpr(static_move_ys) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 4e194a3f54..3e28544509 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -501,23 +501,21 @@ struct tile_window_with_static_distribution // issues * warps * lanes static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded - constexpr index_t lds_stride = lds_padded_sizeof(); - const index_t size_per_buf = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<0>{}, number<0>{})) * - lds_stride; + sizeof(LdsDataType); const index_t size_per_wave = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<1>{}, number<0>{})) * - lds_stride - + sizeof(LdsDataType) - size_per_buf; const index_t size_per_issue = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<1>{}, number<0>{}, number<0>{})) * - lds_stride - + sizeof(LdsDataType) - size_per_buf; // Use VALU so the compiler can optimize redundant/repeated computations @@ -630,12 +628,8 @@ struct tile_window_with_static_distribution make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); // Calculate SMEM address using base pointer - // Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride) CK_TILE_LDS_ADDR LdsDataType* smem = - reinterpret_cast( - reinterpret_cast(lds_base_ptr) + - (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize * - lds_padded_sizeof()); + lds_base_ptr + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize; const auto dram_ys_offset = [&]() { if constexpr(static_move_ys) diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 9e49154b34..458a725379 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -61,7 +61,6 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1 tf32_t, pk_fp4_t, pk_fp4_raw_t, - pk_fp6x16_t, pk_int4_t, I8, I32, @@ -136,7 +135,6 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, tf32_t, pk_fp4_t, pk_fp4_raw_t, - pk_fp6x16_t, pk_int4_t, I8, I32, diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 44d1913033..bddc0ae2d2 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -169,41 +169,6 @@ struct FillUniformDistribution } }; -template <> -struct FillUniformDistribution -{ - float a_{-2.f}; - float b_{2.f}; - std::optional seed_{11939}; - - template - void operator()(ForwardIter first, ForwardIter last) const - { - std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - while(first != last) - { - ck_tile::pk_fp6x16_t pk{}; - for(ck_tile::index_t i = 0; i < ck_tile::pk_fp6x16_t::packed_size; ++i) - { - pk.pack(ck_tile::pk_fp6x16_t::float_to_fp6_e2m3(dis(gen)), i); - } - *first = pk; - ++first; - } - } - - template - auto operator()(ForwardRange&& range) const - -> std::void_t()( - std::begin(std::forward(range)), - std::end(std::forward(range))))> - { - (*this)(std::begin(std::forward(range)), - std::end(std::forward(range))); - } -}; - namespace impl { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index dc56f34bc7..f43bcbc4b1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -146,10 +146,9 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); // TODO: LDS alignment should come from Policy! - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t a_lds_block_space_size = lds_padded_sizeof() * - a_lds_block_desc.get_element_space_size() / - APackedSize; + constexpr index_t APackedSize = numeric_traits::PackedSize; + constexpr index_t a_lds_block_space_size = + sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize; constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(a_lds_block_space_size, 16); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index cdbb2662f9..b4a8e9e8cb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -834,10 +834,9 @@ struct UniversalGemmBasePolicy using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); - constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(A)) * - numeric_traits::PackedSize; + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(A)); - return ck_tile::min(KPack, VecElems); + return (KPack < VecElems) ? KPack : VecElems; } template @@ -847,10 +846,9 @@ struct UniversalGemmBasePolicy using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); - constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(B)) * - numeric_traits::PackedSize; + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(B)); - return ck_tile::min(KPack, VecElems); + return (KPack < VecElems) ? KPack : VecElems; } template @@ -859,10 +857,8 @@ struct UniversalGemmBasePolicy using ADataType = remove_cvref_t; constexpr auto APackedSize = numeric_traits::PackedSize; constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor(); - constexpr index_t smem_size_a = - integer_least_multiple(a_lds_block_desc.get_element_space_size() * - lds_padded_sizeof() / APackedSize, - 16); + constexpr index_t smem_size_a = integer_least_multiple( + a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16); return smem_size_a; } @@ -875,10 +871,8 @@ struct UniversalGemmBasePolicy typename Problem::BDataType>; constexpr auto BPackedSize = numeric_traits::PackedSize; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); - constexpr index_t smem_size_b = - integer_least_multiple(b_lds_block_desc.get_element_space_size() * - lds_padded_sizeof() / BPackedSize, - 16); + constexpr index_t smem_size_b = integer_least_multiple( + b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16); return smem_size_b; } diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index bf00bc0b0f..ec16a4e8b6 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -442,12 +442,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< MWarp / BlockSize, "BLdsTile size is wrong!"); static_assert(Policy::template GetSmemSizeA() == - MPerBlock * - (KPerBlock * lds_padded_sizeof() / APackedSize), + MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), "SmemSizeA size is wrong!"); static_assert(Policy::template GetSmemSizeB() == - (KPerBlock * lds_padded_sizeof() / BPackedSize) * - NPerBlock, + (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, "SmemSizeB size is wrong!"); ////////////// MX Scale register tiles (ping-pong buffers) ///////////////// diff --git a/test/ck_tile/gemm_mx/CMakeLists.txt b/test/ck_tile/gemm_mx/CMakeLists.txt index 31fb3ef8a5..36d2e455ae 100644 --- a/test/ck_tile/gemm_mx/CMakeLists.txt +++ b/test/ck_tile/gemm_mx/CMakeLists.txt @@ -9,8 +9,7 @@ endif() if(GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp) target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_mx_gemm_fp6 test_mx_gemm_fp6.cpp) - target_compile_options(test_ck_tile_mx_gemm_fp6 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp) target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) else() diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp index 41dc249788..3cce36a85d 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp @@ -87,13 +87,6 @@ struct MXfp4_GemmConfig16 : MxGemmConfig static constexpr ck_tile::index_t K_Tile = 256; }; -struct MXfp6_GemmConfig16 : MxGemmConfig -{ - static constexpr ck_tile::index_t M_Tile = 64; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; -}; - struct MXfp8_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 64; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp deleted file mode 100644 index c63f1d9156..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_mx_gemm_config.hpp" -#include "test_mx_gemm_util.hpp" - -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - -using MxFp6Types = ::testing::Types< - std::tuple>; - -template -class TestMxGemmFp6 : public TestMxGemmUtil, - std::tuple_element_t<1, TypeParam>, - std::tuple_element_t<2, TypeParam>, - std::tuple_element_t<3, TypeParam>, - std::tuple_element_t<4, TypeParam>, - std::tuple_element_t<5, TypeParam>> -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp6, MxFp6Types); - -TYPED_TEST(TestMxGemmFp6, BasicSizes) -{ - this->Run(64, 64, 256); - this->Run(128, 128, 256); - this->Run(64, 128, 512); -} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp index cbf2a7ecd7..6e7ddfb5d0 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp @@ -4,6 +4,7 @@ #pragma once #include + #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/check_err.hpp"