diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index c4366f6662..b7512b2999 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,54 +75,17 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 9cf43a986e..9a4ec64242 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -304,6 +304,14 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +template +struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template struct GemmTypeConfig; @@ -344,6 +352,24 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template <> struct GemmTypeConfig { diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 2b8f8b32ae..0f323cb0e3 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -91,7 +91,11 @@ int main(int argc, char* argv[]) try { +#if CK_TILE_USE_WMMA + return !run_gemm_example(arg_parser); +#else return !run_gemm_example(arg_parser); +#endif } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 49d9a34f17..cc980a75f7 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck_tile/host/permute_pk_int4.hpp" + template static constexpr inline auto is_row_major(Layout layout_) { @@ -90,61 +92,6 @@ void permute_tensor_b(Tensor& tensor) } } -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 6, i) = i4x2; - } - } - } -} - template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase + constexpr int divisor = 2; + constexpr int kABK0PerLane = 2; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + kABK0PerLane, + GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -399,7 +373,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, BLayout, CLayout>(b_k_n_dev); } - permute_vectors_i4x4_b(b_k_n_dev); + ck_tile::permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else diff --git a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh index 951f8aa63a..c2ee7a1c3e 100755 --- a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh +++ b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh @@ -5,7 +5,7 @@ KNAME=1 export CK_WARMUP=0 export CK_REPEAT=1 -COMMON_ARGS='-v=2 -warmup=0 -repeat=1' +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' run_tests() { for m in 512 1024; do @@ -32,5 +32,8 @@ run_tests "fp16" run_tests "bf16" run_tests "fp8" run_tests "bf8" +run_tests "fp16i4" +run_tests "fp8i4" +run_tests "bf8i4" set +x diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 8e0bc40494..f9a7263a5f 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -5,11 +5,8 @@ #include #include -#include #include -#include -#include "ck_tile/host.hpp" #include "gemm_utils.hpp" #include "run_gemm_example.inc" #include "run_gemm_example_common.hpp" @@ -58,7 +55,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) ck_tile::int8_t, ck_tile::int32_t>(a_layout, b_layout, arg_parser); } - else if(data_type == "pk_int4_t") + else if(data_type == "fp16i4") { // TODO: Add support for bhalf_t ADataType if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) @@ -74,6 +71,36 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) throw std::runtime_error("Unsupported pipeline for this operation !!!"); } } + else if(data_type == "fp8i4") + { + if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_example_prec_type, + Invoker, + ck_tile::fp8_t, + ck_tile::pk_int4_t, + ck_tile::half_t>(a_layout, b_layout, arg_parser); + } + else + { + throw std::runtime_error("Unsupported pipeline for this operation !!!"); + } + } + else if(data_type == "bf8i4") + { + if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_example_prec_type, + Invoker, + ck_tile::bf8_t, + ck_tile::pk_int4_t, + ck_tile::half_t>(a_layout, b_layout, arg_parser); + } + else + { + throw std::runtime_error("Unsupported pipeline for this operation !!!"); + } + } else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 511efeeaec..2ca5157eda 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -71,11 +71,11 @@ bool run(const ck_tile::ArgParser& arg_parser) constexpr bool kTwoPass = true; - using BlockWarps = ck_tile::sequence<2, 2>; - using BlockTile = ck_tile::sequence<2, 128>; - using WarpTile = ck_tile::sequence<1, 64>; - using Vector = ck_tile::sequence<1, 1>; - using Shape = ck_tile::Generic2dBlockShape; + using BlockTile = ck_tile::sequence<2, 128>; + using Vector = ck_tile::sequence<1, 1>; + using ThreadPerBlock = ck_tile::sequence<2, 128>; + + using Shape = ck_tile::Generic2dBlockShape; using PipelineTraits = ck_tile::Rmsnorm2dFwdTraits; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using ThreadPerBlock = ck_tile::sequence; + + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index faa134e5c4..b7bd7ac7df 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,55 +80,17 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; - static constexpr auto WarpSize = ck_tile::get_warp_size(); - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveX = kSaveX_; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index ace5fe0c4f..ca94bc1b71 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -99,12 +99,11 @@ bool run(const ck_tile::ArgParser& arg_parser) constexpr bool kThreePass = true; - using BlockWarps = ck_tile::sequence<4, 1>; - using BlockTile = ck_tile::sequence<4, 128>; - using WarpTile = ck_tile::sequence<1, 64>; - using Vector = ck_tile::sequence<1, 1>; + using BlockTile = ck_tile::sequence<4, 128>; + using Vector = ck_tile::sequence<1, 1>; + using ThreadPerBlock = ck_tile::sequence<4, 64>; - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem; - using BlockTile = ck_tile::sequence<2, 128>; - using WarpTile = ck_tile::sequence<1, 64>; - using Vector = ck_tile::sequence<1, 1>; + using BlockTile = ck_tile::sequence<2, 128>; + using Vector = ck_tile::sequence<1, 1>; + using ThreadPerBlock = ck_tile::sequence<2, 128>; - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::SmoothquantPipelineProblem; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp index 36cf477a42..eff4bef025 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -38,54 +38,17 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index f8e21d5ee4..1fb53909ac 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -190,6 +190,30 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr bool kPadK = true; }; +template +struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 32 / sizeof(PrecType); + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadK = true; + + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; +}; + template struct PipelineTypeTraits; @@ -266,16 +290,43 @@ template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase + constexpr int divisor = 2; + constexpr int kABK0PerLane = 2; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + kABK0PerLane, + GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template (argc, argv); +#else return !run_grouped_gemm_example(argc, argv); +#endif } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 93117e5b75..280da8d333 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -268,6 +268,9 @@ int main(int argc, char* argv[]) try { +#if defined(CK_TILE_USE_WMMA) + return !run_flatmm_example(argc, argv); +#else int warp_tile = arg_parser.get_int("warp_tile"); if(warp_tile == 0) { @@ -285,6 +288,7 @@ int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } +#endif } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 64e141860e..8f8f65e214 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -86,6 +86,14 @@ struct FlatmmConfig16_950 : public FlatmmConfig16 static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128; }; +template +struct FlatmmConfig16_Wmma : public FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template struct GemmBasicTypeConfig; @@ -183,8 +191,10 @@ auto create_args(int argc, char* argv[]) .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") +#if !defined(CK_TILE_USE_WMMA) .insert( "warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)") +#endif .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "flatmm_basic.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index b6b92b5801..63d0a80555 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -43,15 +43,40 @@ auto shuffle_b(const ck_tile::HostTensor& t) int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2) - : (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4); - ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, - FlatmmConfig::N_Warp_Tile, - k_ / FlatmmConfig::K_Warp_Tile, - divisor, - FlatmmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + if(ck_tile::is_gfx12_supported()) + { + // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase + constexpr int divisor = 2; + constexpr int kABK0PerLane = 2; + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / FlatmmConfig::K_Warp_Tile, + divisor, + kABK0PerLane, + FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / FlatmmConfig::K_Warp_Tile, + divisor, + FlatmmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp index f4e63d54da..52a815dfdb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp @@ -228,4 +228,4 @@ int run_gemm_example(int argc, char* argv[]) } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc index 5f1d528b2b..477cb125be 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc @@ -5,6 +5,7 @@ #pragma once #include #include +#include "../00_shared/host_tensor_utils.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -217,7 +218,16 @@ int run_gemm_example_with_layouts(int argc, aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); } - a_m_k_dev_buf.ToDevice(a_m_k.data()); + if constexpr(std::is_same_v) + { + ck_tile::HostTensor a_m_k_dev = a_m_k; + ck_tile::permute_vectors_i4x4_b(a_m_k_dev); + a_m_k_dev_buf.ToDevice(a_m_k_dev.data()); + } + else + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + } b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_bquant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_bquant_example.inc index 4500e2e874..f30039b1b2 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_bquant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_bquant_example.inc @@ -3,6 +3,7 @@ #pragma once #include +#include "ck_tile/host/permute_pk_int4.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -208,7 +209,17 @@ int run_gemm_example_with_layouts(int argc, ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + ck_tile::permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 28ae401717..54f139dd74 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -4,6 +4,7 @@ #pragma once #include #include +#include "ck_tile/host/permute_pk_int4.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -308,7 +309,17 @@ int run_gemm_example_with_layouts(int argc, aq_dev_buf.ToDevice(aq_tensor.data()); } - a_m_k_dev_buf.ToDevice(a_m_k.data()); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor a_m_k_dev = a_m_k; + ck_tile::permute_vectors_i4x4_b(a_m_k_dev); + a_m_k_dev_buf.ToDevice(a_m_k_dev.data()); + } + else + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + } b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index ad7956d32a..fc1caf13ff 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -125,7 +125,7 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_in float x_h = ((x_u8 & 0xf0) >> 4); x_l = x_l > 7 ? x_l - 16 : x_l; - x_h = x_l > 7 ? x_l - 16 : x_l; + x_h = x_h > 7 ? x_h - 16 : x_h; #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE fp32x2_t res = {x_h, x_l}; diff --git a/include/ck_tile/host/permute_pk_int4.hpp b/include/ck_tile/host/permute_pk_int4.hpp new file mode 100644 index 0000000000..b770edddca --- /dev/null +++ b/include/ck_tile/host/permute_pk_int4.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c), Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "ck_tile/core/utility/bit_cast.hpp" +namespace ck_tile { + +/** + * @brief Permute packed int4 vectors for device implementation compatibility + * + * This function transforms 4 pk_int4_t values from original layout to hardware-optimized layout: + * - Original layout (4 pk_int4_t): 0x76543210 + * - Transformed layout (4 pk_int4_t): 0x75316420 + * + * Each pk_int4_t contains two 4-bit values packed in the high and low nibbles of an int8_t + * + * Example: + * - Input: 0x76, 0x54, 0x32, 0x10 + * - Output: 0x75, 0x31, 0x64, 0x20 + * + * @note Input tensor length must be a multiple of 4 + * + * This transformation is required before transferring B matrix data (of type pk_int4_t) to device. + * The device conversion functions (i4_to_half4, i4_to_bhalf4, amd_assembly_i4_to_fp8x8, + * amd_assembly_i4_to_bf8x8) require data in 0x75316420 order to correctly convert pk_int4_t to + * other numeric types. + */ +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + auto tensor_row_buf = tensor.data(); + for(size_t idx = 0; idx < tensor.size(); idx += 4) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = bit_cast(tensor_row_buf[idx + k]); + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 0x76543210 => 0x75316420 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor_row_buf[idx + 0] = bit_cast(i4x2); + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor_row_buf[idx + 1] = bit_cast(i4x2); + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor_row_buf[idx + 2] = bit_cast(i4x2); + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor_row_buf[idx + 3] = bit_cast(i4x2); + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 5a26cfbcd2..caa00e5994 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -50,7 +50,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor& a_m_k, if constexpr(std::is_same_v) { const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); if(k % 2 == 1) v_a = fp32_val.hi; else @@ -63,7 +63,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor& a_m_k, if constexpr(std::is_same_v) { const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); if(k % 2 == 1) v_b = fp32_val.hi; else diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index c0bfd93198..333762e5d7 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,43 +35,69 @@ namespace ck_tile { +-----------+-----------+-----------+-----------+-----------+ // clang-format on */ -template - typename WarpPerBlock_, // num warps along seq - typename WarpTile_, // warp size, seq - typename Vector_> // contiguous pixels(vector size) along seq)> +template + typename ThreadPerBlock_, // num threads along seq + typename Vector_> // contiguous pixels(vector size) along seq)> struct Generic2dBlockShape { // block size - static constexpr index_t Block_M = BlockTile_::at(number<0>{}); - static constexpr index_t Block_N = BlockTile_::at(number<1>{}); - - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); - static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); - - // warp size - static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); - static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); - - static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); - static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); - // repeat of each thread along seq - static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); - static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + static constexpr index_t Block_M = BlockTile_::at(number<0>{}); + static constexpr index_t Block_N = BlockTile_::at(number<1>{}); + static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{}); + static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{}); + static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; // vector size along seq static constexpr index_t Vector_M = Vector_::at(number<0>{}); static constexpr index_t Vector_N = Vector_::at(number<1>{}); + static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size(); + static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0); + static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size(); + + // num warps along seq, within each block + static constexpr index_t WarpPerBlock_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(get_warp_size() % ThreadPerBlock_N == 0); + return total_warps * (get_warp_size() / ThreadPerBlock_N); + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N / get_warp_size()); + } + }(); + + // num of warps along n + static constexpr index_t WarpPerBlock_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(get_warp_size() % ThreadPerBlock_N == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N % get_warp_size() == 0); + return ThreadPerBlock_N / get_warp_size(); + } + }(); + + // warp size + static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; + static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); - // num of threads along seq, within each warp - static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; - static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M; - static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N; + static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); + static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); - static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; + // repeat of each thread along seq + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + // num of threads along seq, within each warp + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index c5cbca4a87..9e3ccb025d 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -4,15 +4,29 @@ #pragma once #include "ck_tile/core.hpp" +#include #include namespace ck_tile { namespace element_wise { -// Fast int4x4 to fp16x8_t data type conversion based on paper -// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] -// (https://arxiv.org/abs/2211.10017) and implementation: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +/** + * @brief Fast int4x4 to fp16x8_t data type conversion based on paper + * "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production" + * @see https://arxiv.org/abs/2211.10017 + * @see + * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + * + * This function converts 4 4-bit integers into 4 fp16 values. + * @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4. + * @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8) + * @note The output ordering differs from input ordering. For example, when input is 0x76543210, + * the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor + * must be preprocessed with permute_vectors_i4x4_b on the host side before using this + * function. + * + * @see permute_vectors_i4x4_b + */ CK_TILE_DEVICE fp16x4_t i4_to_half4(int q) { const int LO = 0x000f000f; @@ -46,6 +60,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4(int q) return res; } +/** + * @brief This function dequantizes 4 int4 values into 4 fp16 values and applies scaling. + * + * @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4. + * @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8) + * @note The output ordering differs from input ordering. For example, when input is 0x76543210, + * the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor + * must be preprocessed with permute_vectors_i4x4_b on the host side before using this + * function. + * + * @see permute_vectors_i4x4_b + */ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale) { const int LO = 0x000f000f; @@ -81,6 +107,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale) return res; } +/** + * @brief This function converts 4 4-bit integers into 4 bf16 values. + * + * @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4. + * @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf16(-8) + * @note The output ordering differs from input ordering. For example, when input is 0x76543210, + * the output sequence will be bf16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor + * must be preprocessed with permute_vectors_i4x4_b on the host side before using this + * function. + * + * @see permute_vectors_i4x4_b + */ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) { uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); @@ -110,37 +148,55 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) return res; } +/** + * @brief This function converts 8 packed 4-bit integers into 8 fp8 values. + * + * @note `int q` contains 4 bytes, each byte represents 2 int4. + * @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp8(-8) + * @note The output ordering differs from input ordering. For example, when input is 0x76543210, + * the output sequence will be fp8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor + * must be preprocessed with permute_vectors_i4x4_b on the host side before using this + * function. + * + * @see permute_vectors_i4x4_b + */ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) { - uint32_t src = static_cast(a), src_hi; - uint32_t fp8x4_lo, fp8x4_hi; - float tmp_0, tmp_1; + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0xd2d4d6d8; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0xc0c8ccd0; + // register values [-1, -2, -3, -4] + static constexpr uint32_t reg2 = 0x4C484000; + // register values [-5, -6, -7, -8] + static constexpr uint32_t reg3 = 0x56545250; - asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" - "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" - "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + uint32_t dict_sel = a & 0x07070707; + uint32_t sign = a >> 1; + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(final_sel) + : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" - "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" - "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" - : [v_tmp_0] "+v"(tmp_0), - [v_tmp_1] "+v"(tmp_1), - [v_hi_src] "+v"(src_hi), - [v_dst_lo] "+v"(fp8x4_lo), - [v_dst_hi] "+v"(fp8x4_hi), - [v_src] "+v"(src) - :); + a >>= 4; + dict_sel = a & 0x07070707; + sign = a >> 1; + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(final_sel) + : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - return bit_cast(((static_cast(fp8x4_hi) << 32) | fp8x4_lo)); + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); + auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200); + auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301); + + return bit_cast((static_cast(tmp_res_high) << 32) | tmp_res_low); } CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src) @@ -157,37 +213,55 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) return res; } -CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a) +/** + * @brief This function converts 8 packed 4-bit integers into 8 bf8 values. + * + * @note `int q` contains 4 bytes, each byte represents 2 int4. + * @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf8(-8) + * @note The output ordering differs from input ordering. For example, when input is 0x76543210, + * the output sequence will be bf8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor + * must be preprocessed with permute_vectors_i4x4_b on the host side before using this + * function. + * + * @see permute_vectors_i4x4_b + */ +CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) { - uint32_t src = static_cast(a), src_hi; - uint32_t bf8x4_lo, bf8x4_hi; - float tmp_0, tmp_1; + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0Xc9cacbcc; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0Xc0c4c6c8; + // register values [11, 10, 9, 8] + static constexpr uint32_t reg2 = 0X46444000; + // register values [15, 14, 13, 12] + static constexpr uint32_t reg3 = 0X4b4a4948; - asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" - "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" - "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + uint32_t dict_sel = a & 0x07070707; + uint32_t sign = a >> 1; + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(final_sel) + : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" - "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); - "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" - "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" - "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" - : [v_tmp_0] "+v"(tmp_0), - [v_tmp_1] "+v"(tmp_1), - [v_hi_src] "+v"(src_hi), - [v_dst_lo] "+v"(bf8x4_lo), - [v_dst_hi] "+v"(bf8x4_hi), - [v_src] "+v"(src) - :); + a >>= 4; + dict_sel = a & 0x07070707; + sign = a >> 1; + asm volatile("v_and_or_b32 %0, %1, %2, %3" + : "=v"(final_sel) + : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - return bit_cast(((static_cast(bf8x4_hi) << 32) | bf8x4_lo)); + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); + auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200); + auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301); + + return bit_cast((static_cast(tmp_res_high) << 32) | tmp_res_low); } struct PassThroughPack8 @@ -209,12 +283,12 @@ struct PassThroughPack8 CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const { - y = amd_assembly_i4_to_fp8x8(bit_cast(x)); + y = amd_assembly_i4_to_fp8x8(bit_cast(x)); } CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const { - y = amd_assembly_i4_to_bf8x8(bit_cast(x)); + y = amd_assembly_i4_to_bf8x8(bit_cast(x)); } constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 20ca976590..a924279d52 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -127,7 +127,10 @@ struct FlatmmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize); + } CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 1a28366e24..0cae1a467d 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -185,11 +185,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV } template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 3ca79fc46e..5fd1fb8d39 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -237,8 +237,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; +#if defined(__gfx11__) + constexpr index_t scale = 4; +#else constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; +#endif if constexpr(TileShape::WarpTile::at(I1) == 32) { return TileShape::WarpTile::at(I2) * scale / 2; @@ -342,7 +346,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -350,8 +354,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = GetKBPerLoad(); - constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = 1; static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); @@ -362,16 +371,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t NRepeat = 1; constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - return make_static_tile_distribution( tile_distribution_encoding< - sequence, // ? + sequence, // ? tuple, // second direction sequence>, // first direction // wave in blk, // thd in wave // // - tuple, sequence<1, 2>>, // which direction - tuple, sequence<2, 2>>, // which index + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index // sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 8b95639516..71ca907c07 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -89,14 +89,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { using TileShape = typename Problem::BlockGemmShape; +#if defined(__gfx11__) + constexpr index_t scale = 4; +#else + constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; +#endif if constexpr(TileShape::WarpTile::at(I1) == 32) { - return TileShape::WarpTile::at(I2) / 2; + return TileShape::WarpTile::at(I2) * scale / 2; } else { static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) / 4; + return TileShape::WarpTile::at(I2) * scale / 4; } } @@ -192,7 +197,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -200,8 +205,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = GetKBPerLoad(); - constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = 1; static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); @@ -212,16 +222,15 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t NRepeat = 1; constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - return make_static_tile_distribution( tile_distribution_encoding< - sequence, // ? + sequence, // ? tuple, // second direction sequence>, // first direction // wave in blk, // thd in wave // // - tuple, sequence<1, 2>>, // which direction - tuple, sequence<2, 2>>, // which index + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index // sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp index b91c211d91..290f24a7f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp @@ -189,11 +189,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 } template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 7104e318d2..129eac6557 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -146,10 +146,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t mfma_per_wg = 1; #endif static constexpr index_t dsread_per_wg = - WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize; - static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0); - - static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp; + max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); +#if defined(__HIP_DEVICE_COMPILE__) + static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + Problem::VectorLoadSize == + 0); +#endif + static constexpr index_t dsread_num_perK = + WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize; static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; static constexpr index_t Aload_num_perK = dswrite_num_perK; @@ -499,12 +503,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction> - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const { static_assert( std::is_same_v>, diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 614245f05b..b41f01b951 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -181,9 +181,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase static constexpr index_t MWarp = Traits::MWarp; static constexpr index_t NWarp = Traits::NWarp; - static constexpr auto Scheduler = Traits::Scheduler; - static constexpr uint8_t kA_cvt_scale = std::is_same_v ? 16 : 1; - static constexpr uint8_t kB_cvt_scale = std::is_same_v ? 16 : 1; + static constexpr auto Scheduler = Traits::Scheduler; using AWarpDstr = typename WarpGemm::AWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr; @@ -451,7 +449,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += (c_warp_tensor.get_thread_buffer()[c_row] * - scale_reg_f * kA_cvt_scale * kB_cvt_scale); + scale_reg_f); }); } } @@ -471,7 +469,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase [&](auto c_row) { c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += (c_warp_tensor.get_thread_buffer()[c_row] * - scale_reg_f * kA_cvt_scale * kB_cvt_scale); + scale_reg_f); }); } else @@ -556,7 +554,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase reg_offset_for_row_data] += (c_warp_tensor .get_thread_buffer()[reg_offset_for_row_data] * - scale_reg_f * kA_cvt_scale * kB_cvt_scale); + scale_reg_f); }); } } diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 844c8f6eb0..7e28ea8fa9 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -179,9 +179,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase static constexpr index_t MWarp = Traits::MWarp; static constexpr index_t NWarp = Traits::NWarp; - static constexpr auto Scheduler = Traits::Scheduler; - static constexpr uint8_t kA_cvt_scale = std::is_same_v ? 16 : 1; - static constexpr uint8_t kB_cvt_scale = std::is_same_v ? 16 : 1; + static constexpr auto Scheduler = Traits::Scheduler; using AWarpDstr = typename WarpGemm::AWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr; @@ -384,8 +382,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) { c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f * - kA_cvt_scale * kB_cvt_scale); + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); }); }); }); diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index faa134e5c4..b7bd7ac7df 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,55 +80,17 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; - static constexpr auto WarpSize = ck_tile::get_warp_size(); - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return total_warps * (WarpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(WarpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / WarpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(WarpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % WarpSize == 0); - return ThreadPerBlock_N_ / WarpSize; - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveX = kSaveX_; diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc index a870028a37..dbe652ac62 100644 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -14,6 +14,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host.hpp" #include "test_gemm_aquant_utils.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" template ) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor a_m_k_dev = a_m_k; + ck_tile::permute_vectors_i4x4_b(a_m_k_dev); + a_m_k_dev_buf.ToDevice(a_m_k_dev.data()); + } + else + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + } aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); diff --git a/test/ck_tile/layernorm2d/generate.py b/test/ck_tile/layernorm2d/generate.py index c4366f6662..f7446c0148 100644 --- a/test/ck_tile/layernorm2d/generate.py +++ b/test/ck_tile/layernorm2d/generate.py @@ -75,54 +75,17 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using ThreadPerBlock = ck_tile::sequence; + + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; diff --git a/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp b/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp index ced9b4ef3d..40b09dca00 100644 --- a/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp +++ b/test/ck_tile/moe_smoothquant/moe_smoothquant.hpp @@ -38,54 +38,16 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 1a1c842b3c..5eded8b310 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -74,54 +74,17 @@ struct rmsnorm2d_fwd_traits_ using YScaleDataType = ck_tile::remove_cvref_t; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using ThreadPerBlock = ck_tile::sequence; + + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; diff --git a/test/ck_tile/smoothquant/smoothquant.hpp b/test/ck_tile/smoothquant/smoothquant.hpp index b1d5dae3d3..ef0c36aa98 100644 --- a/test/ck_tile/smoothquant/smoothquant.hpp +++ b/test/ck_tile/smoothquant/smoothquant.hpp @@ -49,54 +49,17 @@ struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); - } - else - { - // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); - return ThreadPerBlock_N_ / ck_tile::get_warp_size(); - } - }(); - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + using BlockTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + using ThreadPerBlock = ck_tile::sequence; - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_;