diff --git a/include/ck_tile/core/numeric/pk_fp6.hpp b/include/ck_tile/core/numeric/pk_fp6.hpp index 0de61f6b1f..a8b1d2eea1 100644 --- a/include/ck_tile/core/numeric/pk_fp6.hpp +++ b/include/ck_tile/core/numeric/pk_fp6.hpp @@ -22,7 +22,10 @@ struct pk_fp6_t static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; element_type data_[vector_size]; // packed data using type = pk_fp6_t; - CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0) + + CK_TILE_HOST_DEVICE constexpr pk_fp6_t() : data_{element_type{}} {} + + CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value) { for(size_t i = 0; i < vector_size; ++i) { @@ -59,13 +62,14 @@ struct pk_fp6_t const int bit_offset = bit_pos % num_bits_vec_elem; const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - int32_t bits = pk.data_[arr_idx] >> bit_offset; + uint32_t bits = static_cast(pk.data_[arr_idx]) >> bit_offset; if(overhang > 0 && (arr_idx + 1) < vector_size) { - bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); + bits |= (static_cast(pk.data_[arr_idx + 1]) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); } - return bits & 0x3F; + return static_cast(bits & 0x3F); } CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); } @@ -97,6 +101,22 @@ struct pk_fp6_t } return sign == 1 ? -1 * result : result; } + + CK_TILE_HOST static int32_t float_to_fp6_e2m3(float val) + { + int32_t best = 0; + float best_err = 1e30f; + for(int32_t i = 0; i < 64; i++) + { + float err = std::fabs(val - fp6_e2m3_to_float(i)); + if(err < best_err) + { + best = i; + best_err = err; + } + } + return best; + } }; using pk_fp6x16_t = pk_fp6_t<16>; @@ -105,5 +125,7 @@ template <> struct numeric_traits { static constexpr int PackedSize = 16; + static constexpr int exp = 2; + static constexpr int mant = 3; }; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 4d0e915685..bba9fd7fb0 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -19,6 +19,14 @@ namespace ck_tile { +// buffer_load_dwordx3 to LDS uses a fixed 16-byte per-thread stride, +// padding each 12-byte element to 16 bytes in LDS. +template +CK_TILE_HOST_DEVICE constexpr index_t lds_padded_sizeof() +{ + return (sizeof(T) == 12) ? 16 : sizeof(T); +} + // T may be scalar or vector // X may be scalar or vector // T and X have same scalar type @@ -840,7 +848,10 @@ struct buffer_view>::scalar_type, scalar_per_t_vector * scalar_per_x_vector>; - auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); + constexpr index_t padded_stride = lds_padded_sizeof(); + const char* base = + reinterpret_cast(p_data_) + (i + linear_offset) * padded_stride; + auto rtn = *c_style_pointer_cast(base); return bit_cast(rtn); } #endif @@ -872,7 +883,8 @@ struct buffer_view = {}) const { - smem_load{}(dst, v_offset * sizeof(T), i_offset * sizeof(T)); + constexpr index_t padded_stride = lds_padded_sizeof(); + smem_load{}(dst, v_offset * padded_stride, i_offset * padded_stride); } template (); + const index_t size_per_buf = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<0>{}, number<0>{})) * - sizeof(LdsDataType); + lds_stride; const index_t size_per_wave = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<1>{}, number<0>{})) * - sizeof(LdsDataType) - + lds_stride - size_per_buf; const index_t size_per_issue = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<1>{}, number<0>{}, number<0>{})) * - sizeof(LdsDataType) - + lds_stride - size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); @@ -780,9 +783,12 @@ struct tile_scatter_gather make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); // Calculate SMEM address using base pointer - CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr + - lds_coord.get_offset() / Traits::PackedSize + - lds_ys_offset / Traits::PackedSize; + // Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride) + CK_TILE_LDS_ADDR LdsDataType* smem = + reinterpret_cast( + reinterpret_cast(lds_base_ptr) + + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize * + lds_padded_sizeof()); const auto dram_ys_offset = [&]() { if constexpr(static_move_ys) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 3e28544509..4e194a3f54 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -501,21 +501,23 @@ struct tile_window_with_static_distribution // issues * warps * lanes static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + constexpr index_t lds_stride = lds_padded_sizeof(); + const index_t size_per_buf = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<0>{}, number<0>{})) * - sizeof(LdsDataType); + lds_stride; const index_t size_per_wave = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<0>{}, number<1>{}, number<0>{})) * - sizeof(LdsDataType) - + lds_stride - size_per_buf; const index_t size_per_issue = lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( make_tuple(number<1>{}, number<0>{}, number<0>{})) * - sizeof(LdsDataType) - + lds_stride - size_per_buf; // Use VALU so the compiler can optimize redundant/repeated computations @@ -628,8 +630,12 @@ struct tile_window_with_static_distribution make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); // Calculate SMEM address using base pointer + // Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride) CK_TILE_LDS_ADDR LdsDataType* smem = - lds_base_ptr + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize; + reinterpret_cast( + reinterpret_cast(lds_base_ptr) + + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize * + lds_padded_sizeof()); const auto dram_ys_offset = [&]() { if constexpr(static_move_ys) diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 458a725379..9e49154b34 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -61,6 +61,7 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1 tf32_t, pk_fp4_t, pk_fp4_raw_t, + pk_fp6x16_t, pk_int4_t, I8, I32, @@ -135,6 +136,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num, tf32_t, pk_fp4_t, pk_fp4_raw_t, + pk_fp6x16_t, pk_int4_t, I8, I32, diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index bddc0ae2d2..44d1913033 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -169,6 +169,41 @@ struct FillUniformDistribution } }; +template <> +struct FillUniformDistribution +{ + float a_{-2.f}; + float b_{2.f}; + std::optional seed_{11939}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + while(first != last) + { + ck_tile::pk_fp6x16_t pk{}; + for(ck_tile::index_t i = 0; i < ck_tile::pk_fp6x16_t::packed_size; ++i) + { + pk.pack(ck_tile::pk_fp6x16_t::float_to_fp6_e2m3(dis(gen)), i); + } + *first = pk; + ++first; + } + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + namespace impl { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index f43bcbc4b1..dc56f34bc7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -146,9 +146,10 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); // TODO: LDS alignment should come from Policy! - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t a_lds_block_space_size = - sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize; + constexpr index_t APackedSize = numeric_traits::PackedSize; + constexpr index_t a_lds_block_space_size = lds_padded_sizeof() * + a_lds_block_desc.get_element_space_size() / + APackedSize; constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(a_lds_block_space_size, 16); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index b4a8e9e8cb..cdbb2662f9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -834,9 +834,10 @@ struct UniversalGemmBasePolicy using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); - constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(A)); + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(A)) * + numeric_traits::PackedSize; - return (KPack < VecElems) ? KPack : VecElems; + return ck_tile::min(KPack, VecElems); } template @@ -846,9 +847,10 @@ struct UniversalGemmBasePolicy using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = static_cast(BlockGemm::Traits::KPack); - constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(B)); + constexpr index_t VecElems = static_cast(Problem::VectorLoadSize / sizeof(B)) * + numeric_traits::PackedSize; - return (KPack < VecElems) ? KPack : VecElems; + return ck_tile::min(KPack, VecElems); } template @@ -857,8 +859,10 @@ struct UniversalGemmBasePolicy using ADataType = remove_cvref_t; constexpr auto APackedSize = numeric_traits::PackedSize; constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor(); - constexpr index_t smem_size_a = integer_least_multiple( - a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16); + constexpr index_t smem_size_a = + integer_least_multiple(a_lds_block_desc.get_element_space_size() * + lds_padded_sizeof() / APackedSize, + 16); return smem_size_a; } @@ -871,8 +875,10 @@ struct UniversalGemmBasePolicy typename Problem::BDataType>; constexpr auto BPackedSize = numeric_traits::PackedSize; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); - constexpr index_t smem_size_b = integer_least_multiple( - b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16); + constexpr index_t smem_size_b = + integer_least_multiple(b_lds_block_desc.get_element_space_size() * + lds_padded_sizeof() / BPackedSize, + 16); return smem_size_b; } diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index ec16a4e8b6..bf00bc0b0f 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -442,10 +442,12 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< MWarp / BlockSize, "BLdsTile size is wrong!"); static_assert(Policy::template GetSmemSizeA() == - MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), + MPerBlock * + (KPerBlock * lds_padded_sizeof() / APackedSize), "SmemSizeA size is wrong!"); static_assert(Policy::template GetSmemSizeB() == - (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, + (KPerBlock * lds_padded_sizeof() / BPackedSize) * + NPerBlock, "SmemSizeB size is wrong!"); ////////////// MX Scale register tiles (ping-pong buffers) ///////////////// diff --git a/test/ck_tile/gemm_mx/CMakeLists.txt b/test/ck_tile/gemm_mx/CMakeLists.txt index 36d2e455ae..31fb3ef8a5 100644 --- a/test/ck_tile/gemm_mx/CMakeLists.txt +++ b/test/ck_tile/gemm_mx/CMakeLists.txt @@ -9,7 +9,8 @@ endif() if(GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp) target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) - + add_gtest_executable(test_ck_tile_mx_gemm_fp6 test_mx_gemm_fp6.cpp) + target_compile_options(test_ck_tile_mx_gemm_fp6 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp) target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) else() diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp index 3cce36a85d..41dc249788 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp @@ -87,6 +87,13 @@ struct MXfp4_GemmConfig16 : MxGemmConfig static constexpr ck_tile::index_t K_Tile = 256; }; +struct MXfp6_GemmConfig16 : MxGemmConfig +{ + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; +}; + struct MXfp8_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 64; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp new file mode 100644 index 0000000000..c63f1d9156 --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_fp6.cpp @@ -0,0 +1,30 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_config.hpp" +#include "test_mx_gemm_util.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +using MxFp6Types = ::testing::Types< + std::tuple>; + +template +class TestMxGemmFp6 : public TestMxGemmUtil, + std::tuple_element_t<1, TypeParam>, + std::tuple_element_t<2, TypeParam>, + std::tuple_element_t<3, TypeParam>, + std::tuple_element_t<4, TypeParam>, + std::tuple_element_t<5, TypeParam>> +{ +}; + +TYPED_TEST_SUITE(TestMxGemmFp6, MxFp6Types); + +TYPED_TEST(TestMxGemmFp6, BasicSizes) +{ + this->Run(64, 64, 256); + this->Run(128, 128, 256); + this->Run(64, 128, 512); +} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp index 6e7ddfb5d0..cbf2a7ecd7 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp @@ -4,7 +4,6 @@ #pragma once #include - #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/check_err.hpp"