diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index c5cecceb9c..43789750d0 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -21,7 +21,9 @@ if(has_supported_gpu) add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES}) target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm) - set(EXAMPLE_FLATMM_COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + # ... because they are auto-generated + set(EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template) set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 33a2ba3135..14976b8093 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -136,7 +136,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run MXFP4_Flatmm kernel " // + std::cout << "Run " << ck_tile::gemm_prec_str() << " Flatmm kernel " // << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A << " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -172,42 +172,47 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template -void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K) +template +auto preShuffleWeight(ck_tile::HostTensor& src) { - int KPack = 16; - int NLane = FlatmmConfig::N_Warp_Tile; - int KLane = 64 / NLane; - int K_pk = K / 2; - int K0 = K_pk / (KLane * KPack); + auto src_lengths = src.get_lengths(); + const int K = src_lengths[0]; + const int N = src_lengths[1]; + constexpr int packed_size = ck_tile::numeric_traits::PackedSize; + int KPack = 16 * packed_size; // fp4:32 or fp8:16 + int NLane = N_Warp_Tile; + int KLane = 64 / NLane; + int K0 = K / (KLane * KPack); + + ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); + // K -> K0 KLane KPack // N -> N0 NLane // N, K -> N0 K0 KLane NLane KPack - int tempk; for(int n = 0; n < N; ++n) { - for(int k = 0; k < K_pk; ++k) + for(int k = 0; k < K; k += packed_size) { int n0 = n / NLane; int n1 = n % NLane; - int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); - int k1 = tempk / KPack; - int k2 = tempk % KPack; + int k0 = k / (KLane * KPack); + int tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack + k2; - dst[outputIndex] = src[n * K_pk + k]; + shuffled(outputIndex) = src(k, n); } } + return shuffled; } -template -auto preShuffleScale(Src& src) +template +auto preShuffleScale(ck_tile::HostTensor& src) { - using dtype = typename Src::Data::value_type; auto src_lengths = src.get_lengths(); const auto MN = KLast ? src_lengths[0] : src_lengths[1]; const auto K = KLast ? src_lengths[1] : src_lengths[0]; @@ -261,7 +266,6 @@ auto preShuffleScale(Src& src) #include "run_mx_flatmm.inc" -template int run_mx_flatmm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -278,24 +282,31 @@ int run_mx_flatmm_example(int argc, char* argv[]) if(a_layout == "R" && b_layout == "C") { - if(mx_prec == "fp4xfp4") + if(mx_prec == "fp4" || mx_prec == "fp4xfp4") { if(persistent_opt == 0) return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); else throw std::runtime_error("Only non-persistent kernels are supported currently!"); } - else if(mx_prec == "fp6xfp6") + else if(mx_prec == "fp6" || mx_prec == "fp6xfp6") { - throw std::runtime_error("Only support fp4xfp4 now!"); + throw std::runtime_error("fp6xfp6 is not supported."); } - else if(mx_prec == "fp8xfp8") + else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") { - throw std::runtime_error("Only support fp4xfp4 now!"); + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); } else { @@ -306,7 +317,6 @@ int run_mx_flatmm_example(int argc, char* argv[]) { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); } - return -1; } int main(int argc, char* argv[]) @@ -319,7 +329,7 @@ int main(int argc, char* argv[]) int warp_tile = arg_parser.get_int("warp_tile"); if(warp_tile == 0) { - return run_mx_flatmm_example(argc, argv); + return run_mx_flatmm_example(argc, argv); } else if(warp_tile == 1) { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index b47d3a95ab..248cf28341 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -12,4 +12,87 @@ #include "ck_tile/ops/flatmm.hpp" #include "ck_tile/ops/gemm.hpp" -#include "mxfp4_flatmm.hpp" +// GEMM config with 16x16 warp tile +struct MXfp4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 512; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + +struct MXfp8_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + +template +float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake index 950b0c72a6..63158b807f 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake @@ -1,24 +1,29 @@ function(mx_flatmm_instance_generate FILE_LIST) - set(FLATMM_CONFIG MXfp4_FlatmmConfig16) - set(A_DATA_TYPE FP4) - set(B_DATA_TYPE FP4) set(C_DATA_TYPE FP16) set(A_LAYOUT ROW) set(B_LAYOUT COL) set(C_LAYOUT ROW) + set(FLATMM_CONFIG_FP4 "MXfp4_FlatmmConfig16") + set(FLATMM_CONFIG_FP8 "MXfp8_FlatmmConfig16") # foreach(PERSISTENT false true) # TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions. foreach(PERSISTENT false) - foreach(SPLIT_K false true) - foreach(HAS_HOT_LOOP false true) - foreach(TAIL_NUMBER ODD EVEN) - set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp) - configure_file( - ${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in - ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE} - @ONLY) - list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}) + foreach(DATA_TYPE FP4 FP8) + set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}}) + set(A_DATA_TYPE ${DATA_TYPE}) + set(B_DATA_TYPE ${DATA_TYPE}) + foreach(SPLIT_K false true) + foreach(HAS_HOT_LOOP false true) + foreach(TAIL_NUMBER ODD EVEN) + set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${DATA_TYPE}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp) + string(TOLOWER ${KERNEL_FILE} KERNEL_FILE) + configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in + ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE} + @ONLY) + list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}) + endforeach() endforeach() endforeach() endforeach() diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in index 0be9fc7bb7..d9fe78b701 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -18,6 +18,7 @@ // clang-format on using FP4 = ck_tile::pk_fp4_t; +using FP8 = ck_tile::fp8_t; using FP16 = ck_tile::fp16_t; using BF16 = ck_tile::bf16_t; diff --git a/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp deleted file mode 100644 index 02f58a6269..0000000000 --- a/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp +++ /dev/null @@ -1,60 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -// GEMM config with 16x16 warp tile -struct MXfp4_FlatmmConfig16 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 512; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; - -template -float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, - const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 0171fc1403..dd522bbcb6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -88,10 +88,7 @@ int run_mx_flatmm_with_layouts(int argc, throw std::runtime_error("wrong! Unexpected init_method"); } - ck_tile::HostTensor b_shuffled_host( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); - preShuffleWeight(b_origin_host.begin(), b_shuffled_host.begin(), N, K); - + const auto b_shuffled_host = preShuffleWeight(b_origin_host); const auto scale_a_shuffled = preShuffleScale(scale_a); const auto scale_b_shuffled = preShuffleScale(scale_b); diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index 1eec80828a..c22fad07f4 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -41,6 +41,8 @@ using long_number = constant; template using bool_constant = constant; +using true_type = bool_constant; +using false_type = bool_constant; #define CK_TILE_LEFT_UNARY_OP(OP) \ template \ diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 1be4259e97..6b6cad299a 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -21,9 +21,10 @@ namespace ck_tile { template >> + typename offset_t, + typename = std::enable_if_t>> CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window, - index_t offset, + offset_t offset, number = {}, bool_constant = {}) { @@ -67,11 +68,12 @@ template > && std::is_class_v>> CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile, const TileWindow_& tile_window, - index_t offset, + offset_t offset, number = {}, bool_constant = {}) { @@ -147,29 +149,31 @@ template > && std::is_class_v>> -CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile, +CK_TILE_DEVICE void async_load_tile_with_offset(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, index_t offset, - number = {}, - bool_constant = {}) + number = {}, + bool_constant occ = {}, + bool_constant smy = {}) { - return tile_window.async_load_with_offset( - offset, lds_tile, number{}, bool_constant{}); + tile_window.async_load_with_offset(offset, lds_tile, number{}, occ, smy); } template -CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, + bool oob_conditional_check = true, + bool static_move_ys = false> +CK_TILE_DEVICE void async_load_tile(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, - number = {}, - bool_constant = {}) + number = {}, + bool_constant occ = {}, + bool_constant smy = {}) { - return async_load_tile_with_offset( - lds_tile, tile_window, 0, number{}, bool_constant{}); + async_load_tile_with_offset(lds_tile, tile_window, 0, number{}, occ, smy); } template -CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, +CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, number = {}, bool_constant = {}, bool_constant = {}) { - return tile_window.async_load_raw(lds_tile, - number{}, - bool_constant{}, - bool_constant{}); + tile_window.async_load_raw(lds_tile, + number{}, + bool_constant{}, + bool_constant{}); } -CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) +CK_TILE_DEVICE void async_load_fence(index_t cnt = 0) { asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 7dd2684347..3cdc4ff1cf 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -166,8 +166,8 @@ struct tensor_view { return buf_.template async_get( smem, - coord.get_offset() / PackedSize, - linear_offset / PackedSize, + coord.get_offset() / PackedSize + linear_offset / PackedSize, + 0, // linear_offset need to be imm and is not supported currently coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index ea459417d2..89a0cc0f53 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -156,8 +156,10 @@ struct tile_window_with_static_distribution 0, number{}, bool_constant{}); } - template - CK_TILE_DEVICE auto load_with_offset(index_t offset, + template + CK_TILE_DEVICE auto load_with_offset(offset_t offset, number = {}, bool_constant = {}) const { @@ -291,14 +293,16 @@ struct tile_window_with_static_distribution 0, dst_tensor, number{}, bool_constant{}); } - template >>> - CK_TILE_DEVICE auto load_with_offset(index_t offset, - DistributedTensor& dst_tensor, - number = {}, - bool_constant = {}) const + typename offset_t> + CK_TILE_DEVICE void load_with_offset( // + offset_t offset, + static_distributed_tensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -306,6 +310,19 @@ struct tile_window_with_static_distribution constexpr auto tile_dstr = typename Base::TileDstr{}; + const index_t linear_off = [&]() { + if constexpr(std::is_integral_v) + return offset; + else if constexpr(is_constant_v) + return offset_t::value; + else + { + auto bottom_tensor_idx_off = to_multi_index(offset_t{}); + auto bottom_tensor_coord_off = make_tensor_coordinate( + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off); + return bottom_tensor_coord_off.get_offset(); + } + }(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { /// TODO: use structure binding (to be captured later) if compiled in C++20 @@ -321,7 +338,9 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = this->get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, offset, bool_constant{}); + bottom_tensor_thread_coord, + linear_off, + bool_constant{}); // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -514,11 +533,13 @@ struct tile_window_with_static_distribution template >>> CK_TILE_DEVICE void async_load_with_offset(index_t offset, LdsTileWindow_&& lds_tile, number = {}, - bool_constant = {}) const + bool_constant = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; @@ -531,7 +552,7 @@ struct tile_window_with_static_distribution const auto window_origin = lds_tile.get_window_origin(); const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view(); const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor(); - auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_; + auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_; static_for<0, NumCoord, 1>{}([&](auto iCoord) { auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; @@ -543,22 +564,51 @@ struct tile_window_with_static_distribution static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; - // Use precomputed window origin + constexpr auto idx_ys_offset = [&]() { + constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess); + constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate( + StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(), + container_concat(array{0}, + to_array(idx_off_ys))); + return adapter_ys_offset.get_bottom_index(); + }(); + const auto lds_ys_offset = [&]() { + if constexpr(static_move_ys) + { + const auto coord_ys_offset = + make_tensor_coordinate(tensor_descriptor, idx_ys_offset); + return coord_ys_offset.get_offset(); + } + else + return 0; + }(); + + // Use precomputed window origin & tensor descriptor auto lds_bottom_tensor_thread_idx = window_origin + window_adaptor_warp_coord.get_bottom_index(); - - // Use precomputed tensor descriptor const auto lds_coord = make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); // Calculate SMEM address using base pointer - CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset(); + CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr + + lds_coord.get_offset() / Traits::PackedSize + + lds_ys_offset / Traits::PackedSize; + + const auto dram_ys_offset = [&]() { + if constexpr(static_move_ys) + { + const auto coord_ys_offset = make_tensor_coordinate( + this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset); + return coord_ys_offset.get_offset(); + } + else + return 0; + }(); - // Write into bottom tensor this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, - offset, + offset + dram_ys_offset, bool_constant{}); // Move thread coordinate if not last access @@ -569,11 +619,15 @@ struct tile_window_with_static_distribution generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( - window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + if constexpr(!static_move_ys) + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, + bottom_tensor_thread_coord, + idx_diff_ps_ys); - Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( - window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); + if constexpr(!static_move_ys) + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); } }); }); diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 70ad154556..4d0f92f3e0 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -432,6 +432,12 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, a_m_k_scaled(m, k) = a_f4_lo * a_scale; a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale; } + else + { + a_m_k_scaled(m, k) = + ck_tile::type_convert((a_m_k(m, k))) * + ck_tile::type_convert(scale_a(m, k / ScaleBlockSize)); + } } } diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index f60a7e1441..d1e6813b01 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -19,6 +19,7 @@ template <> struct typeToStr { static constexpr const char * name = "fp8" template <> struct typeToStr { static constexpr const char * name = "bf8"; }; template <> struct typeToStr { static constexpr const char * name = "int8"; }; template <> struct typeToStr { static constexpr const char * name = "pk_int4"; }; +template <> struct typeToStr { static constexpr const char * name = "pk_fp4"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index a807229d9b..3f2560587a 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel( - b_flat_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); + constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock; + constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); + constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; + const index_t kFlatKBlocks = kargs.K / kKPerBlock; + const index_t kFlatN = kargs.N / kNWarpTile; + const auto& b_flat_tensor_view = [&]() { + static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0, + "wrong! vector size for B tensor"); + auto&& naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(kFlatN, kFlatKBlocks, number{})); + auto&& desc = transform_tensor_descriptor( + naive_desc, + make_tuple(make_pass_through_transform(kFlatN), + make_merge_transform_v3_division_mod( + make_tuple(kFlatKBlocks, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_flat_ptr, desc); }(); const auto& ds_tensor_view = generate_tuple( diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 5bb5436edf..e6ff17952b 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -44,7 +44,10 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 else if(TailNumber::Odd == tail_num) return TailHandler(run_func, has_hot_loop); else + { assert(("Wrong TailNumber!", false)); + return decltype(TailHandler<>(run_func, true, TailNumber::Even)){}; + } } }; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index fbed495d25..95b4cfeaca 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -43,7 +43,7 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem @@ -122,9 +122,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1::PackedSize; static constexpr index_t BPackedSize = numeric_traits::PackedSize; - static constexpr index_t MXdlPack = Problem::MXdlPack; - static constexpr index_t NXdlPack = Problem::NXdlPack; - static constexpr index_t KXdlPack = Problem::KXdlPack; + static constexpr index_t MXdlPack = Problem::MXdlPack; + static constexpr index_t NXdlPack = Problem::NXdlPack; + static constexpr index_t KXdlPack = Problem::KXdlPack; + static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK; static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; @@ -138,25 +139,25 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 - CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem_ping, + void* __restrict__ p_smem_pong) const { #ifndef __gfx950__ static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now."); @@ -495,9 +494,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}], "wrong!"); - constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; - const index_t iMWarp = get_warp_id() / NWarp; - // const index_t iNWarp = get_warp_id() % NWarp; + // constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2); + static_assert(NWarp == 4); using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; @@ -506,6 +504,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; + auto a_dram_window = + make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor( + a_copy_dram_window_tmp.get_bottom_tensor_view()), + a_copy_dram_window_tmp.get_window_lengths(), + a_copy_dram_window_tmp.get_window_origin(), + PipelinePolicy::template MakeMXFP4_ADramTileDistribution()); + __builtin_amdgcn_sched_barrier(0); // A tile in LDS @@ -520,93 +525,51 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(p_a_lds_pong, a_lds_block_desc); - auto a_copy_lds_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - auto a_copy_lds_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); + auto a_store_lds_window_ping = make_tile_window( + a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}); + auto a_store_lds_window_pong = make_tile_window( + a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}); // ping-pong window for A LDS - auto a_warp_window_ping_tmp = + auto a_warp_window_ping = make_tile_window(a_lds_block_ping, make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, + {0, 0}, PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); - auto a_warp_window_pong_tmp = + auto a_warp_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, + {0, 0}, PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_ping; - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_pong; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - - auto packed_m_idx = mIter / number{}; - auto packed_m_rank = mIter % number{}; - - move_tile_window( - a_warp_windows_ping(mIter)(kIter), - {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM, - kIter * KPerBlockPerIter}); - move_tile_window( - a_warp_windows_pong(mIter)(kIter), - {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM, - kIter * KPerBlockPerIter}); - }); - }); - // Block GEMM auto block_flatmm = BlockFlatmm(); // Acc register tile auto c_block_tile = block_flatmm.MakeCBlockTile(); // B flat DRAM window for load - auto b_flat_distribution = - PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution(); - - auto b_flat_dram_window = make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - b_flat_distribution); - - using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window)); - // use v4i32 as the data type between basicblock to avoid unpack and repack operation. - using V4UInt_B_Buffer = thread_buffer; - union UnionBuf - { - V4UInt_B_Buffer u = 0; - MXFP4_B_Buffer mxfp4; - } ub; // pingpong buffer for B + auto b_flat_dram_windows = generate_tuple( + [&](auto nIter) { + constexpr auto packed_n_idx = nIter / number{}; + constexpr auto packed_n_rank = nIter % number{}; + auto window_i = make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution()); + move_tile_window( + window_i, + {number{}, + number<0>{}}); + return window_i; + }, + number{}); statically_indexed_array< - statically_indexed_array, + statically_indexed_array, NIterPerWarp> - b_flat_dram_windows; - statically_indexed_array, - NIterPerWarp> - b_warp_tensor_ping; - statically_indexed_array, - NIterPerWarp> - b_warp_tensor_pong; + b_warp_tensor_ping, b_warp_tensor_pong; // pingpong buffer for Scale A and Scale B auto scale_a_dram_window = make_tile_window( @@ -649,29 +612,24 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 scale_b_tile_tensor_pong; + auto async_load_tile_ = [](auto lds, auto dram) { + async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{}); + }; + // HEAD // Prefetch A0 - auto a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + async_load_tile_(a_store_lds_window_ping, a_dram_window); + move_tile_window(a_dram_window, {0, kKPerBlock}); // prefetch B static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, - kIter * KFlatPerBlockPerIter}); - - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = ub.u; + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); }); + // move B window to next flat K + move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter}); }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter}); // prefetch Scale A static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { @@ -700,71 +658,40 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{})(number<0>{}))); - // use v4i32 as the data type between basicblock to avoid unpack and repack operation. - using V4UInt_A_Buffer = thread_buffer; - union UnionBuf_A_ping - { - V4UInt_A_Buffer u = 0; - MXFP4_A_Buffer_ping mxfp4; - } ua_ping; - - using MXFP4_A_Buffer_pong = - decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))); - union UnionBuf_A_pong - { - V4UInt_A_Buffer u = 0; - MXFP4_A_Buffer_pong mxfp4; - } ua_pong; + statically_indexed_array a_warp_tensor; // preload A00,A10... from lds - statically_indexed_array a_warp_tensor; - + s_waitcnt_barrier(); static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MXdlPack; constexpr auto kIter = loadIter / MXdlPack; - ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number{})(number{})); - a_warp_tensor(loadIter) = ua_ping.u; + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_ping, tuple, number>{}); }); __builtin_amdgcn_sched_barrier(0); // MAIN LOOP - index_t iCounter = (num_loop - 1) / 2; - while(iCounter > 0) - { + auto main_body_implx2 = [&]() mutable { // prefetch B(2i+1) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, - kIter * KFlatPerBlockPerIter}); - - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_pong(nIter)(kIter) = ub.u; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); }); }); @@ -791,15 +718,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { @@ -807,30 +725,26 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; + // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - UnionBuf_A_ping ua_compute; - ua_compute.u = a_warp_tensor(number{}); - - UnionBuf ub_compute; - ub_compute.u = - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl); // warp GEMM WG{}.template operator()( c_warp_tensor, - ua_compute.mxfp4, - ub_compute.mxfp4, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -838,68 +752,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, + merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && (nIter_pack == NIterPerWarp / NXdlPack - 1)) { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - ua_ping.mxfp4 = load_tile( - a_warp_windows_ping(number{})(number{})); - a_warp_tensor(number{}) = ua_ping.u; - } - - // barrier - if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && - mIter_pack * MXdlPack + imxdl == MIter_2nd_last) - { - block_sync_lds(); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); } }); }); }); }); }); + // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished + s_waitcnt< // vmcnt + Bload_num + ScaleAload_num + ScaleBload_num>(); + block_sync_lds(); + + // Prefetch A(2i+2) + async_load_tile_(a_store_lds_window_ping, a_dram_window); + move_tile_window(a_dram_window, {0, kKPerBlock}); // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + // preload A(2i+1) static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MXdlPack; - constexpr auto kIter = loadIter / MXdlPack; - ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number{})(number{})); - a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_pong, tuple, number>{}); }); HotLoopScheduler(); - // Next K + ////////////////////////////// Next K ////////////////////////////// // prefetch B(2i+2) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, - kIter * KFlatPerBlockPerIter}); - - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = ub.u; + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); }); }); @@ -926,15 +832,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { @@ -953,20 +850,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); - UnionBuf_A_pong ua_compute; - ua_compute.u = a_warp_tensor(number{}); - - UnionBuf ub_compute; - ub_compute.u = - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl); - // warp GEMM WG{}.template operator()( c_warp_tensor, - ua_compute.mxfp4, - ub_compute.mxfp4, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -988,39 +878,47 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{})(number{})); - a_warp_tensor(number{}) = ua_pong.u; - } - - // barrier - if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && - mIter_pack * MXdlPack + imxdl == MIter_2nd_last) - { - block_sync_lds(); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); } }); }); }); }); }); + // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished + s_waitcnt< // vmcnt + Bload_num + ScaleAload_num + ScaleBload_num>(); + block_sync_lds(); + // Prefetch A(2i+3) + async_load_tile_(a_store_lds_window_pong, a_dram_window); + move_tile_window(a_dram_window, {0, kKPerBlock}); // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + // preload A(2i+2) static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MXdlPack; - constexpr auto kIter = loadIter / MXdlPack; - ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number{})(number{})); - a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_ping, tuple, number>{}); }); HotLoopScheduler(); + }; - iCounter--; + if constexpr(HasHotLoop) + { + index_t iCounter = (num_loop - 1) / 2; + do + { + main_body_implx2(); + iCounter--; + } while(iCounter > 0); } // TAIL @@ -1029,18 +927,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - auto packed_n_idx = nIter / number{}; - auto packed_n_rank = nIter % number{}; - - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window( - b_flat_dram_windows(nIter)(kIter), - {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, - kIter * KFlatPerBlockPerIter}); - - ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); - b_warp_tensor_pong(nIter)(kIter) = ub.u; + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), + make_tuple(number<0>{}, number{})); }); }); @@ -1055,7 +944,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter_pack) { static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; @@ -1067,10 +955,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { @@ -1089,20 +973,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); - UnionBuf_A_ping ua_compute; - ua_compute.u = a_warp_tensor(number{}); - - UnionBuf ub_compute; - ub_compute.u = - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl); - // warp GEMM WG{}.template operator()( c_warp_tensor, - ua_compute.mxfp4, - ub_compute.mxfp4, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -1124,30 +1001,28 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{})(number{})); - a_warp_tensor(number{}) = ua_ping.u; - } - - // barrier - if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && - mIter_pack * MXdlPack + imxdl == MIter_2nd_last) - { - block_sync_lds(); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); } }); }); }); }); }); + // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished + s_waitcnt< // vmcnt + Bload_num + ScaleAload_num + ScaleBload_num>(); + block_sync_lds(); + // preload A(2i+1) static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MXdlPack; - constexpr auto kIter = loadIter / MXdlPack; - ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number{})(number{})); - a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_pong, tuple, number>{}); }); Last2ndHotLoopScheduler(); @@ -1170,19 +1045,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); - UnionBuf_A_pong ua_compute; - ua_compute.u = a_warp_tensor(number{}); - - UnionBuf ub_compute; - ub_compute.u = - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl); // warp GEMM WG{}.template operator()( c_warp_tensor, - ua_compute.mxfp4, - ub_compute.mxfp4, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -1204,18 +1073,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{})(number{})); - a_warp_tensor(number{}) = ua_pong.u; - } - - // barrier - if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && - mIter_pack * MXdlPack + imxdl == MIter_2nd_last) - { - block_sync_lds(); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); } }); }); @@ -1244,20 +1106,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); - UnionBuf_A_ping ua_compute; - ua_compute.u = a_warp_tensor(number{}); - - UnionBuf ub_compute; - ub_compute.u = - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl); - // warp GEMM WG{}.template operator()( c_warp_tensor, - ua_compute.mxfp4, - ub_compute.mxfp4, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -1279,18 +1134,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{})(number{})); - a_warp_tensor(number{}) = ua_ping.u; - } - - // barrier - if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && - mIter_pack * MXdlPack + imxdl == MIter_2nd_last) - { - block_sync_lds(); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); } }); }); @@ -1299,32 +1147,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp, - const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp, - index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const - { - return operator()( - a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, - b_flat_dram_block_window_tmp, - scale_a_flat_window_tmp, - scale_b_flat_window_tmp, - num_loop, - p_smem_ping, - p_smem_pong); - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index f3fc5e9fef..c04622919d 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -13,22 +13,139 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; - static constexpr index_t KBPerLoad = 32; + static constexpr index_t kDramLoadPackBytes = 128; static constexpr int MXdlPack = 2; static constexpr int NXdlPack = 2; static constexpr int KXdlPack = 2; template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() - { - using namespace ck_tile; + static inline constexpr auto wg_attr_num_access = + std::is_same_v, pk_fp4_t> + ? WGAttrNumAccessEnum::Single + : WGAttrNumAccessEnum::Double; + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() + { + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + static_assert( + sizeof(ADataType) * numeric_traits::PackedSize == + sizeof(BDataType) * numeric_traits::PackedSize, + "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!"); + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher< // + ADataType, + BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + false, + wg_attr_num_access>; + using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< // + ADataType, + BDataType, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + return BlockFlatmmASmemBSmemCRegV1{}; + } + + template + CK_TILE_DEVICE static constexpr auto + MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view) + { using ADataType = remove_cvref_t; using ALayout = remove_cvref_t; constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + static_assert(MPerXdl == 16 && NPerXdl == 16); + static_assert(std::is_same_v); + const auto& naive_desc = naive_view.get_tensor_descriptor(); + constexpr auto ndims = remove_cvref_t::get_num_of_dimension(); + static_assert(ndims == 2, "only support 2D tensor"); + const auto rows = naive_desc.get_length(number<0>{}); + const auto cols = naive_desc.get_length(number<1>{}); + + constexpr index_t APackedSize = numeric_traits::PackedSize; + constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + const index_t K0 = cols / (K1 * K2); + const auto col_lens = make_tuple(K0, number{}, number{}); + + constexpr index_t M1 = 4; // so that we can use imm offset to load lds + const index_t M0 = rows / M1; + const auto row_lens = make_tuple(M0, number{}); + + const auto desc_0 = + make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(M0), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(K0), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); + const auto desc = transform_tensor_descriptor( // + desc_1, + make_tuple(make_merge_transform_v3_division_mod(row_lens), + make_merge_transform_v3_division_mod(col_lens)), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + // printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1)); + + return tensor_view, + TensorView::DstInMemOp>{naive_view.buf_, desc}; + } + + template + CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution() + { + + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t APackedSize = numeric_traits::PackedSize; + + constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + + constexpr index_t M2 = get_warp_size() / K1; // 8 + constexpr index_t M1 = BlockSize / get_warp_size(); // 4 + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); + static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence<1>, + tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 + tuple, sequence<1, 2>>, // M1 M2,K1 + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, // M0,K0,K2 + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); static_assert(MPerXdl == 16 && NPerXdl == 16); static_assert(std::is_same_v); @@ -36,65 +153,70 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t KPack = GetSmemPackA() * APackedSize; + constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, + constexpr index_t M3 = 4; // so that we can use imm offset to load lds + constexpr index_t M2 = get_warp_size() / K1 / M3; // 2 + constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 + constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16 + static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!"); + + constexpr index_t Pad = 4 * K2; // 4 * 32 + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, number<1>{}); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor( a_lds_block_desc_0, - make_tuple( - make_xor_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(M1), + make_pass_through_transform(K0), + make_pass_through_transform(M2), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4, 5>{}, + sequence<6>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4, 5>{}, + sequence<6>{})); constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), + a_lds_block_desc_1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}, number{})), make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}), make_tuple(sequence<0>{}, sequence<1>{})); // return a_lds_block_desc_permuted; return a_lds_block_desc; } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution() - { - using ADataType = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = MPerBlock / (M2 * M1); - static_assert(M0 * M1 * M2 == MPerBlock, - "Incorrect M0, M2, M1 configuration! " - "M0, M1, M2 must cover whole MPerBlock!"); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution() { @@ -105,20 +227,31 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr int M_warps = TileShape::BlockWarps::at(number<0>{}); constexpr int N_warps = TileShape::BlockWarps::at(number<1>{}); - constexpr int M_Lane = TileShape::WarpTile::at(I0); + constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16 - constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4 + constexpr int K_Lane = 64 / M_Lane; // 4 - constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32 + constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32 + constexpr index_t num_access_v = static_cast(wg_attr_num_access); + constexpr int K1 = K_Thread / num_access_v; // 16 return make_static_tile_distribution( - tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<0, 2>>, - sequence<2>, - sequence<1>>{}); + std::conditional_t< + num_access_v == 1, + tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<2>, + sequence<1>>, + tile_distribution_encoding< // + sequence, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 2>, + sequence<0, 2>>>{}); } template @@ -132,25 +265,36 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t K1 = WaveSize; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; + constexpr index_t K0 = KWavePerBlk; constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp - constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + constexpr index_t kKPerThread = 32; + constexpr index_t num_access_v = static_cast(wg_attr_num_access); + constexpr index_t K2 = kKPerThread / num_access_v; return make_static_tile_distribution( - tile_distribution_encoding< - sequence, - tuple, - sequence>, // first direction - // wave in blk, // thd in wave - // // - tuple, sequence<2>>, // which direction - tuple, sequence<1>>, // which index - // - sequence<2>, - sequence<2>>{}); + std::conditional_t< // + num_access_v == 1, + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 1 64 32 + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<2>, + sequence<2>>, + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 2 1 64 16 + tuple, sequence<2>>, + tuple, sequence<2>>, + sequence<2, 2>, + sequence<0, 3>>>{}); } template @@ -270,6 +414,21 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<2>, sequence<1>>{}); } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + using ADataType = remove_cvref_t; + constexpr index_t APackedSize = numeric_traits::PackedSize; + return sizeof(ADataType) * + MakeMXFP4_ALdsBlockDescriptor().get_element_space_size() / APackedSize; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return GetSmemSizeA(); + } }; } // namespace ck_tile