From cc1392a4056bbac2636f283dafd4de4aea825a0e Mon Sep 17 00:00:00 2001 From: Joseph Macaranas <145489236+jayhawk-commits@users.noreply.github.com> Date: Fri, 2 Jan 2026 14:23:43 -0500 Subject: [PATCH 01/23] Update TheRock CI SHA 20260102 (#3506) - TheRock CI compilation passed with the changes. --- .github/workflows/therock-ci-linux.yml | 7 ++++--- .github/workflows/therock-test-component.yml | 2 +- .github/workflows/therock-test-packages.yml | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 0baa503334..cc6178b08c 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -54,7 +54,7 @@ jobs: with: repository: "ROCm/TheRock" path: "TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: Setup ccache run: | @@ -78,8 +78,9 @@ jobs: run: | git config --global --add safe.directory '*' # Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo - rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch - git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch + rm ./TheRock/patches/amd-mainline/rocm-libraries/0003-Find-rocm_smi-via-config-files.patch + rm ./TheRock/patches/amd-mainline/rocm-libraries/0007-Remove-Windows-third_party_dlls-copying-code.patch + # git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps run: | diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 565d1d3e54..74f3bb0017 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,7 +51,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index cd255a40b6..e4bd295c95 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: "Configuring CI options" env: From 1da340031c98bfde0f142bf34493d087490ec70d Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 2 Jan 2026 11:36:42 -0800 Subject: [PATCH 02/23] Enable math defines for MSVC. (#3503) The symbol M_PI is breaking the build on Windows. The _USE_MATH_DEFINES macro enables M_PI and other math constants on Windows. (I'm guessing this is more idomatic than the old trick of using PI=acos(-1.0).) https://learn.microsoft.com/en-us/cpp/c-runtime-library/math-constants?view=msvc-170 Co-authored-by: BradPepersAMD --- include/ck/library/utility/device_tensor_generator.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ck/library/utility/device_tensor_generator.hpp b/include/ck/library/utility/device_tensor_generator.hpp index 4da38bf399..ede6d131e7 100644 --- a/include/ck/library/utility/device_tensor_generator.hpp +++ b/include/ck/library/utility/device_tensor_generator.hpp @@ -7,6 +7,7 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/device_tensor_generator.hpp" #include "ck/utility/data_type.hpp" +#define _USE_MATH_DEFINES // Required for M_PI in MSVC #include // use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in From 355ce9230d9c4f2e74776e879f2bee71a26bae4a Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 2 Jan 2026 14:21:46 -0800 Subject: [PATCH 03/23] Remove non-standard M_PI (#3507) Just use PI=acos(-1.0) as a local static constexpr. This has been causing build issues on windows. --- include/ck/library/utility/device_tensor_generator.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck/library/utility/device_tensor_generator.hpp b/include/ck/library/utility/device_tensor_generator.hpp index ede6d131e7..80a91adda4 100644 --- a/include/ck/library/utility/device_tensor_generator.hpp +++ b/include/ck/library/utility/device_tensor_generator.hpp @@ -7,7 +7,6 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/device_tensor_generator.hpp" #include "ck/utility/data_type.hpp" -#define _USE_MATH_DEFINES // Required for M_PI in MSVC #include // use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in @@ -108,6 +107,7 @@ template __global__ void fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size) { + static constexpr PI = std::acos(-1.0); // initial values ran_state_u32 s = ran_init(); float norm[2]; @@ -119,9 +119,9 @@ fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_e float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); norm[0] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * M_PI * u2) + mean; + sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * PI * u2) + mean; norm[1] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * M_PI * u2) + mean; + sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * PI * u2) + mean; } if constexpr(ck::is_same_v) From 4670df5ca606e6e3ee07a085ea61016489bf91ad Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 2 Jan 2026 16:58:35 -0800 Subject: [PATCH 04/23] [CK_BUILDER] Remove cmath include (#3508) Remove the dependency from device_tensor_generator.hpp and fix a typo from a previous force push. The changes replace standard library math functions with their ck::math equivalents and define PI as a local constant instead of computing it using std::acos. Key changes: * Removed #include header dependency * Replaced std::acos(-1.0) with hardcoded PI constant 3.141592653f * Replaced std::sqrt, std::cos, and std::sin with ck::math equivalents --- .../ck/library/utility/device_tensor_generator.hpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/include/ck/library/utility/device_tensor_generator.hpp b/include/ck/library/utility/device_tensor_generator.hpp index 80a91adda4..60bc3110d4 100644 --- a/include/ck/library/utility/device_tensor_generator.hpp +++ b/include/ck/library/utility/device_tensor_generator.hpp @@ -7,7 +7,6 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/device_tensor_generator.hpp" #include "ck/utility/data_type.hpp" -#include // use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in // the future @@ -107,7 +106,7 @@ template __global__ void fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size) { - static constexpr PI = std::acos(-1.0); + static constexpr float PI = 3.141592653f; // initial values ran_state_u32 s = ran_init(); float norm[2]; @@ -116,12 +115,11 @@ fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_e { if(j % (2 / ck::packed_size_v) == 0) { - float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); - float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); - norm[0] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * PI * u2) + mean; - norm[1] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * PI * u2) + mean; + float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + float scale = sigma * ck::math::sqrt(-2.0f * ck::math::log(u1)); + norm[0] = scale * ck::math::cos(2.0f * PI * u2) + mean; + norm[1] = scale * ck::math::sin(2.0f * PI * u2) + mean; } if constexpr(ck::is_same_v) From ec23be0b9d45ff9ca4135090bcd0269184c953a7 Mon Sep 17 00:00:00 2001 From: John Afaganis Date: Fri, 2 Jan 2026 22:16:41 -0700 Subject: [PATCH 05/23] Update unsigned long literals and format specifiers to work correctly in Windows (#3483) Previously, the code used unsigned long for literals and format specifiers to represent 64-bit unsigned values. While this worked on Linux, it caused compatibility issues on Windows. The C++ standard does not guarantee that long is 64 bits. On LP64 systems (e.g., Linux), long maps to 64-bit values, but on LLP64 systems (e.g., Windows), long maps to 32-bit values. This discrepancy led to incorrect behavior when assuming unsigned long was always 64-bit. This commit updates all relevant literals and format specifiers to explicitly use 64-bit unsigned types, ensuring consistent behavior across platforms. --- example/ck_tile/13_moe_sorting/moe_sorting.cpp | 4 ++-- include/ck_tile/host/fill.hpp | 7 ++++--- .../reference_grouped_conv_bwd_data.hpp | 3 ++- .../gemm_universal_pipeline_ag_bg_cr_policy.hpp | 16 ++++++++++------ .../moe_sorting/test_moe_sorting_util.hpp | 4 ++-- test/ck_tile/utility/test_fill.cpp | 2 ++ 6 files changed, 22 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index d9cb54cf74..a98faf5840 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -334,13 +334,13 @@ bool test_moe_sorting(ck_tile::ArgParser args) if(moe_buf_bytes > 0) { #if MOE_SORTING_FMOE_2D_BUF - printf("moe_buf:%lu(%d,%d), ", + printf("moe_buf:%" PRIu64 "(%d,%d), ", static_cast(moe_buf_bytes), moe_buf_interm_dim, moe_buf_elem_bytes); #else - printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); + printf("moe_buf:%" PRIu64 ", ", static_cast(moe_buf_bytes)); #endif } diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 4bbf8cbf3f..bddc0ae2d2 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -55,9 +55,10 @@ struct FillUniformDistribution const auto total_bytes = total * sizeof(T_iter); // max 80 threads; at least 2MB per thread - const size_t available_cpu_cores = get_available_cpu_cores(); - const size_t num_thread = - min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL)); + const size_t available_cpu_cores = get_available_cpu_cores(); + constexpr uint64_t MAX_THREAD_COUNT = 80; + const size_t num_thread = min( + MAX_THREAD_COUNT, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL)); constexpr size_t BLOCK_BYTES = 64; constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter); const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES); diff --git a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp index e141d842dd..95ab1258d6 100644 --- a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp +++ b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -28,7 +29,7 @@ CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor& input, output.get_num_of_dimension() == NDimSpatial + 3)) { - printf("%lu %lu %lu", + printf("%" PRIu64 " %" PRIu64 " %" PRIu64, input.get_num_of_dimension(), weight.get_num_of_dimension(), output.get_num_of_dimension()); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index a45d41189b..d68da14ac5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -246,9 +246,11 @@ struct UniversalGemmBasePolicy } else // A is in RowMajor { - constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = - max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + max(MinLdsLayer, + get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); constexpr index_t NBanks = get_n_lds_banks(); static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); @@ -442,11 +444,13 @@ struct UniversalGemmBasePolicy } else // B is Column Major { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto NLdsLayer = - max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + max(MinLdsLayer, + get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); constexpr index_t NBanks = get_n_lds_banks(); static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); diff --git a/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp b/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp index 37377755ea..de06669063 100644 --- a/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp +++ b/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp @@ -236,13 +236,13 @@ class TestCkTileMoeSorting : public ::testing::Test if(moe_buf_bytes > 0) { #if MOE_SORTING_FMOE_2D_BUF - printf("moe_buf:%lu(%d,%d), ", + printf("moe_buf:%" PRIu64 "(%d,%d), ", static_cast(moe_buf_bytes), moe_buf_interm_dim, moe_buf_elem_bytes); #else - printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); + printf("moe_buf:%" PRIu64 ", ", static_cast(moe_buf_bytes)); #endif } diff --git a/test/ck_tile/utility/test_fill.cpp b/test/ck_tile/utility/test_fill.cpp index 3633f8bbff..f67dee9757 100644 --- a/test/ck_tile/utility/test_fill.cpp +++ b/test/ck_tile/utility/test_fill.cpp @@ -26,6 +26,7 @@ using TestTypes = ::testing::Types; TYPED_TEST_SUITE(FillUniformDistributionTest, TestTypes); // Test that multiple runs with the same seed produce identical results +#ifndef _WIN32 TYPED_TEST(FillUniformDistributionTest, ConsistencyWithSameSeed) { using T = TypeParam; @@ -53,6 +54,7 @@ TYPED_TEST(FillUniformDistributionTest, ConsistencyWithSameSeed) << "First and second fill should be identical"; } } +#endif // Test consistency across different data sizes (which affects threading) TYPED_TEST(FillUniformDistributionTest, ConsistencyAcrossSizes) From e339101e9c9961fe1bc8305d5c316b39d1980d3e Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Sun, 4 Jan 2026 03:28:14 -0800 Subject: [PATCH 06/23] [CK-Tile] move out memory operation from cshuffle epilogue class (#3359) * initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning --- .../ck_tile/03_gemm/gemm_basic_invoker.hpp | 159 ++-- .../03_gemm/gemm_splitk_two_stage_invoker.hpp | 250 +++--- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 177 ++-- example/ck_tile/03_gemm/gemm_utils.hpp | 6 - .../gemm_weight_preshuffle_invoker.hpp | 185 ++--- .../03_gemm/universal_gemm_invoker.hpp | 168 ++-- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 83 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 214 ++--- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 3 +- .../17_grouped_gemm/grouped_gemm_multi_d.cpp | 214 ++--- .../grouped_gemm_preshuffle.cpp | 211 ++--- .../quant_invoke_grouped_gemm_kernel.hpp | 159 ++-- .../run_grouped_gemm_example.inc | 5 +- .../run_grouped_gemm_multi_d_example.inc | 26 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 30 +- example/ck_tile/18_flatmm/grouped_flatmm.cpp | 30 +- .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 30 +- .../mixed_prec/mixed_prec_flatmm.cpp | 30 +- example/ck_tile/18_flatmm/moe_flatmm.cpp | 36 +- .../18_flatmm/mxgemm/mx_flatmm_instance.hpp | 4 +- .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 104 +-- ...uped_convolution_backward_data_invoker.hpp | 146 ++-- ...ed_convolution_backward_weight_invoker.hpp | 159 ++-- ...tion_backward_weight_two_stage_invoker.hpp | 260 +++--- .../grouped_convolution_forward_invoker.hpp | 135 ++- ...nvolution_forward_large_tensor_invoker.hpp | 14 +- .../grouped_convolution_utils.hpp | 5 - .../22_gemm_multi_abd/gemm_multi_abd_fp16.cpp | 78 +- .../run_gemm_quant_example.inc | 93 +-- .../40_streamk_gemm/streamk_gemm_basic.cpp | 163 ++-- .../batched_contraction.cpp | 98 +-- .../builder/factory/conv_tile_factory.hpp | 1 - .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 1 - .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 1 - .../ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp | 1 - .../test/test_bwd_data_instance_traits.cpp | 1 - .../test/test_bwd_weight_instance_traits.cpp | 1 - .../builder/test/test_fwd_instance_traits.cpp | 1 - .../ops/epilogue/cshuffle_epilogue.hpp | 102 ++- .../ops/epilogue/default_2d_epilogue.hpp | 38 +- .../ops/flatmm/kernel/flatmm_kernel.hpp | 468 ++++++----- .../kernel/mixed_prec_flatmm_kernel.hpp | 366 +++++---- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 3 +- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 407 +++++---- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 129 ++- .../streamk_gemm/streamk_gemm_kernel.hpp | 54 +- .../streamk_gemm_tile_partitioner.hpp | 3 + .../ops/gemm/kernel/universal_gemm_kernel.hpp | 498 ++++++----- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 770 ++++++++++++++++-- .../kernel/grouped_gemm_quant_kernel.hpp | 165 ++-- ...ouped_convolution_backward_data_kernel.hpp | 204 ++++- ...ped_convolution_backward_weight_kernel.hpp | 265 +++--- .../grouped_convolution_forward_kernel.hpp | 243 +++--- .../batched_gemm/test_batched_gemm_util.hpp | 83 +- .../epilogue/test_cshuffle_epilogue_util.hpp | 4 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 106 +-- .../test_gemm_quant_fixtures.hpp | 11 +- .../test_gemm_multi_abd_util.hpp | 127 ++- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 127 ++- .../gemm_streamk/test_gemm_streamk_util.hpp | 103 ++- .../test_gemm_pipeline_util.hpp | 94 +-- .../test_ck_tile_grouped_conv_bwd_weight.cpp | 24 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 200 ++--- .../test_grouped_gemm_multi_d_util.hpp | 227 +++--- .../test_grouped_gemm_preshuffle_util.hpp | 173 ++-- .../test_grouped_gemm_util_quant.hpp | 187 ++--- tile_engine/ops/gemm/gemm_instance_builder.py | 8 +- .../gemm_streamk_instance_builder.py | 25 +- 68 files changed, 4198 insertions(+), 4298 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp index 77a9fe4271..df8351602b 100644 --- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp @@ -69,107 +69,88 @@ struct BasicInvoker using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem::TransposeC, - memory_operation>>; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index c312a53c2a..d2460193d8 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmKernel = ck_tile::GemmKernel; - using GemmKernel = ck_tile::GemmKernel; + ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); + ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); + auto c_ptr = ws_args.c_ptr; + ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); - ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); - ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); - auto c_ptr = ws_args.c_ptr; - ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) + : GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); - const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) - : GemmKernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernel::BlockSize(); + if(!GemmKernel::IsSupportedArgument(gemm_kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!GemmKernel::IsSupportedArgument(gemm_kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = {args.M, args.N}; - ck_tile::index_t total_elements = 1; - std::vector shape = {args.M, args.N}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); + auto input_size = ck_tile::make_tuple(args.M, args.N); - auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); - auto input_size = ck_tile::make_tuple(args.M, args.N); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - gemm_kargs.as_ptr[0], - gemm_kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel( - GemmKernel{}, grids, blocks, 0, gemm_kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(args.N, 1), // Input Stride - ck_tile::make_tuple(args.N, 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + gemm_kargs.as_ptr[0], + gemm_kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel( + GemmKernel{}, grids, blocks, 0, gemm_kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(args.N, 1), // Input Stride + ck_tile::make_tuple(args.N, 1), // Output Stride + input_tensors, + static_cast(c_ptr))); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index c06dc457c9..64305b85cf 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& args.stride_E); constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&]() { - // use SET operation since each K-split writes to separate memory - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = + ck_tile::CShuffleEpilogue>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(base_args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(base_args); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + const dim3 blocks = Kernel::BlockSize(); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - ck_tile::RotatingMemWrapper rotating_mem( - kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - return ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - }; - - return Run(); + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + return ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } } /** diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index f79494a478..8eff0e7469 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -460,12 +460,6 @@ inline auto create_args() return arg_parser; } -// Type aliases for memory operation integral constants -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - // host API template ::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" - << std::endl; - } - float ave_time = 0.f; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper rotating_mem(kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = - ck_tile::launch_kernel_time_mask(s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; - - if(args.k_batch == 1) + dim3 grids; + if constexpr(Persistent) { - return Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - throw std::runtime_error("split-k is not supported yet!"); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl; + } + float ave_time = 0.f; + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; } }; diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 4a83a2c4ab..fb89e6b4cc 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -60,112 +60,94 @@ struct UniversalInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GemmKernel; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) - : Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) + : Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index c7e37bc8a7..b68c30351d 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -78,63 +78,48 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_batched_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 3ff3f2f10e..a24e4bc8ab 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -62,71 +62,55 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -161,74 +144,55 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, BLayout, CLayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - if(!splitk) + if(s.log_level_ > 0) { - return ave_time = Run(ck_tile::integral_constant{}); - } - else - { - return ave_time = - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 67b411c1f0..462f11e405 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -328,5 +328,4 @@ template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk = false); + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 060dd311b5..e5aefad8d1 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -61,72 +61,56 @@ float grouped_gemm_multi_d(const std::vector& gemm_d using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -163,76 +146,55 @@ float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, BLayout, ELayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - if(!splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_multi_d_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 4a5be996c0..b4c10900d6 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -65,70 +65,54 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -167,75 +150,53 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::NumWaveGroups, GemmConfig::Preshuffle>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, // DsDataType (empty for no D tensors) + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout (empty for no D tensors) + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue, // DsDataType (empty for no D tensors) - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout (empty for no D tensors) - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - - if(splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp index 16352722e1..ea71abb213 100644 --- a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp @@ -72,10 +72,9 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::BQuantGrouped; @@ -137,8 +136,7 @@ float grouped_gemm(const std::vector& gemm_descs, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; + QuantGemmProblem::TransposeC>>; using Kernel = ck_tile::QuantGroupedGemmKernel; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; - constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; + using GemmPipeline = GemmQuantConfig::template GemmPipeline; - using GemmPipeline = - GemmQuantConfig::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - return ave_time = Run(ck_tile::integral_constant{}); + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 390a54644b..7a01b1dcea 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -79,8 +79,7 @@ float invoke_gemm(int n_warmup, // earlier stage. std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, @@ -109,7 +108,7 @@ float invoke_gemm(int n_warmup, ADataType, BDataType, AccDataType, - CDataType>(stream, group_count, kargs_ptr, splitk); + CDataType>(stream, group_count, kargs_ptr); } return ave_time; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index ac6ea99db3..4f2bebdf17 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -95,8 +95,7 @@ float invoke_gemm(int n_warmup, else { std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr}, @@ -119,18 +118,17 @@ float invoke_gemm(int n_warmup, kargs.size() * sizeof(ck_tile::GemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); - ave_time = - grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); + ave_time = grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr); } return ave_time; } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index cd241a2be0..af46884a90 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -170,13 +170,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -282,23 +278,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/grouped_flatmm.cpp b/example/ck_tile/18_flatmm/grouped_flatmm.cpp index da85c95dae..780a21ba14 100644 --- a/example/ck_tile/18_flatmm/grouped_flatmm.cpp +++ b/example/ck_tile/18_flatmm/grouped_flatmm.cpp @@ -113,13 +113,10 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. @@ -216,23 +212,7 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index fe7fe4c5d1..708e8a683e 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -113,13 +113,10 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = std::conditional_t{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp index 2b6dbace36..f9f8c0cec7 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp @@ -89,13 +89,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -128,7 +125,6 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC @@ -201,23 +197,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/moe_flatmm.cpp b/example/ck_tile/18_flatmm/moe_flatmm.cpp index 96b9ae29a4..4cca953066 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.cpp @@ -144,15 +144,11 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -261,37 +256,20 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, args.NumTokens * args.TopK * outputN * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, run_flush_cache, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index f177ef04ca..01128f8fe8 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -61,8 +61,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = - Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set; + ck_tile::ignore = Splitk; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -98,7 +97,6 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, MXPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index 9e2bc3e3fb..1c56295f9f 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -81,87 +81,45 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - // Epilogue selection: set to true for chainer-based, false for standard - // CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue, - // Chainer-based epilogue - ck_tile::EpilogueChainer, - ck_tile::DefaultScheduleTag>>, - // Standard CShuffleEpilogue - ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>>; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_d_fp16_example.inc" diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index d2663b033c..ca8573d6d2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -59,94 +59,80 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index afe43cd1c0..90874e6018 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -59,104 +59,85 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - const auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + auto preprocess = [&]() { + if(args.k_batch > 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + ck_tile::hip_check_error(hipMemsetAsync( + kargs.wei_ptr, 0, args.template GetWeightByte(), s.stream_id_)); } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - { - ck_tile::hip_check_error( - hipMemsetAsync(kargs.wei_ptr, - 0, - args.template GetWeightByte(), - s.stream_id_)); - } - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return InvokerResult{ave_time, args.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index ad5e8ae70f..c4d618a0bf 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -65,163 +65,143 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + const ck_tile::index_t spatial_lengths_accum = + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * + sizeof(WorkspaceDataType)); + ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const ck_tile::index_t spatial_lengths_accum = - std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * - sizeof(WorkspaceDataType)); - ck_tile::GroupedConvBwdWeightHostArgs ws_args = - ck_tile::GroupedConvBwdWeightHostArgs(args); - auto c_ptr = ws_args.wei_ptr; - ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const auto kargs = Kernel::MakeKernelArgs(ws_args); + const auto kargs = Kernel::MakeKernelArgs(ws_args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = { + static_cast(args.G_ * args.K_), + static_cast(args.C_ * spatial_lengths_accum)}; - ck_tile::index_t total_elements = 1; - std::vector shape = { - static_cast(args.G_ * args.K_), - static_cast(args.C_ * spatial_lengths_accum)}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); + auto input_size = ck_tile::make_tuple(shape[0], shape[1]); - auto input_tensors = - ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); - auto input_size = ck_tile::make_tuple(shape[0], shape[1]); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - ck_tile::hip_check_error( - hipMemsetAsync(ws_args.wei_ptr, - 0, - shape[0] * shape[1] * sizeof(WorkspaceDataType), - s.stream_id_)); - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel( - ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; + auto preprocess = [&]() { + if(args.k_batch > 1) + ck_tile::hip_check_error( + hipMemsetAsync(ws_args.wei_ptr, + 0, + shape[0] * shape[1] * sizeof(WorkspaceDataType), + s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); + return InvokerResult{ave_time, kargs.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 82541bb593..c94466aeb2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -70,91 +70,74 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - // ===================================================================== - // Split-K dispatch - // ===================================================================== - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(MemoryOpSet{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); } - else + + if(s.log_level_ > 0) { - return Run(MemoryOpAtomicAdd{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4261385a84..5dec340668 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -213,8 +213,7 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Kernel launch lambda: Uses EnableSplitImage based on layout support // ===================================================================== - const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) { - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto enable_split_image_) { constexpr bool EnableSplitImage = enable_split_image_.value; using GroupedConvTraitsType = std::conditional_t>; @@ -332,17 +330,11 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== if(use_split_image) { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } else { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 63dd54dcae..a78a880815 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -13,11 +13,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "conv_configs.hpp" -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - template auto calculate_rtol_atol(const ck_tile::index_t GemmK, const ck_tile::index_t kbatch, diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp index acb9126d65..9202bf9d98 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -85,60 +85,44 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_abd_fp16_example.inc" diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 47a22cdcba..d8988be7b0 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -173,77 +173,30 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - - // Epilogue selection: use chainer for RowCol/Tensor quant, standard for others - // Toggle to switch between chainer-based and standard CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; - - // Define the schedule tag based on quant mode - using ScheduleTag = - std::conditional_t>; - - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant), - // Chainer-based epilogue for RowCol/Tensor quant modes - ck_tile::EpilogueChainer, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>, - ScheduleTag>>, - // Standard CShuffleEpilogue for other modes - ck_tile::CShuffleEpilogue, typename TypeConfig::ADataType, - std::conditional_t< - std::is_same_v, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>>; - + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index d3ee9fe9c6..828c861349 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -48,112 +48,87 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, GemmConfiguration::NUM_WAVE_GROUPS, GemmConfiguration::PRESHUFFLE>; - const auto runKernel = [&](const auto memory_operation) -> std::tuple { - // We create the GEMM pipeline without specifying has_hot_loop or tail_num. - // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K - // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K - // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::StreamKKernel; - auto kernel_args = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); - ck_tile::DeviceMem workspace_data(workspace_size); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kernel_args)) + { + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); + } + + if(stream_config.log_level_ > 0) + { + // Reset sk flags to zero before each repetition of the kernel workspace_data.SetZero(); - kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + } - dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kernel_args)) + auto reset_data_buffers = [&]() { + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } - - if(stream_config.log_level_ > 0) + else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); } - - auto reset_data_buffers = [&]() { - if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) - { - // Clear the output C tensor results after each repetition of the kernel - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); - } - else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) - { - // Reset sk flags to zero before each repetition of the kernel - workspace_data.SetZero(); - } - }; - - std::function preprocess = reset_data_buffers; - - float average_time = - ck_tile::launch_kernel_time_mask(stream_config, - preprocess, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kernel_args)); - - ck_tile::index_t num_wgs_per_tile = - kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{average_time, num_wgs_per_tile}; }; - if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) - { - return runKernel(ck_tile::integral_constant{}); - } - else // We are using ck_tile::StreamKReductionStrategy::Reduction - { - return runKernel(ck_tile::integral_constant{}); - } + std::function preprocess = reset_data_buffers; + + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); + + ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; } #include "run_gemm_example.inc" diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index f9f13c6e85..1e159a5615 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -92,67 +92,59 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = GEMM_PIPELINE; - using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = + ck_tile::BatchedContractionKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = - ck_tile::BatchedContractionKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::GetBlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::GetBlockSize(); + if(!Kernel::IsSupportedArguments(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); + } - if(!Kernel::IsSupportedArguments(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - - return ck_tile::launch_kernel(s, kernel); - }; - - return Run(); + return ck_tile::launch_kernel(s, kernel); } #define HANDLE_CASE(G, M, N, K) \ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp index cce95cb3f1..6ce508b47d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -116,7 +116,6 @@ struct ConvTileFactory BLOCK_GEMM.warp_tile.k, GroupedConvTraitsType::FixedGemmParams::TransposeC, // TODO:: This template parameter will be moved inside the kernel - ck_tile::memory_operation_enum::set, BLOCK_GEMM.num_wave_groups, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, SCALAR_PER_VECTOR.c>>; diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index ad31fc52bc..91c75e3e8d 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 47908e0e5b..e2e165967a 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp index 083d9d9955..5ec73d780f 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index f26b5d7caf..fe94d16a7d 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -81,7 +81,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index c7c4e370e2..dbb3a0a8fc 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -184,7 +184,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 6dd2a4eada..ad0a2cadc6 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -795,7 +795,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 53bfa6041d..c73897f064 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -30,7 +30,6 @@ template struct CShuffleEpilogueProblem { - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t MWave = MWave_; - static constexpr index_t NWave = NWave_; - static constexpr index_t MPerXdl = MPerXdl_; - static constexpr index_t NPerXdl = NPerXdl_; - static constexpr index_t KPerXdl = KPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; - static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; - static constexpr bool FixedVectorSize = FixedVectorSize_; - static constexpr index_t VectorSizeC = VectorSizeC_; - static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; - static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; - static constexpr index_t kNumWaveGroups = kNumWaveGroups_; - static constexpr index_t NumDTensor = DsDataType::size(); + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; + static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; + static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; + static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; + static constexpr index_t kNumWaveGroups = kNumWaveGroups_; + static constexpr index_t NumDTensor = DsDataType::size(); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); @@ -105,28 +103,27 @@ struct CShuffleEpilogue ADataType, BDataType>; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kMPerBlock = Problem::kMPerBlock; - static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t MWave = Problem::MWave; - static constexpr index_t NWave = Problem::NWave; - static constexpr index_t MPerXdl = Problem::MPerXdl; - static constexpr index_t NPerXdl = Problem::NPerXdl; - static constexpr index_t KPerXdl = Problem::KPerXdl; - static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr bool FixedVectorSize = Problem::FixedVectorSize; - static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; - static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr index_t MPerIteration = MPerXdl * MWave; - static constexpr index_t NPerIteration = NPerXdl * NWave; - static constexpr index_t NumDTensor = Problem::NumDTensor; - static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); - static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t MWave = Problem::MWave; + static constexpr index_t NWave = Problem::NWave; + static constexpr index_t MPerXdl = Problem::MPerXdl; + static constexpr index_t NPerXdl = Problem::NPerXdl; + static constexpr index_t KPerXdl = Problem::KPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; + static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t MPerIteration = MPerXdl * MWave; + static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); + static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); CDElementwise elfunc_; @@ -142,8 +139,7 @@ struct CShuffleEpilogue concat('x', MWave, NWave), concat('x', MPerXdl, NPerXdl, KPerXdl), VectorSizeC, - isCTransposed ? "CTransposed" : "CNotTransposed", - mem_op_string()); + isCTransposed ? "CTransposed" : "CNotTransposed"); // clang-format on } @@ -445,7 +441,8 @@ struct CShuffleEpilogue CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window, const COutTensor& c_out_tensor) { - if constexpr(MemoryOperation == memory_operation_enum::set) + if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } @@ -617,7 +614,8 @@ struct CShuffleEpilogue }); // store/update - if constexpr(MemoryOperation == memory_operation_enum::set) + if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index cc2303582e..aafe7b9f58 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -15,17 +15,15 @@ template + bool UseRawStore_ = true> struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool UseRawStore = UseRawStore_; - static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; - static constexpr index_t NumDTensor = 0; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr index_t NumDTensor = 0; }; template -struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem + bool UseRawStore_ = true> +struct DefaultGemm2DEpilogueProblem + : public Default2DEpilogueProblem { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -81,7 +74,6 @@ struct Default2DEpilogue static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; - static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } @@ -102,7 +94,10 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - if constexpr(MemoryOperation == memory_operation_enum::set) + // FIXME? + // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp == + // memory_operation_enum::set) + if constexpr(true) { if constexpr(is_partition_index) { @@ -123,7 +118,10 @@ struct Default2DEpilogue } else { - if constexpr(MemoryOperation == memory_operation_enum::set) + // FIXME? + // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp == + // memory_operation_enum::set) + if constexpr(true) { if constexpr(is_partition_index) { diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 9a33801c8f..42dab68e91 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -558,21 +558,19 @@ struct FlatmmKernel return DTesnorIsValid; } - template - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); @@ -581,25 +579,81 @@ struct FlatmmKernel { return make_naive_tensor_view( a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); - index_t kFlatK = - FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2)); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - const auto& b_flat_tensor_view = [&]() { - return make_naive_tensor_view( - b_flat_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } }(); + // Step 3: Create tile window + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {block_idx_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, block_idx_m}); + } + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view + index_t kFlatK = + FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2)); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + const auto& b_flat_tensor_view = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + + // Step 2: No padding needed for b_flat + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -625,7 +679,56 @@ struct FlatmmKernel }, number{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -647,98 +750,8 @@ struct FlatmmKernel } }(); - constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; - constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; - - constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; - constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; - - auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale - : 1; // per-token scale - auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale - : 1; // per-channel scale - - static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1, - "only support per-tensor or per-row scaling"); - static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1, - "only support per-tensor or per-column scaling"); - - const auto scale_m_view = make_naive_tensor_view( - kargs.scale_m_ptr.ptr, - make_tuple(kargs.M / ScaleGranularityM, - ScaleGranularityKA == 0 - ? 1 - : splitk_batch_offset.splitted_k / - (ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)), - make_tuple(scale_stride_m, 0), - number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, - number<1>{}); - const auto scale_n_view = make_naive_tensor_view( - kargs.scale_n_ptr.ptr, - make_tuple(ScaleGranularityKB == 0 - ? 1 - : (splitk_batch_offset.splitted_k / - (ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)), - kargs.N / ScaleGranularityN), - make_tuple(0, scale_stride_n), - number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, - number<1>{}); - - return make_tuple(a_tensor_view, - b_flat_tensor_view, - ds_tensor_view, - e_tensor_view, - scale_m_view, - scale_n_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -755,93 +768,72 @@ struct FlatmmKernel } }(); - return make_tuple(a_pad_view, - b_flat_tensor_view, - ds_pad_view, - e_pad_view, - views.at(number<4>{}), - views.at(number<5>{})); - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } - constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK; - constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK; + template + CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m) + { + constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; + constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; - auto scale_m_window = make_tile_window(views.at(number<4>{}), - make_tuple(number{}, - number < ScaleGranularityKA == 0 - ? TilePartitioner::NPerBlock - : TilePartitioner::KPerBlock > {}), - {i_m, 0}); - auto scale_n_window = make_tile_window(views.at(number<5>{}), - make_tuple(number < ScaleGranularityKB == 0 - ? TilePartitioner::MPerBlock - : TilePartitioner::KPerBlock > {}, - number{}), - {0, i_n}); + auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale + : 1; // per-token scale - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_m_window, - scale_n_window); + // Step 1: Create tensor view + const auto scale_m_view = make_naive_tensor_view( + kargs.scale_m_ptr.ptr, + make_tuple(kargs.M / ScaleGranularityM, + ScaleGranularityKA == 0 + ? 1 + : (splitk_batch_offset.splitted_k / ScaleGranularityKA)), + make_tuple(scale_stride_m, 0), + number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window(scale_m_view, + make_tuple(number{}, + number < ScaleGranularityKA == 0 + ? TilePartitioner::NPerBlock + : TilePartitioner::KPerBlock > {}), + {block_idx_m, 0}); + } + + template + CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_n) + { + constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; + constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; + + auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale + : 1; // per-channel scale + + // Step 1: Create tensor view + const auto scale_n_view = make_naive_tensor_view( + kargs.scale_n_ptr.ptr, + make_tuple( + ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB), + kargs.N / ScaleGranularityN), + make_tuple(0, scale_stride_n), + number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window(scale_n_view, + make_tuple(number < ScaleGranularityKB == 0 + ? TilePartitioner::MPerBlock + : TilePartitioner::KPerBlock > {}, + number{}), + {0, block_idx_n}); } template @@ -857,45 +849,74 @@ struct FlatmmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); + const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = FlatmmPipeline{}.template operator()( + const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); - auto scale_m_window = gemm_tile_windows.at(number<4>{}); - auto scale_n_window = gemm_tile_windows.at(number<5>{}); - - // Run Epilogue Pipeline + // Run Epilogue Pipeline with k_batch dispatching if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1) { - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template - operator()( - c_block_window, - c_block_tile, - d_block_window, - smem_ptr_ping, - scale_m_window, - scale_n_window); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + scale_m_window, + scale_n_window); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + scale_m_window, + scale_n_window); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()( + e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()( + e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -924,8 +945,7 @@ struct FlatmmKernel __shared__ char smem_ptr_ping[GetSmemPingSize()]; __shared__ char smem_ptr_pong[GetSmemPongSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); diff --git a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp index 05d50666a5..61001522b0 100644 --- a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp @@ -100,21 +100,19 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); @@ -123,25 +121,80 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {block_idx_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, block_idx_m}); + } + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); index_t kFlatN = kargs.N * kargs.K / kFlatK; - const auto& b_flat_tensor_view = [&]() { - return make_naive_tensor_view( - b_flat_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - }(); + const auto& b_flat_tensor_view = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + // Step 2: No padding needed for b_flat + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -167,7 +220,56 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -189,70 +291,8 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( - reinterpret_cast(scale_n.ptr), - make_tuple(FlatScaleN, FlatScaleK), - make_tuple(FlatScaleK, 1), - number<8>{}, - number<1>{}); - - return make_tuple( - a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -269,77 +309,37 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } - auto scale_block_window = - make_tile_window(views.at(I4), - make_tuple(number{}, - number{}), - {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + template + CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs, + const index_t block_idx_n) + { + auto scale_n = kargs.scale_n_ptr; - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_block_window); + // Step 1: Create tensor view + index_t FlatScaleK = + (kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1); + index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); + + const auto scale_b_flat_view = make_naive_tensor_view( + reinterpret_cast(scale_n.ptr), + make_tuple(FlatScaleN, FlatScaleK), + make_tuple(FlatScaleK, 1), + number<8>{}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window( + scale_b_flat_view, + make_tuple(number{}, + number{}), + {block_idx_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); } template @@ -355,21 +355,15 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_block_window = MakeScaleBBlockWindow(kargs, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& scale_block_window = gemm_tile_windows.at(I4); - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK || ScaleM::GranularityMN == -1 // or ScaleA is disable || ScaleN::GranularityMN == -1, // or ScaleB is disable @@ -378,6 +372,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -434,8 +453,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b47ec4a829..604089b7c4 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1476,7 +1476,8 @@ struct MoeFlatmmKernel c_scatter_valids[mIter]); if constexpr(!IsInputGemm || - EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) + decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::atomic_add) c_scatter_tile_window.update(c_out_tensor); else c_scatter_tile_window.store(c_out_tensor); diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index 799f8f26a9..a58d71c790 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -113,32 +113,50 @@ struct MXFlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { static_assert(std::is_same_v, "A tensor for mx must be RowMajor"); return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); }(); + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, 0}); + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view with special flat layout constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock; constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; const index_t kFlatKBlocks = kargs.K / kKPerBlock; const index_t kFlatN = kargs.N / kNWarpTile; - const auto& b_flat_tensor_view = [&]() { + + const auto& b_flat_tensor_view = [&]() { static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0, "wrong! vector size for B tensor"); auto&& naive_desc = make_naive_tensor_descriptor_packed( @@ -153,6 +171,22 @@ struct MXFlatmmKernel : FlatmmKernel(b_flat_ptr, desc); }(); + // Step 2: No padding for flat B + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -178,7 +212,56 @@ struct MXFlatmmKernel : FlatmmKernel{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -200,92 +283,8 @@ struct MXFlatmmKernel : FlatmmKernel{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view( - reinterpret_cast(scale_a.ptr), scale_a_desc); - }(); - - // B scale tensor view - const auto& scale_b_tensor_view = [&]() { - const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); - const auto scale_b_desc = transform_tensor_descriptor( - scale_b_navie_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view( - reinterpret_cast(scale_b.ptr), scale_b_desc); - }(); - - return make_tuple(a_tensor_view, - b_flat_tensor_view, - ds_tensor_view, - e_tensor_view, - scale_a_tensor_view, - scale_b_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - static_assert(std::is_same_v, - "A tensor for mx must be RowMajor"); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -302,79 +301,71 @@ struct MXFlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - static_assert(std::is_same_v, - "A tensor for mx must be RowMajor"); - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } + template + CK_TILE_DEVICE static auto MakeScaleABlockWindow(const KernelArgs& kargs, + const index_t block_idx_m) + { static constexpr int BlockScaleSize = 32; - auto scale_a_block_window = make_tile_window( - views.at(I4), + const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl)); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl); + + // Step 1: Create tensor view + const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl)); + const auto scale_a_desc = transform_tensor_descriptor( + scale_a_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto& scale_a_tensor_view = make_tensor_view( + reinterpret_cast(kargs.scale_m_ptr.ptr), scale_a_desc); + + // Step 2: Create tile window + return make_tile_window( + scale_a_tensor_view, make_tuple(number{}, number{}), - {i_m / MXdlPack, 0}); + {block_idx_m / MXdlPack, 0}); + } - auto scale_b_block_window = make_tile_window( - views.at(I5), + template + CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs, + const index_t block_idx_n) + { + static constexpr int BlockScaleSize = 32; + + const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl)); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl); + + // Step 1: Create tensor view + const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); + const auto scale_b_desc = transform_tensor_descriptor( + scale_b_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto& scale_b_tensor_view = make_tensor_view( + reinterpret_cast(kargs.scale_n_ptr.ptr), scale_b_desc); + + // Step 2: Create tile window + return make_tile_window( + scale_b_tensor_view, make_tuple(number{}, number{}), - {i_n / NXdlPack, 0}); - - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_a_block_window, - scale_b_block_window); + {block_idx_n / NXdlPack, 0}); } template @@ -390,22 +381,16 @@ struct MXFlatmmKernel : FlatmmKernel( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_a_block_window = MakeScaleABlockWindow(kargs, block_idx_m); + const auto& scale_b_block_window = MakeScaleBBlockWindow(kargs, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& scale_a_block_window = gemm_tile_windows.at(I4); - const auto& scale_b_block_window = gemm_tile_windows.at(I5); - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK || ScaleM::GranularityMN == -1 // or ScaleA is disable || ScaleN::GranularityMN == -1, // or ScaleB is disable @@ -422,22 +407,46 @@ struct MXFlatmmKernel : FlatmmKernel( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -466,27 +475,17 @@ struct MXFlatmmKernel : FlatmmKernel::value)) - { - constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); - RunFlatmm(a_ptr, - b_flat_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_ping, - smem_ptr_pong, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - static_assert(false, - "Unimplemented: atomic_add with odd vector size for fp16/bf16"); - } + constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); + RunFlatmm(a_ptr, + b_flat_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_ping, + smem_ptr_pong, + kargs, + splitk_batch_offset, + i_m, + i_n); partition_idx += gridDim.x; } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 95114e8496..5ba5699dda 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -361,6 +361,7 @@ struct GroupedGemmKernel * * @param a_ptr input A pointer * @param b_ptr input B pointer + * @param ds_ptr input Ds pointer * @param c_ptr output C pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments @@ -381,49 +382,54 @@ struct GroupedGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m) + .at(Base::I0); + const auto& b_block_window = + Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n) + .at(Base::I0); + const auto& d_block_window = + Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM pipeline + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + a_block_window, b_block_window, num_loop, smem_ptr_0); + // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. + * @note RunGEMM2LDS with two shared memory buffers using the ping pong buffer mechanism. * * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param smem_ptr_1 The second start memory pointer of the shared memory block. + * @param ds_ptr input Ds pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. + * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -440,54 +446,39 @@ struct GroupedGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m) + .at(Base::I0); + const auto& b_block_window = + Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n) + .at(Base::I0); + const auto& d_block_window = + Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM pipeline with compile-time branching - const auto& c_block_tile = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - // Preshuffle version - without has_hot_loop parameter - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - else - { - // Regular version - with has_hot_loop parameter - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - }(); + // Run GEMM cooperatively by whole workgroup. + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index d1fd32dc1b..47e59c4704 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -222,19 +222,13 @@ struct StreamKKernel const index_t block_idx_n, const index_t k_size) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - UniversalGemmKernel::template MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); - - const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); - const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); - const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + // Create block windows using specialized methods + const auto& as_block_window = + UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m); + const auto& bs_block_window = + UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n); + const auto& ds_block_window = + UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this @@ -243,6 +237,7 @@ struct StreamKKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], bs_block_window[UniversalGemmKernel::I0], num_loop, @@ -253,7 +248,9 @@ struct StreamKKernel if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = + UniversalGemmKernel::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); } @@ -525,21 +522,13 @@ struct StreamKKernel const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k_b; CDataType* c_ptr = static_cast(kargs.e_ptr); - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - UniversalGemmKernel::template MakeGemmTensorViews< - EpiloguePipeline::MemoryOperation>( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size); - - const auto& gemm_pad_views = - UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); - const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); - const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + // Create block windows using specialized methods + const auto& as_block_window = + UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m); + const auto& bs_block_window = + UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n); + const auto& ds_block_window = + UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n); // Since num_loop can vary per WG and per iteration of the Stream-K while loop, // we compute has_hot_loop and tail_num here. This is a similar pattern used by @@ -548,6 +537,7 @@ struct StreamKKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk); + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], bs_block_window[UniversalGemmKernel::I0], num_loop_sk, @@ -594,7 +584,8 @@ struct StreamKKernel } } - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows< + TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n); EpiloguePipeline{}( c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); } @@ -617,7 +608,8 @@ struct StreamKKernel // tensor. if(tile_started && !partner_in_tile) { - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows< + TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n); EpiloguePipeline{}( c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); break; diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index a6022e8b8e..0b0f6c18ef 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase static constexpr index_t NPerBlock = BlockGemmShapeType::kN; static constexpr index_t KPerBlock = BlockGemmShapeType::kK; static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType; + static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic) + ? memory_operation_enum::atomic_add + : memory_operation_enum::set; StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid); diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 77952c9afd..65f58a8ca5 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -254,6 +254,8 @@ struct UniversalGemmKernel static_assert(DsLayout::size() == DsDataType::size(), "The size of DsLayout and DsDataType should be the same"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + using KernelArgs = UniversalGemmKernelArgs; @@ -609,17 +611,13 @@ struct UniversalGemmKernel return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; } - template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const index_t k_size) + MakeABlockWindows(const std::array& as_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_m) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - + // Step 1: Create tensor views for A tensors (from MakeGemmTensorViews) const auto& as_tensor_view = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; @@ -645,6 +643,58 @@ struct UniversalGemmKernel }, number{}); + // Step 2: Create padded views (from MakeGemmPadViews) + const auto& as_pad_view = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(as_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(as_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows (from MakeGemmTileWindows) + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_m}); + } + }, + number{}); + + return as_block_window; + } + + CK_TILE_DEVICE static auto + MakeBBlockWindows(const std::array& bs_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_n) + { + // Step 1: Create tensor views for B tensors (from MakeGemmTensorViews) const auto& bs_tensor_view = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -733,96 +783,20 @@ struct UniversalGemmKernel }, number{}); - const auto& ds_tensor_view = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - }, - number{}); - - // TODO: enable vector write for C in ColMajor - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& as_pad_view = generate_tuple( - [&](auto i) { - const auto& a_tensor_view = views.at(I0); - using AiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - const auto& b_flat_pad_view = views.at(I1); - + // Step 2: Create padded views (from MakeGemmPadViews) const auto& bs_pad_view = generate_tuple( [&](auto i) { - const auto& b_tensor_view = views.at(I1); - using BiLayout = remove_cvref_t>; + using BiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - return pad_tensor_view(b_tensor_view[i], + return pad_tensor_view(bs_tensor_view[i], make_tuple(number{}, number{}), sequence{}); } else { - return pad_tensor_view(b_tensor_view[i], + return pad_tensor_view(bs_tensor_view[i], make_tuple(number{}, number{}), sequence{}); @@ -830,86 +804,7 @@ struct UniversalGemmKernel }, number{}); - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor - const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - if constexpr(GemmPipeline::Preshuffle) - { - // For flatmm, we need to use the flat B tensor view - return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); - } - else - { - return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& as_pad_view = views.at(I0); - const auto& bs_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& as_block_window = generate_tuple( - [&](auto i) { - using AiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(as_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(as_pad_view[i], - make_tuple(number{}, - number{}), - {0, i_m}); - } - }, - number{}); - + // Step 3: Create tile windows (from MakeGemmTileWindows) const auto& bs_block_window = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -942,7 +837,63 @@ struct UniversalGemmKernel }, number{}); - const auto ds_block_window = generate_tuple( + return bs_block_window; + } + + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor views for D tensors (from MakeGemmTensorViews) + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // Step 2: Create padded views (from MakeGemmPadViews) + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows (from MakeGemmTileWindows) + const auto& ds_block_window = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -962,12 +913,62 @@ struct UniversalGemmKernel }, number{}); + return ds_block_window; + } + + template + CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews) + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_E), + number<1>{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view (from MakeGemmPadViews) + const auto& e_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window (from MakeGemmTileWindows) auto e_block_window = make_tile_window( e_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + return e_block_window; } /** @@ -995,30 +996,32 @@ struct UniversalGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& as_block_window = + MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& bs_block_window = + MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); - if(UseDefaultScheduler || (get_warp_id() == 0)) + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); + // Run Epilogue Pipeline + if(k_batch == 1) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); } } @@ -1051,22 +1054,17 @@ struct UniversalGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& as_block_window = + MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& bs_block_window = + MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, AElementWise{}, bs_block_window, @@ -1076,9 +1074,20 @@ struct UniversalGemmKernel smem_ptr_1); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } // Non-persistent kernel entry point @@ -1119,39 +1128,30 @@ struct UniversalGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } } @@ -1204,40 +1204,28 @@ struct UniversalGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm(as_ptr, + RunGemm2LDS(as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr_0, + smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); - } + } + else + { + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } // Advance to the next work item block_id += grid_size; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index ba67a9ee4d..8aab756ccf 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -401,6 +401,592 @@ struct QuantGemmKernel index_t splitted_k; }; + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const QuantGemmKernelArgs& kargs, + const index_t k_size, + const index_t i_m) + { + // Step 1: Create tensor view for A + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, k_size), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(k_size, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + return a_block_window; + } + + CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for AQ + const auto& aq_tensor_view = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + { + static_assert(std::is_same_v); + const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; + const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; + const auto aq_desc = + make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), + make_tuple(aq_x, 1), + number{}, + number<1>{}); + + const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ; + const auto aq_pad0_desc = transform_tensor_descriptor( + aq_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; + const auto wave_tile_size = + GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; + const auto wave_tile_count_x = + ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); + + const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( + aq_pad0_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto aq_pad1_desc = transform_tensor_descriptor( + aq_unmerge_pad0_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_pass_through_transform(wave_tile_count_x), + make_right_pad_transform( + wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto pad_wave_size = + ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); + const auto aq_merge_pad1_desc = transform_tensor_descriptor( + aq_pad1_desc, + make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), + make_pass_through_transform(pad_wave_size)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tensor_view(aq_ptr, aq_merge_pad1_desc); + } + else if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + !PreshuffleQuant) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + else // Column major AQ + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.QK_A, kargs.M), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, 0), // broadcasting over n + number<1>{}, + number<1>{}); + } + else + { + return nullptr; + } + }(); + + // Step 2: Create tile window (no padding for AQ) + const auto& aq_block_window = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + constexpr auto block_m = TilePartitioner::MPerBlock; + constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto tile_window_width = + ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); + constexpr auto tile_window_height = block_m / warp_m; + auto block_m_idx = i_m / block_m; + return make_tile_window( + aq_tensor_view, + make_tuple(number{}, number{}), + {block_m_idx * tile_window_height, 0}); + } + else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) + { + using QuantGroupSize = remove_cvref_t; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto block_m = TilePartitioner::MPerBlock; + if constexpr(std::is_same_v) + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else // Column major AQ + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, number{}), + {0, i_m}); + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + constexpr auto block_m = TilePartitioner::MPerBlock; + constexpr auto block_k = TilePartitioner::KPerBlock; + return make_tile_window( + aq_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return nullptr; + } + }(); + + return aq_block_window; + } + + CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr, + const QuantGemmKernelArgs& kargs, + const index_t k_size, + const index_t i_n) + { + // Step 1: Create tensor view for B + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(k_size, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + if constexpr(PreshuffleB) + { + index_t kFlatK = + GemmPipeline::flatKPerWarp * + (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + return make_naive_tensor_view( + b_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + if constexpr(std::is_same_v) + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size / 2), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + } + }(); + + // Step 2: Create padded view (or flat view for PreshuffleB) + const auto& b_pad_view = [&]() { + if constexpr(PreshuffleB) + { + return b_tensor_view; // no padding for preshuffle + } + else if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + else + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + const auto& b_block_window = [&]() { + if constexpr(PreshuffleB) + { + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); + } + else + { + if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + else + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + } + }(); + + return b_block_window; + } + + CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for BQ + const auto& bq_tensor_view = [&]() { + if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_naive_tensor_view( + bq_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(0, 1), // broadcasting over m + number<1>{}, + number<1>{}); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); + + return MakePreshuffledQuantTensorView< + GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlock, + TilePartitioner::BlockGemmShape::WarpTile::at(I1), + GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); + } + else + { + using QuantGroupSize = remove_cvref_t; + + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), + make_tuple(kargs.stride_BQ, 1), + number{}, + number<1>{}); + } + else + { + return nullptr; + } + }(); + + // Step 2: Create tile window (no padding for BQ) + const auto& bq_block_window = [&]() { + if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_tile_window(bq_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + using QuantGroupSize = remove_cvref_t; + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v); + constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); + constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto tile_window_width = + ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); + constexpr auto tile_window_height = block_n / warp_n; + auto block_n_idx = i_n / block_n; + + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, number{}), + {block_n_idx * tile_window_height, 0}); + } + else + { + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } + else + { + return nullptr; + } + }(); + + return bq_block_window; + } + + template + CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for C + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view + const auto& c_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return c_block_window; + } + CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { if(kargs.k_batch != 1) @@ -1143,9 +1729,7 @@ struct QuantGemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, const AQDataType* aq_ptr, @@ -1157,25 +1741,22 @@ struct QuantGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = [&]() { if constexpr(kQuantType == QuantType::AQuantGrouped) { - const auto& aq_block_window = gemm_tile_windows.at(I1); - index_t m = 0; + index_t m = 0; if constexpr(PreshuffleQuant) { m = kargs.M; @@ -1185,8 +1766,7 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t n = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { n = kargs.N; @@ -1196,10 +1776,8 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::ABQuantGrouped) { - const auto& aq_block_window = gemm_tile_windows.at(I1); - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t m = 0; - index_t n = 0; + index_t m = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { m = kargs.M; @@ -1222,86 +1800,111 @@ struct QuantGemmKernel } }(); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I4); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - if constexpr(kQuantType == QuantType::ABQuantGrouped || - kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr_0, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } - else if constexpr(kQuantType == QuantType::RowColQuant) + else { - const auto& aq_block_window = gemm_tile_windows.at(I1); - const auto& bq_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, - c_block_tile, - c_block_window, - smem_ptr_0, - aq_block_window, - bq_block_window); - } - else if constexpr(kQuantType == QuantType::TensorQuant) - { - // TODO: why doesn't readfirstlane work here? - // const AccDataType aq_scale = - // __builtin_amdgcn_readfirstlane(type_convert(*aq_ptr)); - // const AccDataType bq_scale = - // __builtin_amdgcn_readfirstlane(type_convert(*bq_ptr)); - const AccDataType aq_scale = type_convert(*aq_ptr); - const AccDataType bq_scale = type_convert(*bq_ptr); - EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr_0, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * + * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * * @param a_ptr input A pointer * @param b_ptr input B pointer * @param aq_ptr input AQ pointer + * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, - const AQDataType* aq_ptr, + [[maybe_unused]] const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, - void* smem_ptr_1, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, const QuantGemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = [&]() { if constexpr(kQuantType == QuantType::BQuantGrouped) { - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t n = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { n = kargs.N; @@ -1320,19 +1923,23 @@ struct QuantGemmKernel } }(); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I4); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); + // Run Epilogue Pipeline with k_batch dispatch if constexpr(kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else - { - return; - // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or - // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped, - // "DoubleSmemBuffer Not implemented"); + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } } } @@ -1343,16 +1950,19 @@ struct QuantGemmKernel const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); - // options - const ADataType* a_ptr = static_cast(kargs.a_ptr); - const BDataType* b_ptr = static_cast(kargs.b_ptr); + + // Apply splitk offset to input pointers + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_ptr = + static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - assert(kargs.k_batch == 1); + if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 7e246961cb..1c98a372be 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -374,7 +374,7 @@ struct QuantGroupedGemmKernel CK_TILE_DEVICE static void RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, const BDataType* b_ptr, - const AQDataType* aq_ptr, + [[maybe_unused]] const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, void* smem_ptr_0, @@ -385,25 +385,21 @@ struct QuantGroupedGemmKernel const index_t block_idx_n) { static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped"); - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& bq_block_window = + Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I2); - - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + // Run GEMM cooperatively by whole workgroup + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, b_block_window, bq_block_window, num_loop, @@ -411,10 +407,20 @@ struct QuantGroupedGemmKernel smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + // Run Epilogue Pipeline with split_k dispatch + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } } /** @@ -449,16 +455,15 @@ struct QuantGroupedGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I2); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& aq_block_window = + Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = + Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( @@ -466,51 +471,77 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::AQuantGrouped) + // Run GEMM cooperatively by whole workgroup + const auto& c_block_tile = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant || + kQuantType == QuantType::TensorQuant) + { + return GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + } + }(); + + // Run Epilogue Pipeline with split_k dispatch + if(kargs.k_batch == 1) { - const auto& aq_block_window = gemm_tile_windows.at(Base::I1); - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - aq_block_window, - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); + auto c_block_window = Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - // Run Epilogue Pipeline - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); - - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - // Run Epilogue Pipeline - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + if constexpr(kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr_0, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } else { - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I4); - if constexpr(kQuantType == QuantType::RowColQuant) + auto c_block_window = + Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) { - const auto& aq_block_window = gemm_tile_windows.at(Base::I1); - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index ad445e17a7..2e5f536ab7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -617,6 +617,117 @@ struct GroupedConvolutionBackwardDataKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + CK_TILE_DEVICE static auto + MakeABlockWindow(const OutDataType* a_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_k) + { + // Step 1: Create tensor view for A (Out tensor) + const auto& a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_descs_m_k[group_id]); + + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, i_k}); + + return a_block_window; + } + + CK_TILE_DEVICE static auto + MakeBBlockWindow(const InDataType* b_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_n, + const index_t i_k) + { + // Step 1: Create tensor view for B (Weight tensor) + const auto& b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_descs_n_k[group_id]); + + // Step 2: Create padded view + const auto& b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_k, i_n}); + + return b_block_window; + } + + CK_TILE_DEVICE static auto + MakeDBlockWindows(const std::array& ds_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_n) + { + // Create D tensor block windows + const auto ds_block_window = generate_tuple( + [&](auto i) { + // Step 1: Create tensor view for D + const auto& d_tensor_view = make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]); + + // Step 2: Create padded view + const auto& d_pad_view = + pad_tensor_view(d_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window(d_pad_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + return ds_block_window; + } + + template + CK_TILE_DEVICE static auto + MakeCBlockWindow(WeiDataType* c_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for C (Input tensor) + const auto& c_tensor_view = make_tensor_view( + c_ptr, kargs.c_grid_descs_m_n[group_id]); + + // Step 2: Create padded view + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return c_block_window; + } + CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs) { @@ -895,38 +1006,49 @@ struct GroupedConvolutionBackwardDataKernel const index_t block_idx_k, const index_t group_id) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k); + const auto& d_block_window = + MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. * * @param a_ptr input A pointer * @param b_ptr input B pointer @@ -951,23 +1073,19 @@ struct GroupedConvolutionBackwardDataKernel const index_t block_idx_k, const index_t group_id) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k); + const auto& d_block_window = + MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, @@ -976,11 +1094,27 @@ struct GroupedConvolutionBackwardDataKernel smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs, @@ -1066,8 +1200,7 @@ struct GroupedConvolutionBackwardDataKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1086,8 +1219,7 @@ struct GroupedConvolutionBackwardDataKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 4b7ad72ffc..6bcd05e9ba 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -518,25 +518,6 @@ struct GroupedConvolutionBackwardWeightKernel return false; } -#if defined(__gfx11__) - if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set) - { - return false; - } -#endif - - if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add) - { - if(kargs.k_batch == 1) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1."); - } - return false; - } - } - if constexpr(!std::is_same_v && !std::is_same_v) { @@ -704,29 +685,31 @@ struct GroupedConvolutionBackwardWeightKernel template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const OutDataType* a_ptr, - const InDataType* b_ptr, - const std::array& ds_ptr, - WeiDataType* c_ptr, - const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + MakeCBlockWindow(WeiDataType* c_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); - const auto& a_tensor_view = [&]() { - return make_tensor_view(a_ptr, - kargs.a_grid_desc_k_m); // A: out - }(); + const auto& c_tensor_view = + make_tensor_view(c_ptr, kargs.c_grid_desc_m_n); - const auto& b_tensor_view = [&]() { - return make_tensor_view(b_ptr, - kargs.b_grid_desc_k_n); // B: in - }(); + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); - const auto& c_tensor_view = [&]() { - return make_tensor_view(c_ptr, - kargs.c_grid_desc_m_n); - }(); + return make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, block_idx_n}); + } + CK_TILE_DEVICE static auto + MakeDBlockWindows(const std::array& ds_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { const auto& ds_tensor_view = generate_tuple( [&](auto i) { static_assert(std::is_same_v, OutLayout>, @@ -741,30 +724,7 @@ struct GroupedConvolutionBackwardWeightKernel }, number{}); - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& ds_tensor_view = views.at(I2); - const auto& ds_pad_view = generate_tuple( + const auto& ds_pad_view = generate_tuple( [&](auto i) { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, @@ -773,67 +733,58 @@ struct GroupedConvolutionBackwardWeightKernel }, number{}); - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I3); - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); - } - - /** - * @brief Create views to the data that each workgroup will process. - * - * @param views padded views of A, B, D and C tensors - * @param i_m block m-index - * @param i_n block n-index - * @param i_k block k-index - * - * @return tuple of tile windows for A, B, D and C tensors - */ - template - CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, - const index_t i_m, - const index_t i_n, - const index_t i_k) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& c_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_k, i_m}); - }(); - - const auto& b_block_window = [&]() { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_k, i_n}); - }(); - - const auto ds_block_window = generate_tuple( + return generate_tuple( [&](auto i) { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); }, number{}); + } - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); + CK_TILE_DEVICE static auto + MakeBBlockWindow(const InDataType* b_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_n, + const index_t block_idx_k) + { + static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_desc_k_n); - return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + const auto& b_pad_view = + pad_tensor_view(b_tensor_view, + make_tuple(number{} * kargs.k_batch, + number{}), + sequence{}); + + return make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {block_idx_k, block_idx_n}); + } + + CK_TILE_DEVICE static auto + MakeABlockWindow(const OutDataType* a_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_k) + { + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + const auto& a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_desc_k_m); + + const auto& a_pad_view = + pad_tensor_view(a_tensor_view, + make_tuple(number{} * kargs.k_batch, + number{}), + sequence{}); + + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_k, block_idx_m}); } /** @@ -859,28 +810,30 @@ struct GroupedConvolutionBackwardWeightKernel const index_t block_idx_n, const index_t block_idx_k) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + // Create block windows using helper methods + const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k); + const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k); + const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** @@ -910,27 +863,33 @@ struct GroupedConvolutionBackwardWeightKernel const index_t block_idx_n, const index_t block_idx_k) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + // Create block windows using helper methods + const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k); + const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k); + const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { +#if defined(__gfx11__) + return; +#endif + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const @@ -960,12 +919,6 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const { -#if defined(__gfx11__) - if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set) - { - return; - } -#endif if constexpr(GroupedConvTraitsType_::ExplicitGemm) { CallExplicitGemm(kargs); @@ -1001,9 +954,7 @@ struct GroupedConvolutionBackwardWeightKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1021,9 +972,7 @@ struct GroupedConvolutionBackwardWeightKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 0f143d7ff7..1b81bce34a 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -794,34 +794,53 @@ struct GroupedConvolutionForwardKernel return true; } - template + template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const InDataType* a_ptr, - const WeiDataType* b_ptr, - const std::array& ds_ptr, - OutDataType* c_ptr, - const ADescType& a_desc, - const BDescType& b_desc, - const CDescType& c_desc) + MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); - const auto& a_tensor_view = [&]() { - return make_tensor_view(a_ptr, a_desc); - }(); + // Step 1: Create tensor view + const auto& a_tensor_view = make_tensor_view(a_ptr, a_desc); - const auto& b_tensor_view = [&]() { - return make_tensor_view(b_ptr, b_desc); - }(); + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); - // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - return make_tensor_view(c_ptr, c_desc); - }(); + // Step 3: Create tile window + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, 0}); + } + template + CK_TILE_DEVICE static auto + MakeBBlockWindow(const WeiDataType* b_ptr, const BDescType& b_desc, const index_t block_idx_n) + { + // Step 1: Create tensor view + const auto& b_tensor_view = make_tensor_view(b_ptr, b_desc); + + // Step 2: Create padded view + const auto& b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {block_idx_n, 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const CDescType& c_desc, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { static_assert(std::is_same_v, OutLayout>, @@ -836,30 +855,8 @@ struct GroupedConvolutionForwardKernel }, number{}); - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& ds_tensor_view = views.at(I2); - const auto& ds_pad_view = generate_tuple( + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( [&](auto i) { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, @@ -868,55 +865,38 @@ struct GroupedConvolutionForwardKernel }, number{}); - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I3); - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& c_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - }(); - - const auto& b_block_window = [&]() { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - }(); - - const auto ds_block_window = generate_tuple( + // Step 3: Create tile windows + return generate_tuple( [&](auto i) { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); }, number{}); + } - auto c_block_window = make_tile_window( + template + CK_TILE_DEVICE static auto MakeCBlockWindow(OutDataType* c_ptr, + const CDescType& c_desc, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view + const auto& c_tensor_view = + make_tensor_view(c_ptr, c_desc); + + // Step 2: Create padded view + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( c_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + {block_idx_m, block_idx_n}); } /** @@ -931,6 +911,7 @@ struct GroupedConvolutionForwardKernel * @param b_desc Weight tensor B descriptor * @param c_desc Output tensor C descriptor * @param gemm_k The GEMM K dimension + * @param k_batch The K batch parameter for split-K * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -945,34 +926,41 @@ struct GroupedConvolutionForwardKernel const BDescType& b_desc, const CDescType& c_desc, const index_t gemm_k, + const index_t k_batch, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise& elfunc) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m); + const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); + + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } /** @@ -990,6 +978,7 @@ struct GroupedConvolutionForwardKernel * @param b_desc Weight tensor B descriptor * @param c_desc Output tensor C descriptor * @param gemm_k The GEMM K dimension + * @param k_batch The K batch parameter for split-K * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -1005,33 +994,41 @@ struct GroupedConvolutionForwardKernel const BDescType& b_desc, const CDescType& c_desc, const index_t gemm_k, + const index_t k_batch, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise& elfunc) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m); + const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); + + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const @@ -1185,9 +1182,7 @@ struct GroupedConvolutionForwardKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1200,6 +1195,7 @@ struct GroupedConvolutionForwardKernel b_desc, c_desc, kargs.GemmK, + kargs.k_batch, i_m, i_n, kargs.elfunc); @@ -1207,9 +1203,7 @@ struct GroupedConvolutionForwardKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, @@ -1221,6 +1215,7 @@ struct GroupedConvolutionForwardKernel b_desc, c_desc, kargs.GemmK, + kargs.k_batch, i_m, i_n, kargs.elfunc); diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 77eb416532..37005cccc1 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -99,62 +99,47 @@ class TestCkTileBatchedGemm : public ::testing::Test scheduler>; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 9b90110c07..0572115201 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -120,8 +120,8 @@ using SimpleCShuffleEpilogueProblem = MPerXdl, NPerXdl, KPerXdl, - false, // isCTransposed, - memory_operation_enum::set>; + false // isCTransposed + >; template auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index e949ed45e6..8dc2e88430 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -182,74 +182,58 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmPipeline = typename GemmPipelineTypeSelector::pipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + const auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + const dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) { - Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - Run(ck_tile::integral_constant{}); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7d82958acf..6fb1b77fa8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -356,8 +356,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = + std::conditional_t; - using GemmEpilogue = std:: - conditional_t; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 8217f5a3d9..6a6806641a 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -170,88 +170,69 @@ class TestCkTileGemmMultiD : public ::testing::Test using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = + std::conditional_t; - using GemmEpilogue = std:: - conditional_t; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 540109a999..237dc24c3b 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -105,71 +105,60 @@ class TestCkTileStreamK : public ::testing::Test NumWaveGroup, preshuffle>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - // We create the GEMM pipeline without specifying has_hot_loop or tail_num. - // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K - // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K - // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // For initial testing, we will just test with one pipeline. - // More extensive testing is coming later and will test other pipelines. - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); - ck_tile::DeviceMem workspace_data(workspace_size); - workspace_data.SetZero(); - kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + auto kargs = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); - if(!Kernel::IsSupportedArgument(kargs)) - { - EXPECT_TRUE(false); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } - dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); - dim3 block_dims = Kernel::BlockSize(); + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); - return kargs.tile_partitioner.estimate_num_wgs_per_tile(); - }; - - return Run(ck_tile::integral_constant{}); + return kargs.tile_partitioner.estimate_num_wgs_per_tile(); } public: diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 7c085b5098..875684ce08 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -180,68 +180,52 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmPipeline = typename GemmPipelineTypeSelector::pipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + dim3 grids; + if constexpr(Persistent) { - Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - Run(ck_tile::integral_constant{}); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp index bdce90e385..237641a000 100644 --- a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -42,8 +42,7 @@ template + index_t NDimSpatial = 2> struct BuildKernel { using GemmShape = TileGemmShape< @@ -123,7 +122,6 @@ struct BuildKernel ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, ConvTraits::FixedGemmParams::TransposeC, - MemOp, ConvConfig::NumWaveGroups, ConvTraits::FixedGemmParams::FixedVectorSize, ConvTraits::VectorSizeC>; @@ -212,26 +210,6 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne) EXPECT_FALSE(Kernel::IsSupportedArgument(kargs)); } -TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne) -{ - using Kernel = typename BuildKernel::type; - - // k_batch = 1 should fail with atomic_add - auto host_args_kbatch_1 = create_2d_host_args(1); - auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1); - EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1)); - - // k_batch = 2 should pass - auto host_args_kbatch_2 = create_2d_host_args(2); - auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2)); -} - TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation) { using Kernel = typename BuildKernel; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + if(s.log_level_ > 0) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = false; @@ -212,50 +193,47 @@ class TestCkTileGroupedGemm : public ::testing::Test CLayout, TransposeC>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + ck_tile::ignore = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, @@ -264,19 +242,6 @@ class TestCkTileGroupedGemm : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); - }; - - if(splitk) - { - Run(ck_tile::integral_constant{}); - } - else - { - - Run(ck_tile::integral_constant{}); - } } auto calculate_rtol_atol(const ck_tile::index_t K, @@ -422,8 +387,7 @@ class TestCkTileGroupedGemm : public ::testing::Test { // Generate kernel arguments std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = gemm_descs[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : gemm_descs) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, @@ -448,10 +412,10 @@ class TestCkTileGroupedGemm : public ::testing::Test stream.stream_id_)); #if CK_TILE_USE_WMMA invoke_grouped_gemm_persistent( - stream, group_count, kargs_ptr, splitk); + stream, group_count, kargs_ptr); #else invoke_grouped_gemm_persistent( - stream, group_count, kargs_ptr, splitk); + stream, group_count, kargs_ptr); #endif } else diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index b065df6f8a..c6e311a65c 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -96,7 +96,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test const ck_tile::stream_config& s, void* kargs_ptr) { - + EXPECT_TRUE(gemm_descs[0].k_batch == 1); using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, @@ -134,74 +134,56 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test ck_tile::GemmPipelineAgBgCrCompV3, ck_tile::GemmPipelineAgBgCrCompV4>>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -218,78 +200,58 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test BLayout, ELayout>; - float ave_time{0}; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = std::conditional_t< - Config::Pipeline_ == (PipelineType::Memory), - ck_tile::GemmPipelineAgBgCrMem, - std::conditional_t, - ck_tile::GemmPipelineAgBgCrCompV4>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - if(!splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } public: @@ -445,8 +407,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test if constexpr(Config::Persistent_) { std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = gemm_descs[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : gemm_descs) { kargs.emplace_back( @@ -471,7 +432,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test hipMemcpyHostToDevice, stream.stream_id_)); - invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk); + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr); } else { diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index a7189e7865..e588ad2cc1 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -127,59 +127,44 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) - { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } private: @@ -226,59 +211,45 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test ck_tile::GemmPipelineScheduler::Default>; using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(gemm_descs[0].k_batch == 1) - { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } struct BShuffleGemmConfig diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index b73221ac28..3d52bca9e0 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -148,10 +148,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using QuantGemmProblem = std::conditional_t< UseGroupedQuant, @@ -217,8 +216,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test GroupedGemKernelParam::M_Warp_Tile, GroupedGemKernelParam::N_Warp_Tile, GroupedGemKernelParam::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; + QuantGemmProblem::TransposeC>>; using Kernel = ck_tile::QuantGroupedGemmKernel; - using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || - QuantType == ck_tile::QuantType::BQuantGrouped; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; - - using GemmPipeline = std::conditional_t< - UseGroupedQuant, - std::conditional_t< - QuantType == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, - ck_tile::GemmPipelineAgBgCrCompV3>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + ck_tile::ignore = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, @@ -388,10 +379,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); - }; - - Run(ck_tile::integral_constant{}); } template diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 27ca805c2e..81a9b08b70 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -719,8 +719,8 @@ struct SelectedKernel {{ elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: instance_code += f""" - // Kernel type - using GemmKernel = ck_tile::GemmKernel; + // Kernel type + using GemmKernel = ck_tile::GemmKernel; // Kernel arguments auto kargs = GemmKernel::MakeKernelArgs(args); @@ -802,8 +802,8 @@ struct SelectedKernel {{ ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 2225619fad..bea46de067 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -481,8 +481,6 @@ struct SelectedKernel {{ GemmUniversalTraits>; static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{ - const auto Run = [&](const auto memory_operation_) {{ - constexpr auto memory_operation = memory_operation_.value; constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; // kNumWaveGroups_ using GemmEpilogue = ck_tile::CShuffleEpilogue; @@ -558,30 +555,12 @@ struct SelectedKernel {{ workspace_data.SetZero(); }} }}; - - + // Launch kernel - float ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( stream, reset_data_buffers, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - return ave_time; - - // ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); - // return std::make_tuple(ave_time, num_wgs_per_tile); - }}; - - - if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy) - {{ - return Run(ck_tile::integral_constant{{}}); - }} - else // We are using ck_tile::StreamKReductionStrategy::Reduction - {{ - return Run(ck_tile::integral_constant{{}}); - }} }} }}; """ From cc75a1dc5f18613af29d8821375f79b0f3c6410b Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 5 Jan 2026 18:41:47 +0800 Subject: [PATCH 07/23] [FMHA] Batch Prefill Support Improvements: Change KV Cache Layout & Large Page Size Support (#3442) * add page_block_size parameter * add is_sglang_layout to parameters * add kv_offset_array_transform to batch async for page size 16 * add kv_last_page_lens to kernel * change kv layout to [num_total_pages, page_block_size, hdim] * format * - enable codegen of batch_prefill kernels - create new problem struct BlockFmhaBatchPrefillPipelineProblem for batch prefill kernels - generate different page sizes of batch prefill kernels (1, 16) * 1. fix wrong calculation of page id in kv_offset_array_transform in gfx950 2. support page size 1024 * fix python format * change kv cache layout to [num_blocks, num_kv_heads, head_size/x, block_size, x] and [num_blocks, num_kv_heads, block_size/X, head_size, X] * 1. Introduced `kVectorSize` in BlockFmhaBatchPrefillPipelineProblem instead of using hardcode values 2. Makes batch prefill kernel traits structures inherent from fmha fwd traits 3. Add some static check for Page size, vector size, hdim, ..., etc. * [Refactor] Replace is_sglang_layout with Enums for KV cache configuration Refactored `fmha_batch_prefill` to use `BlockAttentionKVCacheMemoryLayoutEnum` (VECTORIZED/LINEAR) and `BlockAttentionKVCacheLookupTableEnum` (SGLANG_1D/VLLM_2D) instead of a single boolean. **Changes:** * Added Enum definitions in `block_attention_kvcache_layout_enum.hpp`. * Updated Kernel, Pipeline, and Traits to template on these Enums. * Implemented `kv_offset_array_transform` logic based on `kKVMemoryLayout`. * Refactored `PageBlockTableKargs` to adapt to `kKVLookupTable`. * Updated CodeGen scripts to support new parameters. This decouples memory layout from the paging mechanism, enabling flexible KV cache configurations. * 1. remove batch prefill pipeline with sk_pad=false 2. correct some comments 3. add static assert to make sure v offsets is in same page within a tile. * fix vgpr spill count * remove unnecessary t2s functions * add fp8 support for receipt 200 and 600 in fmha_bath_prefill.py * support linear kv cache layout * Remove block_table_ptr from fwd_batch_prefill_args. Instead, reuse kv_page_indices as a pointer of the lookup table. * 1. merge multiple transforms into single transform. 2. add static check to make sure vlayout is row-major. * move FmhaFwdCommonKargs::seqlen_k_ptr to VllmPageTableKargs. * update changelog --------- Co-authored-by: ltqin Co-authored-by: PoYen, Chen --- CHANGELOG.md | 1 + .../01_fmha/codegen/ops/fmha_batch_prefill.py | 214 ++++-- example/ck_tile/01_fmha/fmha_fwd.hpp | 124 +++- include/ck_tile/ops/fmha.hpp | 1 + .../block_attention_kvcache_layout_enum.hpp | 32 + .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 286 +++++--- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 646 +++++++++++------- .../pipeline/block_fmha_pipeline_problem.hpp | 66 ++ .../ops/fmha/pipeline/tile_fmha_traits.hpp | 43 ++ 9 files changed, 983 insertions(+), 430 deletions(-) create mode 100644 include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a9b25b062..ce8e5197a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added FP8 KV cache support for FMHA batch prefill. +* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 95e8379769..c4c70009d5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,6 +36,19 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} +SUPPORTED_PAGE_SIZE = [128, 256, 1024] +SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] +SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] +KV_MEMORY_LAYOUT_ENUM_MAP = { + "vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT", + "linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT", +} +KV_LOOKUP_TABLE_ENUM_MAP = { + "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", + "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", +} + + FMHA_BATCH_PREFILL_PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", } @@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, +using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, @@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_lse}, {F_dropout}, {F_qscale}, - {F_occupancy}>; + {F_occupancy}, + false, + {F_page_size}, + {F_kv_memory_layout}, + {F_kv_lookup_table}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, typename FmhaFwdTypeConfig::VDataType, @@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< fmha_variant_{F_idx}, fmha_mask_{F_idx}, false, + {F_page_size}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} = using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; +using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; #include @@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; return fmha_batch_prefill_(s, a); }} """ @@ -230,12 +248,15 @@ class FmhaFwdApiTrait: dpad: str dvpad: str constraint: CppConstraint + kv_memory_layout: str + kv_lookup_table: str + page_size: int = 1 # page block size @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" ) @property @@ -322,6 +343,8 @@ class FmhaFwdPipeline: F_dropout: str # F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP + F_kv_memory_layout: str # + F_kv_lookup_table: str # F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -382,6 +405,8 @@ class FmhaFwdPipeline: n += f"_{self.F_qscale}" else: n += "_nqscale" + + n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table return n @@ -440,6 +465,13 @@ class FmhaFwdApiPool: F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype], + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + trait.kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + trait.kv_lookup_table + ], + F_page_size=trait.page_size, ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -497,6 +529,7 @@ class FmhaFwdKernel: F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline mask_impl: str + F_page_size: int = 1 # page block size @property def template(self) -> str: @@ -534,17 +567,24 @@ class FmhaFwdKernel: F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale], F_occupancy=self.F_tile.F_occupancy, + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + self.F_pipeline.F_kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + self.F_pipeline.F_kv_lookup_table + ], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], + F_page_size=self.F_page_size, ) @property def name(self) -> str: # TODO: we don't encode idx here return ( - f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + self.F_tile.name + "_" + self.F_pipeline.name @@ -578,6 +618,9 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + kv_memory_layout=self.F_pipeline.F_kv_memory_layout, + kv_lookup_table=self.F_pipeline.F_kv_lookup_table, + page_size=self.F_page_size, ) @@ -604,23 +647,42 @@ class KernelComponentFactory: pipelines = [] if dtype in ["fp16", "bf16"]: qscale = "no" - for logits, mask, bias, lse, dropout in itertools.product( + for ( + logits, + mask, + bias, + lse, + dropout, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for ( + logits, + qscale, + mask, + bias, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], ["pertensor"], get_mask_map(mask_impl).keys(), ["no"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines @@ -672,69 +734,73 @@ def get_fwd_blobs( or pipeline.F_logits == "f" ): continue - k = FmhaFwdKernel( - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_batch_prefill) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_batch_prefill C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == "fp32" - if not cond: - continue + # Generate kernels for both page_size=16 and page_size=1024 + for page_size in SUPPORTED_PAGE_SIZE: + k = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue - api_pool.register_traits(k.api_trait()) - gen.append(k) + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == "fp32" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) return (api_pool, gen) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ba55d6d722..3ff4acfc15 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -529,14 +529,25 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; - // SGLang-style page table - int32_t num_total_pages; - void* kv_indptr; - void* kv_page_indices; -#if 0 // we assume page_block_size=1 for now - void* kv_last_page_lens; - ck_tile::index_t page_block_size; -#endif + // KV cache page table fields (kv_lookup_table selects interpretation): + // - SGLANG_PAGE_TABLE_1D: + // kv_indptr: prefix-sum [batch+1] into kv_page_indices + // kv_page_indices: 1D list of physical page ids, length = num_total_pages + // kv_last_page_lens: per-batch last page lengths [batch] + // - VLLM_BLOCK_TABLE_2D: + // kv_page_indices: block_table [batch, max_blocks_per_seq] (2D) + // batch_stride_block_table: row stride for block_table + // seqlen_k_ptr: per-batch seqlen_k [batch] + int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM) + ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum + kv_memory_layout; // KV memory layout (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector + void* kv_indptr; // SGLang: prefix-sum; vLLM: unused + void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D + void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused + void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused + ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused float scale_s; float scale_p; @@ -1113,6 +1124,22 @@ template auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) { assert(args.nhead_q % args.nhead_k == 0); + using PageTableKargs = typename FmhaKernel::PageBlockTableKargs; + const PageTableKargs page_table = [&]() { + if constexpr(FmhaKernel::kKVLookupTable == + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + return PageTableKargs{reinterpret_cast(args.kv_indptr), + reinterpret_cast(args.kv_page_indices), + reinterpret_cast(args.kv_last_page_lens)}; + } + else + { + return PageTableKargs{reinterpret_cast(args.kv_page_indices), + args.batch_stride_block_table, + reinterpret_cast(args.seqlen_k_ptr)}; + } + }(); auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) @@ -1133,12 +1160,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1184,12 +1207,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1281,6 +1300,65 @@ struct fmha_fwd_traits_ static constexpr bool kHasSink = kHasSink_; }; +template +struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ +{ + static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; + static constexpr auto kKVLookupTable = kKVLookupTable_; + static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_; + static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout"); +}; + template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); @@ -1527,7 +1605,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); -using fmha_batch_prefill_traits = fmha_fwd_traits; +struct fmha_batch_prefill_traits : public fmha_fwd_traits +{ + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + int page_size = 1; +}; + float fmha_batch_prefill(fmha_batch_prefill_traits, fmha_batch_prefill_args, const ck_tile::stream_config&); diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 20714397c9..eb4aa16d05 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp new file mode 100644 index 0000000000..c79e639469 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +// KV cache memory layout selector. +// +// Layout summary (kVectorSize = 16 / sizeof(KDataType)): +// - VECTORIZED_LAYOUT (swizzled): +// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize] +// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize] +// - LINEAR_LAYOUT: +// K: [NumBlocks, PageSize, NumHeads, HeadDim] +// V: [NumBlocks, PageSize, NumHeads, HeadDim] +enum class BlockAttentionKVCacheMemoryLayoutEnum +{ + VECTORIZED_LAYOUT = 0, + LINEAR_LAYOUT = 1, +}; + +// KV cache lookup table layout selector. +// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq] +// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1]) +enum class BlockAttentionKVCacheLookupTableEnum +{ + VLLM_BLOCK_TABLE_2D = 0, + SGLANG_PAGE_TABLE_1D = 1, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 73b6a329d1..9afd097eed 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" @@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; + static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout; + static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable; + static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize; + static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs @@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. + struct SglangPageTableKargs + { + const int32_t* kv_indptr; + const int32_t* kv_page_indices; + const int32_t* kv_last_page_lens; + }; + + struct VllmPageTableKargs + { + const int32_t* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + const int32_t* seqlen_k_ptr; + }; + + using PageBlockTableKargs = + std::conditional_t; + struct FmhaFwdCommonKargs { const void* q_ptr; @@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t nhead_ratio_qk; int32_t num_total_pages; - const int32_t* kv_indptr; - const int32_t* kv_page_indices; -#if 0 // we assume page_block_size=1 for now - const int32_t* kv_last_page_lens; ck_tile::index_t page_block_size; -#else - static constexpr ck_tile::index_t page_block_size = 1; -#endif + PageBlockTableKargs page_table; float scale_s; @@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, -#if 0 // we assume page_block_size=1 for now - const void* kv_last_page_lens, ck_tile::index_t page_block_size, -#endif + const PageBlockTableKargs& page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, @@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel num_head_q, nhead_ratio_qk, num_total_pages, - reinterpret_cast(kv_indptr), - reinterpret_cast(kv_page_indices), -#if 0 // we assume page_block_size=1 for now - reinterpret_cast(kv_last_page_lens), page_block_size, -#endif + page_table, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, -#if 0 // we assume page_block_size=1 for now - const void* kv_last_page_lens, ck_tile::index_t page_block_size, -#endif + const PageBlockTableKargs& page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, @@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel num_head_q, nhead_ratio_qk, num_total_pages, - reinterpret_cast(kv_indptr), - reinterpret_cast(kv_page_indices), -#if 0 // we assume page_block_size=1 for now - reinterpret_cast(kv_last_page_lens), page_block_size, -#endif + page_table, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; - const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; -#if 0 // we assume page_block_size=1 for now - const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; -#endif + const index_t seqlen_k = [&]() { + if constexpr(kKVLookupTable == + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + const int32_t page_start = kargs.page_table.kv_indptr[i_batch]; + const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1]; + const int32_t num_page_blocks = page_end - page_start; + const int32_t last_page_len = [&]() { + if constexpr(kPageBlockSize == 1) + return static_cast(kPageBlockSize); + else + return kargs.page_table.kv_last_page_lens[i_batch]; + }(); + return num_page_blocks > 0 + ? static_cast((num_page_blocks - 1) * kargs.page_block_size + + last_page_len) + : 0; + } + else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D + { + if(kargs.page_table.seqlen_k_ptr != nullptr) + return static_cast(kargs.page_table.seqlen_k_ptr[i_batch]); + else + return kargs.seqlen_k; + } + }(); + const int32_t* page_idx = [&]() { + if constexpr(kKVLookupTable == + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch]; + } + else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D + { + return kargs.page_table.block_table_ptr + + static_cast(i_batch) * + kargs.page_table.batch_stride_block_table; + } + }(); + if constexpr(kIsGroupMode) { // get starting offset for each batch @@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel batch_offset_q = query_start * kargs.stride_q; - kargs.kv_page_indices += kargs.kv_indptr[i_batch]; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias; @@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return; } -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + kargs.seqlen_k = seqlen_k; } else { batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - kargs.kv_page_indices += kargs.kv_indptr[i_batch]; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; @@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + kargs.seqlen_k = seqlen_k; } // for simplicity, batch stride we just modify the pointer @@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } }(); const auto k_dram = [&]() { - const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - const auto v_dram = [&]() { - if constexpr(std::is_same_v) + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, + // Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize] + // Logical View for Pipeline: (TotalSeqK, D) + + // Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize, + // PageBlockSize, kVectorSize) + // Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1) + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages, + kargs.hdim_q / kVectorSize, + kargs.page_block_size, + kVectorSize), + make_tuple( + kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1), + number{}, number<1>{}); - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple( - make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)), - make_tuple(sequence<1>{}, sequence<0>{}), + // Merge to (TotalSeqK, D) in a single transform: + // physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D) + auto k_dram_2d = transform_tensor_view( + k_dram_naive, + make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages, + kargs.page_block_size)), // TotalSeqK + make_merge_transform( + make_tuple(static_cast(kargs.hdim_q / kVectorSize), + static_cast(kVectorSize)))), // D + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; return pad_tensor_view( - v_dram_transposed, + k_dram_2d, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + // Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim] + // Logical View for Pipeline: (TotalSeqK, D) + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q), + make_tuple(kargs.batch_stride_k, kargs.stride_k, 1), + number{}, + number<1>{}); + + // Merge to (TotalSeqK, D) in a single transform: + // physical (Page, S, D) -> logical (TotalSeqK, D) + auto k_dram_2d = transform_tensor_view( + k_dram_naive, + make_tuple(make_merge_transform( + make_tuple(kargs.num_total_pages, kargs.page_block_size)), + make_pass_through_transform(kargs.hdim_q)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + k_dram_2d, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto v_dram = [&]() { + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize] + // Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM + + // Define the naive physical view with 4D shape: (NumPages, + // PageBlockSize/kVectorSize, HeadDim, kVectorSize) + // Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1) + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.num_total_pages, + kargs.page_block_size / kVectorSize, + kargs.hdim_v, + kVectorSize), + make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1), + number{}, + number<1>{}); + + // Merge to (D, TotalSeqK) in a single transform: + // physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK) + auto v_dram_final = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), // D + make_merge_transform(make_tuple(kargs.num_total_pages, + kargs.page_block_size / kVectorSize, + kVectorSize))), // TotalSeqK + make_tuple(sequence<2>{}, sequence<0, 1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_final, make_tuple(number{}, number{}), sequence{}); } else { + // Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim] + // Logical View for Pipeline: (D, TotalSeqK) const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size), - make_tuple(kargs.stride_v, 1), + make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v), + make_tuple(kargs.batch_stride_v, kargs.stride_v, 1), number{}, number<1>{}); - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( + // Merge to (D, TotalSeqK) in a single transform: + // physical (Page, S, D) -> logical (D, TotalSeqK) + auto v_dram_final = transform_tensor_view( v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_merge_transform( + make_tuple(kargs.num_total_pages, kargs.page_block_size))), + make_tuple(sequence<2>{}, sequence<0, 1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_final, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + const index_t stride_k_for_pipeline = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT + ? kVectorSize + : kargs.stride_k; + const index_t stride_v_for_pipeline = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT + ? kargs.hdim_v + : kargs.stride_v; + auto o_acc_tile = [&] { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { @@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel variant_params, block_indices, smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, dropout); } else @@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel variant_params, block_indices, smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, dropout); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 2102fe768f..0b47441995 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,12 +6,82 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { +template +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec, + const index_t& stride_kv, + const index_t& page_stride_kv, + const CoordVecType& coord_vec, + OffsetVecType& kv_offset_vec, + index_t global_seq_offset = 0) +{ + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; + if constexpr(kIsKcache) + { + // for k offsets + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t page_offset = global_token_idx & kInPageOffsetMask; + kv_offset_vec[k0] = static_cast(page_vec[page_id]) * page_stride_kv + + static_cast(page_offset) * stride_kv; + }); + } + else + { + // for v offsets + const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); + const index_t lane0_page_id = + (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + + const long_index_t page_loc = + static_cast(page_vec[lane0_page_id]) * page_stride_kv; + + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t page_offset = + (global_seq_offset + thread_coord_start + kLoopStart + k0.value) & + kInPageOffsetMask; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout offset + // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize] + // Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize) + const index_t s = page_offset; + const index_t D = stride_kv; + + const long_index_t s_offset = + static_cast((s / kVectorSize) * (D * kVectorSize)) + + (s % kVectorSize); + + kv_offset_vec[k0] = page_loc + s_offset; + } + else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT + { + kv_offset_vec[k0] = page_loc + static_cast(page_offset) * stride_kv; + } + }); + } +} // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) template {}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; + static constexpr index_t kLog2PageSize = Problem::kLog2PageSize; + static constexpr index_t kVectorSize = Problem::kVectorSize; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kPageBlockSize % kN0 == 0, + "V offset assumes each tile stays within a page; kPageBlockSize must be " + "divisible by kN0."); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) @@ -68,6 +144,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -196,6 +273,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, DropoutType& dropout) const { static_assert( @@ -325,9 +404,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using KDstrEncode = typename decltype(k_dist)::DstrEncode; constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); + index_t current_seq_k = seqlen_k_start; + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kLog2PageSize, + 0, + NRepeat, + kN0 / NRepeat, + kKVMemoryLayout, + true, + kVectorSize>( + page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), @@ -360,10 +450,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using VDstrEncode = typename decltype(v_dist)::DstrEncode; constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; statically_indexed_array v_offsets; - (void)stride_k; - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; - }); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + 0, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -425,13 +523,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync async_load_fence(); __builtin_amdgcn_s_barrier(); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - v_dram_window.update_page_idx(v_offsets); - __builtin_amdgcn_sched_barrier(0); { // tail gemm_0( @@ -444,49 +535,67 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(1); - // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif - }, - s_acc, - bias_tile); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_dram_window.update_page_idx(v_offsets); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto p = [&]() { + const auto bias_tile = load_tile(bias_dram_window); // load bias tile - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); - }); - }); - } - else - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - if constexpr(kHasLogitsSoftCap) + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - auto apply_logits_transform = - [&variant, &variant_params, &block_indices](auto& x) { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = [&variant, &variant_params, &block_indices]( + auto& x) { x = variant.LogitsTransform(variant_params, variant.QueryTransform(variant_params, x), block_indices.batch_idx, @@ -494,216 +603,229 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync block_indices.kv_head_idx); }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } #else - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { #if(defined(__gfx90a__) || defined(__gfx94__)) && \ (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) - // Avoid data hazard if v_mfma is followed by inline asm consumer - // instructions. In this case, compiler won't add s_nop for us - if(i == s_acc.thread_buf_.size() / 2) - { - __builtin_amdgcn_sched_barrier(0); + // Avoid data hazard if v_mfma is followed by inline asm consumer + // instructions. In this case, compiler won't add s_nop for us + if(i == s_acc.thread_buf_.size() / 2) + { + __builtin_amdgcn_sched_barrier(0); + } +#endif + apply_logits_transform(s_acc.thread_buf_[i]); } #endif - apply_logits_transform(s_acc.thread_buf_[i]); - } -#endif - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif - } - } - move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); - - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }); - } - } - - const auto s = cast_tile(s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s.get_tile_distribution()); // Pcompute{j} - - __builtin_amdgcn_sched_barrier(0x7F); - // store & prefetch next v, after the max reduction - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_buf); - - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - - store_tile( - v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch - } - - if constexpr(k1_loops > 1) - { - move_tile_window( - v_dram_window, - {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = - page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - v_dram_window.update_page_idx(v_offsets); - } - __builtin_amdgcn_sched_barrier(0); - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration. alibi does not have this problem - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - if constexpr(kHasLogitsSoftCap) +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, + m, + m_old, + m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, + kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + 2 * kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_dram_window.update_page_idx(v_offsets); + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } - } #else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif + }); }); - }); - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } - } - }(); + }(); #else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); #endif - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); }); - }); - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); - } + if constexpr(kHasDropout) + { + auto randval_ptr = reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeKV(); + dropout + .template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } - const auto p = [&]() { #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN // For fp32 to fp16, // impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue, @@ -727,11 +849,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + - v_coord[VPageIndexDim] + k0.value] * - stride_v; - }); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + (2 + i_k1.value) * kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); } block_sync_lds(); @@ -772,14 +901,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - page_idx += kN0; + current_seq_k += kN0; // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kLog2PageSize, + 0, + NRepeat, + kN0 / NRepeat, + kKVMemoryLayout, + true, + kVectorSize>( + page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) @@ -887,6 +1025,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, @@ -913,6 +1053,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_idx, stride_k, stride_v, + page_stride_k, + page_stride_v, dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index a192e3f7b0..f9dc94bc65 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" namespace ck_tile { @@ -65,6 +66,71 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasSink = Traits::kHasSink; }; +template +struct BlockFmhaBatchPrefillPipelineProblem + : public BlockFmhaPipelineProblem +{ + static constexpr index_t kPageBlockSize = kPageBlockSize_; + static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive"); + static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, + "kPageBlockSize must be power of two"); + static constexpr index_t kLog2PageSize = []() constexpr { + index_t shift = 0; + index_t val = kPageBlockSize_; + while(val > 1) + { + val >>= 1; + shift++; + } + return shift; + }(); + + static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 + static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; + static constexpr auto kKVLookupTable = Traits_::kKVLookupTable; + static constexpr bool kIsVectorizedLayout = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + + static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0, + "kQKHeaddim must be divisible by kVectorSize"); + static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0, + "kPageBlockSize must be divisible by kVectorSize for vectorized layout"); + static_assert(kIsGroupMode_, "Batch prefill requires group mode"); +}; + template +struct TileFmhaBatchPrefillTraits : public TileFmhaTraits +{ + static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; + static constexpr auto kKVLookupTable = kKVLookupTable_; + static constexpr index_t kPageBlockSize = kPageBlockSize_; + static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT || + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT, + "Batch prefill only supports vectorized or linear KV cache layout."); + static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0), + "kPageBlockSize should be a power of 2 to support efficient page-based KV cache " + "addressing."); +}; + template Date: Mon, 5 Jan 2026 13:57:34 +0100 Subject: [PATCH 08/23] [CK_BUILDER] validation (#3471) This pull request builds on #3267 by proving the "validation" infrastructure, the means to compare a set of `Outputs`. The design of the validation infrastructure is relatively straight forward: - Each SIGNATURE should come with a `validate()` implementation, which should be implemented in a similar way that the other functions/types from `testing.hpp` are implemented. - `validate()` returns a `ValidationReport`, which is a structure that keeps all relevant information about comparing the tensors from two `Outputs`. Note that crucially, `validate()` should not do any reporting by itself. Rather, glue logic should be implemented by the user to turn `ValidationReport` into a relevant error message. - You can see this clue code for CK-Builder itself in `testing_utils.hpp`, its `MatchesReference()`. This functionality is relatively barebones right now, it will be expanded upon in a different PR to keep the scope of this one down. The comparison is done on the GPU (using an atomic for now), to keep tests relatively quick. Some notable items from this PR: - To help compare the tensors and with writing tests, I've written a generic function `tensor_foreach` which invokes a callback on every element of a tensor. - For that it was useful that the `TensorDescriptor` has a rank which is known at compile-time, so I've changed the implementation of `TensorDescriptor` for that. I felt like it was a better approach than keeping it dynamic, for multiple reasons: - This is C++ and we should use static typing where possible and useful. This way, we don't have to implement runtime assertions about the tensor rank. - We know already know the rank of tensors statically, as it can be derived from the SIGNATURE. - It simpifies the implementation of `tensor_foreach` and other comparison code. - There are a lot of new tests for validating the validation implementation, validating validation validation tests (Only 3 recursive levels though...). For a few of those functions, I felt like it would be useful to expose them to the user. - Doc comments everywhere. --- .../factory/helpers/ck/conv_tensor_type.hpp | 5 + .../ck_tile/builder/testing/conv_fwd.hpp | 75 ++- .../include/ck_tile/builder/testing/error.hpp | 150 ++++++ .../testing/{extent.hpp => filter_extent.hpp} | 17 +- .../ck_tile/builder/testing/tensor_buffer.hpp | 160 +------ .../builder/testing/tensor_descriptor.hpp | 444 ++++++++++++++++++ .../builder/testing/tensor_foreach.hpp | 258 ++++++++++ .../builder/testing/tensor_initialization.hpp | 77 ++- .../ck_tile/builder/testing/testing.hpp | 55 ++- .../ck_tile/builder/testing/validation.hpp | 167 +++++++ experimental/builder/test/CMakeLists.txt | 43 +- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 16 +- .../builder/test/test_inline_diff.cpp | 16 +- experimental/builder/test/testing_utils.hpp | 54 +++ .../builder/test/unit_conv_tensor_type.cpp | 49 +- .../builder/test/unit_device_buffer.cpp | 17 +- experimental/builder/test/unit_error.cpp | 46 ++ .../builder/test/unit_tensor_descriptor.cpp | 155 +++++- .../builder/test/unit_tensor_foreach.cpp | 205 ++++++++ experimental/builder/test/unit_validation.cpp | 277 +++++++++++ 20 files changed, 2001 insertions(+), 285 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/testing/error.hpp rename experimental/builder/include/ck_tile/builder/testing/{extent.hpp => filter_extent.hpp} (50%) create mode 100644 experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp create mode 100644 experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp create mode 100644 experimental/builder/include/ck_tile/builder/testing/validation.hpp create mode 100644 experimental/builder/test/unit_error.cpp create mode 100644 experimental/builder/test/unit_tensor_foreach.cpp create mode 100644 experimental/builder/test/unit_validation.cpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index c819e11d00..9430573cc6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -47,6 +47,11 @@ struct DataTypeToCK { using type = ck::f8_t; }; +template <> +struct DataTypeToCK +{ + using type = uint8_t; +}; struct CK_empty_tuple { diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index 62d265894a..8cbafa7efa 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -7,11 +7,14 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/builder/testing/testing.hpp" -#include "ck_tile/builder/testing/extent.hpp" +#include "ck_tile/builder/testing/filter_extent.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" #include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/validation.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + /// This file implements common functionality for invoking/testing grouped /// forward convolutions created through the CK Builder API. The main item /// of it is the ConvArgs structure - which contains a complete description @@ -37,12 +40,12 @@ namespace ck_tile::builder::test { template struct ConvTensorLengths { - size_t batch_size = 1; // N - size_t groups = 1; // G - size_t input_channels = 1; // C - size_t output_channels = 1; // K - Extent image = {}; // W, H, D - Extent filter = {}; // X, Y, Z + size_t batch_size = 1; // N + size_t groups = 1; // G + size_t input_channels = 1; // C + size_t output_channels = 1; // K + FilterExtent image = {}; // W, H, D + FilterExtent filter = {}; // X, Y, Z }; /// @brief `Args` specialization for forward convolution. @@ -59,6 +62,14 @@ struct Args constexpr static auto WEIGHT_TYPE = SIGNATURE.data_type; constexpr static auto OUTPUT_TYPE = SIGNATURE.data_type; + constexpr static int INPUT_RANK = 3 + SPATIAL_DIM; + constexpr static int WEIGHT_RANK = 3 + SPATIAL_DIM; + constexpr static int OUTPUT_RANK = 3 + SPATIAL_DIM; + + using InputDescriptor = TensorDescriptor; + using WeightDescriptor = TensorDescriptor; + using OutputDescriptor = TensorDescriptor; + // TODO: We shouldn't need to call into an internal namespace here. using Ops = factory::internal::ElementwiseOps; @@ -73,10 +84,10 @@ struct Args // implementation (based on ConvParam in old CK / CK Tile) does not // support strides at all. - Extent filter_strides; - Extent filter_dilation; - Extent input_left_pad; - Extent input_right_pad; + FilterExtent filter_strides; + FilterExtent filter_dilation; + FilterExtent input_left_pad; + FilterExtent input_right_pad; Ops::AElementwiseOp a_elementwise_op; Ops::BElementwiseOp b_elementwise_op; @@ -85,7 +96,7 @@ struct Args /// This function returns the `TensorDescriptor` corresponding to /// the input-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_input_descriptor() const + InputDescriptor make_input_descriptor() const { // TODO: We're using old CK functionality to compute the right // values here, mainly because CK tile does not support the @@ -96,31 +107,37 @@ struct Args const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< typename Layouts::ALayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + using Extent = typename InputDescriptor::Extent; + return InputDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// This function returns the `TensorDescriptor` corresponding to /// the weight-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_weight_descriptor() const + WeightDescriptor make_weight_descriptor() const { // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed< typename Layouts::BLayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + using Extent = typename WeightDescriptor::Extent; + return WeightDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// This function returns the `TensorDescriptor` corresponding to /// the output-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_output_descriptor() const + OutputDescriptor make_output_descriptor() const { // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed< typename Layouts::ELayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + using Extent = typename OutputDescriptor::Extent; + return OutputDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// Convert the Args structure into a CK conv_param structure. This @@ -245,12 +262,11 @@ UniqueInputs alloc_inputs(const Args& args) /// /// @see alloc_inputs() template - requires ValidConvSignature && ConvDirectionIsForward && - ValidUniqueInputs -void init_inputs(const Args& args, UniqueInputs& inputs) + requires ValidConvSignature && ConvDirectionIsForward +void init_inputs(const Args& args, Inputs inputs) { - init_tensor_buffer_uniform_fp(inputs.input_buf, args.make_input_descriptor(), -2.0f, 2.0f); - init_tensor_buffer_uniform_fp(inputs.weight_buf, args.make_weight_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); } /// @brief `alloc_outputs()` specialization for forward convolution. @@ -268,4 +284,19 @@ UniqueOutputs alloc_outputs(const Args& args) }; } +/// @brief `validate()` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see validate() +template + requires ValidConvSignature && ConvDirectionIsForward +ValidationReport +validate(const Args& args, Outputs actual, Outputs expected) +{ + ValidationReport report; + report.check("output", args.make_output_descriptor(), actual.output, expected.output); + return report; +} + } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/error.hpp b/experimental/builder/include/ck_tile/builder/testing/error.hpp new file mode 100644 index 0000000000..242f2a8e51 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/error.hpp @@ -0,0 +1,150 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +/// This file defines some utilities for dealing with HIP errors. In the CK-Builder +/// testing code, we'd like to just turn them into exceptions: This cleans up testing +/// code as we don't need to think about returning error codes, but its still much +/// cleaner than just creating a hard crash and thereby possibly interrupting other +/// units in the same test. The testing framework can catch these exceptions where +/// necessary. +/// +/// While the exceptions defined in this file are in principle suitable for general +/// usage, HIP functions which return HIP error codes (`hipError_t`) should be +/// checked using the `check_hip` function. + +namespace ck_tile::builder::test { + +/// @brief Generic HIP exception. +/// +/// This is a derivation of `std::runtime_error` which represents a HIP error code. +/// +/// @see std::runtime_error +/// @see hipError_t +struct HipError : std::runtime_error +{ + /// @brief Utility for formatting HIP error messages + /// + /// Returns a human-readable description of a HIP error. Given a description of the + /// activity that the user tried to perform, this function appends the HIP-specific + /// information such as the stringified version of the error code, and the error + /// code itself (for reference). + /// + /// @param user_msg User-given message about the activity at time of error. + /// @param code The status to report. + /// @param src The location where this error was discovered. + static std::string + format_error(std::string_view user_msg, hipError_t code, std::source_location src) + { + std::stringstream msg; + msg << user_msg << ": " << hipGetErrorString(code) << " (" << code << ")"; + if(src.function_name()) + msg << " in function '" << src.function_name(); + msg << "' at " << src.file_name() << ":" << src.line() << ":" << src.column(); + return msg.str(); + } + + /// @brief Construct a generic HIP error. + /// + /// @param msg User-given message about the activity at time of error. + /// @param code The status to report. + /// @param src The location where this error was discovered. Defaults to the caller's + /// location. + HipError(std::string_view msg, + hipError_t code, + std::source_location src = std::source_location::current()) + : std::runtime_error(format_error(msg, code, src)), code_(code) + { + } + + /// @brief Retrieve the inner error code. + /// + /// This function returns the status code that was encountered while checking an + /// operation for errors. + hipError_t code() const { return code_; } + + private: + hipError_t code_; +}; + +/// @brief HIP out of memory error. +/// +/// This a derivation of `HipError` which is specialized for Out-of-memory errors. This +/// makes it easier to attach additional context, and to match on these errors while +/// using `catch` blocks. +/// +/// @see HipError +struct OutOfDeviceMemoryError : HipError +{ + /// @brief Construct an out-of-device-memory error. + /// + /// @param msg User-given message about the activity at time of error. + /// @param src The location where this error was discovered. Defaults to the caller's + /// location. + OutOfDeviceMemoryError(std::string_view msg = "failed to allocate device memory", + std::source_location src = std::source_location::current()) + : HipError(msg, hipErrorOutOfMemory, src) + { + } +}; + +/// @brief Check HIP status for errors. +/// +/// This function checks a HIP status code (obtained from a HIP function call) for any +/// errors. If the status `code` is not `hipSuccess`, this function throws an instance of +/// `HipError`. The exact type thats thrown depends on the status. If `code` represents +/// an out-of-memory error `hipErrorOutOfMemory`, then `OutOfDeviceMemoryError` will be +/// thrown instead. +/// +/// @param msg User-given message about the activity at possible time of error. +/// @param code The HIP status code to examine. +/// @param src The location where this status was set. Defaults to the caller's location. +/// +/// @throws HipError if `code` is not `hipSuccess`. +/// +/// @see HipError +/// @see OutOfDeviceMemoryError +inline void check_hip(std::string_view msg, + hipError_t code, + std::source_location src = std::source_location::current()) +{ + // -Wswitch-enum throws a warning if this code is changed into a switch, even with + // the `default` label... + + if(code == hipSuccess) + // When you beat the error allegations + return; + else if(code == hipErrorOutOfMemory) + throw OutOfDeviceMemoryError(msg, src); + else + throw HipError(msg, code, src); +} + +/// @brief Check HIP status for errors. +/// +/// This function is similar to `check_hip(std::string_view, hipError_t)`, except that a +/// default message is given. +/// +/// @param code The HIP status code to examine. +/// @param src The location where this status was set. Defaults to the caller's location. +/// +/// @throws HipError if `code` is not `hipSuccess`. +/// +/// @see HipError +/// @see OutOfDeviceMemoryError +/// @see check_hip(std::string_view, hipError_t) +inline void check_hip(hipError_t code, std::source_location src = std::source_location::current()) +{ + check_hip(code == hipErrorOutOfMemory ? "failed to allocate device memory" + : "HIP runtime error", + code, + src); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/extent.hpp b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp similarity index 50% rename from experimental/builder/include/ck_tile/builder/testing/extent.hpp rename to experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp index a2d9b3ff4c..3587ac406f 100644 --- a/experimental/builder/include/ck_tile/builder/testing/extent.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp @@ -5,28 +5,29 @@ namespace ck_tile::builder::test { -/// This structure describes a 1-, 2-, or 3-D extent. Its used to -/// communicate 1-, 2- or 3-D sizes and strides of tensors. -/// Depending on the dimension, the structure will have the `width`, -/// `height`, and `depth` fields available. +/// This structure describes a 1-, 2-, or 3-D extent for convolution +/// filters. Its used to communicate 1-, 2- or 3-D sizes and strides +/// of tensors, specifically for convolution filters. Depending on the +/// dimension, the structure will have the `width`, `height`, and +/// `depth` fields available. template -struct Extent; +struct FilterExtent; template <> -struct Extent<1> +struct FilterExtent<1> { size_t width = 1; }; template <> -struct Extent<2> +struct FilterExtent<2> { size_t width = 1; size_t height = 1; }; template <> -struct Extent<3> +struct FilterExtent<3> { size_t width = 1; size_t height = 1; diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp index 42f85f8017..6043ba2103 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp @@ -3,19 +3,15 @@ #pragma once +#include "ck_tile/builder/testing/error.hpp" +#include #include #include -#include -#include -#include -#include -#include "ck_tile/builder/conv_signature_concepts.hpp" -#include "ck_tile/builder/testing/type_traits.hpp" -#include "ck_tile/host/host_tensor.hpp" +#include -/// This file deals with tensor memory allocation: Both the act of allocating -/// and (automatically) deallocating memory, as well as utilities for managing -/// the layout of tensor data in memory. +/// This file deals with tensor memory management and allocation. The main +/// item is the `DeviceBuffer`: An owned piece of device memory, which is +/// automatically freed when it goes out of scope. namespace ck_tile::builder::test { @@ -39,31 +35,6 @@ struct DeviceMemoryDeleter } }; -/// @brief HIP out of memory error -/// -/// This is a derivation of `std::runtime_error` specialized for HIP -/// out-of-memory errors. -/// -/// @see std::runtime_error -struct OutOfDeviceMemoryError : std::runtime_error -{ - /// @brief Utility for formatting out-of-memory error messages - /// - /// Returns a human-readable description of a HIP out-of-memory error. - /// - /// @param status The status to report - static std::string format_error(hipError_t status) - { - return std::string("failed to allocate hip memory: ") + hipGetErrorString(status) + " (" + - std::to_string(status) + ")"; - } - - /// @brief Construct an out-of-memory error using `status` as message. - /// - /// @param status A HIP error status that was encountered while allocating memory. - OutOfDeviceMemoryError(hipError_t status) : std::runtime_error(format_error(status)) {} -}; - /// @brief Automatically managed GPU memory. /// /// The `DeviceBuffer` is an automatically managed pointer for GPU memory. When @@ -96,117 +67,18 @@ inline DeviceBuffer alloc_buffer(size_t size) std::byte* d_buf = nullptr; if(const auto status = hipMalloc(&d_buf, size); status != hipSuccess) { - throw OutOfDeviceMemoryError(status); + // Add some additional context + + size_t free, total; + check_hip("failed to get HIP memory info", hipMemGetInfo(&free, &total)); + + std::stringstream ss; + ss << "failed to allocate device memory (tried to allocate " << size << " bytes with only " + << free << " available)"; + + throw OutOfDeviceMemoryError(ss.str()); } return DeviceBuffer(d_buf); } -/// @brief Type managing tensor data layout in memory. -/// -/// This structure describes a tensor in memory. It does not actually hold any -/// reference to memory, it just describes how the memory should be laid out if it -/// were. -/// -/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it -/// also includes the data type of the elements of htis tensor. This is mainly to -/// make the descriptor a _complete_ description of a tensor rather than just the -/// dimensions in strides, which helps in reducing clutter in uses of this type. -/// -/// @note All strides are still in _elements_. -/// -/// @tparam DT The conceptual data type of the tensor elements. This need not be the -/// type that the data is actually stored as in memory. -template -struct TensorDescriptor -{ - // For now, the implementation of this type is based on - // `ck_tile::HostTensorDescriptor`, so that we can prototype without - // reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard - // the use of `ck_tile::HostTensorDescriptor` here as an implementation detail. - - /// The conceptual data type of the tensor elements. This need not be the type - /// that the data is actually stored as in memory. - constexpr static DataType data_type = DT; - - /// @brief Create a tensor descriptor from lengths and strides. - /// - /// @param lengths A sequence of tensor lengths, the conceptial dimensions of - /// the tensor in elements. - /// @param strides A sequence of in-memory strides of the tensor, measured in - /// elements. Each element of `strides`` corresponds to one at the same index - /// in `lengths`, the amount of elements to skip in memory to find the next - /// element along that axis. - TensorDescriptor(std::span lengths, std::span strides) - : inner_descriptor_(lengths, strides) - { - // TODO: Validation of strides? For now we just delegate the details of the - // construction to the CK Tile HostTensorDescriptor. - } - - /// Query the conceptual dimensions of the tensor. - /// - /// @returns A span of tensor dimensions, one for every axis. Note that the order - /// does *not* correspond with memory layout, query the in-memory strides for - /// that. - /// - /// @see get_strides() - std::span get_lengths() const { return inner_descriptor_.get_lengths(); } - - /// Query the in-memory strides of the tensor. - /// - /// @returns A span of tensor dimensions, one for every axis. Each element - /// corresponds directly with the stride in elements at the same index in the - /// tensor dimensions. - /// - /// @see get_lengths() - std::span get_strides() const { return inner_descriptor_.get_strides(); } - - /// @brief Compute total tensor size in elements. - /// - /// This function returns the total size of the memory backing a tensor with - /// this descriptor in *elements*, including required extra size for strides. - /// - /// @see get_element_space_size_in_bytes() - size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); } - - /// @brief Compute total tensor size in bytes. - /// - /// This function is like `get_element_space_size()`, except that the returned - /// value is measured in *bytes* rather than *elements*. Use this function for - /// figuring out how much memory needs to be allocated for a particular tensor. - /// - /// @see get_element_space_size() - size_t get_element_space_size_in_bytes() const - { - // For now, the backing type is the naive C++-type that represents the data - // type. When we are going to support packed types such as i4 and fp6, this - // is going to become more complicated. - return get_element_space_size() * data_type_sizeof(DT); - } - - private: - ck_tile::HostTensorDescriptor inner_descriptor_; -}; - -/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor. -/// -/// This function is similar to `alloc_buffer()`, except that the required size is -/// derived automatically from a tensor descriptor. The returned buffer is valid for -/// tensors with that layout. Strides are also taken into account when computing the -/// required size. -/// -/// @tparam DT The conceptual datatype of the elements of the tensor. -/// @param descriptor A descriptor of the memory layout of the tensor to allocate. -/// @throws OutOfDeviceMemoryError if memory allocation failed. -/// -/// @see TensorDescriptor -/// @see DeviceBuffer -/// @see OutOfDeviceMemoryError -/// @see hipMalloc() -template -DeviceBuffer alloc_tensor_buffer(const TensorDescriptor
& descriptor) -{ - return alloc_buffer(descriptor.get_element_space_size_in_bytes()); -} - } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp new file mode 100644 index 0000000000..0ba01a77ca --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp @@ -0,0 +1,444 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/testing/type_traits.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/host/host_tensor.hpp" + +/// This file deals with tensor memory layout. The `TensorDescriptor` is the +/// main item, which is a type that describes (but not manages!) the layout +/// of tensor memory. There are also some related utilities. + +namespace ck_tile::builder::test { + +/// @brief Tensor dimensions type +/// +/// An Extent describes size in tensor space, usually either the tensor lengths +/// (conceptual size) or the tensor strides (memory layout). This type is mainly +/// used by the `TensorDescriptor`. This type is based on `std::array` +/// and supports all relevant operations on that. +/// +/// @note In practical terms, this type is not just an alias of `std::array` for +/// two reasons: First, writing a separate type allows us to write a custom +/// CTAD deduction guideline. This allows users to write `Extent{1, 2, 3}` and +/// get an instance of the correct type, whereas `std::array{1, 2, 3}` yields an +/// instance of `std::array`. This, in turn, allows inferring the rank +/// from the instance (useful in combination with `make_descriptor`), as it alows +/// us to write `function(Extent{1, 2, 3})`. Note that `function({1, 2, 3})` is +/// not valid before C++26 because `{1, 2, 3}` is an initializer list (even if +/// `function` accepts an instance of `Extent`), which does not have a known size +/// at compile time. Second, creating a separate struct for the `Extent` allows +/// additional (static) member functions. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor that this +/// extent describes a size of. +/// +/// @see TensorDescriptor +/// @see make_descriptor +template +struct Extent : std::array +{ + using Base = std::array; + // Note: Default constructor inherited from std::array. + + /// @brief Construct an extent from an `std::vector`. + /// + /// This function can be used to turn an `std::vector` into an `Extent`. + /// Because this code is mainly intended for testing, the vector's size is + /// checked. If its not equal to `RANK`, an exception is thrown. + /// + /// @throws std::runtime_error if the size of `extent` is not equal to `RANK`. + static Extent from_vector(const std::vector& extent) + { + if(extent.size() != RANK) + { + std::stringstream msg; + msg << "invalid rank! expected: " << RANK << ", got: " << extent.size(); + throw std::runtime_error(msg.str()); + } + + Extent result; + std::copy_n(extent.begin(), RANK, result.begin()); + return result; + } + + // Note: std::array doesn't like generating indexing code when the RANK + // is zero. Looks like there is a missing __device__ overload in ROCm 7.1 + // at least. Its not terribly important, but just override the default + // operator[] to fix it. + + /// @brief Array indexing operator + /// + /// `std::array` has issues with this operator when RANK=0, this version + /// fixes that. + /// + /// @param i The index to index the array with. + /// + /// @see std::array::operator[] + __device__ __host__ size_t operator[](size_t i) const + { + if constexpr(RANK > 0) + { + return Base::operator[](i); + } + else + { + __builtin_unreachable(); + } + } + + /// @brief Array indexing operator + /// + /// `std::array` has issues with this operator when RANK=0, this version + /// fixes that. + /// + /// @param i The index to index the array with. + /// + /// @see std::array::operator[] + __device__ __host__ size_t& operator[](size_t i) + { + if constexpr(RANK > 0) + { + return Base::operator[](i); + } + else + { + __builtin_unreachable(); + } + } +}; + +// This is a deduction guideline necessary to resolve `Extent{1, 2, 3}` to the +// correct type. This definition is practically the same as that of `std::array`. +template +Extent(T...) -> Extent; + +/// @brief Concept for automatically deriving tensor memory layout. +/// +/// A `TensorStridesGenerator` is a type which can be used to automatically +/// derive the strides (memory layout) of a tensor, given the tensor lengths. +/// This is mainly used to avoid manually computing strides. +/// +/// Implementors of this concept are required to implement `operator()`, +/// which accepts an instance of `Extent` (the tensor lengths) and +/// yields another instance of `Extent` (the tensor strides). Note +/// that the returned strides are expected to be "pre-scanned", meaning +/// that the offset in memory of a tensor can be computed as +/// `dot(index * strides)` (where `*` is element-wise multiplication). +/// +/// @see TensorDescriptor +/// @see PackedRightLayout +/// @see PackedLeftLayout +template +concept TensorStridesGenerator = requires(const G& generator, const Extent& lengths) { + { generator(lengths) } -> std::convertible_to>; +}; + +/// @brief Layout generator where right-most dimension has stride 1 and +/// all dimensions are packed. +/// +/// This structure implements a `TensorStridesGenerator` which generates +/// a memory layout which has the right-most dimension equal to 1, and +/// all other strides increase right-to-left as a products of the extent. +/// This corresponds with a row-major layout. +/// +/// @see TensorStridesGenerator +/// @see TensorDescriptor +struct PackedRightLayout +{ + /// @brief Stride generation implementation. + /// + /// This is the main function which implements the stride generation + /// + /// @tparam RANK The rank of the tensor. + /// + /// @param lengths The lengths of the tensor. + /// + /// @returns The tensor's memory layout according to the definition + /// of `PackedRightLayout`. + /// + /// @see TensorStridesGenerator + template + Extent operator()(const Extent& lengths) const + { + Extent strides = {}; + size_t numel = 1; + + for(size_t i = RANK; i > 0; --i) + { + strides[i - 1] = numel; + numel *= lengths[i - 1]; + } + + return strides; + } +}; +static_assert(TensorStridesGenerator, + "PackedRightLayout should be a TensorStridesGenerator!"); + +/// @brief Layout generator where left-most dimension has stride 1 and +/// all dimensions are packed. +/// +/// This structure implements a `TensorStridesGenerator` which generates +/// a memory layout which has the left-most dimension equal to 1, and +/// all other strides increase left-to-right as a products of the extent. +/// This corresponds with a column-major layout. +/// +/// @see TensorStridesGenerator +/// @see TensorDescriptor +struct PackedLeftLayout +{ + /// @brief Stride generation implementation. + /// + /// This is the main function which implements the stride generation + /// + /// @tparam RANK The rank of the tensor. + /// + /// @param lengths The lengths of the tensor. + /// + /// @returns The tensor's memory layout according to the definition + /// of `PackedLeftLayout`. + /// + /// @see TensorStridesGenerator + template + Extent operator()(const Extent& lengths) const + { + Extent strides = {}; + size_t numel = 1; + + for(size_t i = 0; i < RANK; ++i) + { + strides[i] = numel; + numel *= lengths[i]; + } + + return strides; + } +}; +static_assert(TensorStridesGenerator, + "PackedLeftLayout should be a TensorStridesGenerator!"); + +/// @brief Type managing tensor data layout in memory. +/// +/// This structure describes a tensor in memory. It does not actually hold any +/// reference to memory, it just describes how the memory should be laid out if it +/// were. +/// +/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it +/// also includes the data type of the elements of htis tensor. This is mainly to +/// make the descriptor a _complete_ description of a tensor rather than just the +/// dimensions in strides, which helps in reducing clutter in uses of this type. +/// +/// @note All strides are still in _elements_. +/// +/// @tparam DT The conceptual data type of the tensor elements. This need not be the +/// type that the data is actually stored as in memory. +/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that +/// the tensor covers. +template +struct TensorDescriptor +{ + // For now, the implementation of this type is based on + // `ck_tile::HostTensorDescriptor`, so that we can prototype without + // reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard + // the use of `ck_tile::HostTensorDescriptor` here as an implementation detail. + + /// @brief Tensor extent alias + /// + /// This alias represents a std::array which holds tensor dimensions. There is one + /// item for each dimension in the tensor, and each item corresponds with the + /// value for that dimension. + using Extent = ::ck_tile::builder::test::Extent; + + /// The conceptual data type of the tensor elements. This need not be the type + /// that the data is actually stored as in memory. + constexpr static DataType data_type = DT; + + /// The tensor "rank": the number of conceptial spatial dimensions that the + /// tensor covers. + constexpr static size_t rank = RANK; + + /// @brief Create a tensor descriptor from lengths and strides. + /// + /// @param lengths A sequence of tensor lengths, the conceptial dimensions of + /// the tensor in elements. + /// @param strides A sequence of in-memory strides of the tensor, measured in + /// elements. Each element of `strides`` corresponds to one at the same index + /// in `lengths`, the amount of elements to skip in memory to find the next + /// element along that axis. + TensorDescriptor(const Extent& lengths, const Extent& strides) + : inner_descriptor_(lengths, strides) + { + // TODO: Validation of strides? For now we just delegate the details of the + // construction to the CK Tile HostTensorDescriptor. + } + + /// @brief Create a tensor descriptor with lengths and automatic layout. + /// + /// This function initializes a tensor descriptor using lengths, and by deriving + /// the memory layout from the layout generator `Generator`. The tensor will be + /// initialized with the strides yielded from `Generator`. + /// + /// @tparam Generator The generator type to generate the strides with. For example, + /// `PackedRightLayout` or `PackedLeftLayout`. + /// + /// @param lengths A sequence of tensor lengths, the conceptial dimensions of + /// the tensor in elements. + /// @param gen An instance of `Generator` to generate the strides with. + /// + /// @see TensorStridesGenerator + /// @see PackedLeftLayout + /// @see PackedRightLayout + template + requires TensorStridesGenerator + TensorDescriptor(const Extent& lengths, const Generator& gen) + : TensorDescriptor(lengths, gen(lengths)) + { + } + + /// Query the conceptual dimensions of the tensor. + /// + /// @returns A span of tensor dimensions, one for every axis. Note that the order + /// does *not* correspond with memory layout, query the in-memory strides for that. + /// + /// @see get_strides() + Extent get_lengths() const + { + // TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and + // after that this can just be `return lengths_;` (and make it const Extent&). + Extent result; + std::copy_n(inner_descriptor_.get_lengths().begin(), RANK, result.begin()); + return result; + } + + /// Query the in-memory strides of the tensor. + /// + /// @returns A span of tensor dimensions, one for every axis. Each element + /// corresponds directly with the stride in elements at the same index in the + /// tensor dimensions. + /// + /// @see get_lengths() + Extent get_strides() const + { + // TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and + // after that this can just be `return strides_;` (and make it const Extent&). + Extent result; + std::copy_n(inner_descriptor_.get_strides().begin(), RANK, result.begin()); + return result; + } + + /// @brief Compute conceptual tensor size in elements. + /// + /// This function returns the size of the tensor in elements. This function only + /// takes the lengths into account, not the strides. In order to allocate memory + /// for the tensor, use `get_element_space_size()`. + /// + /// @see get_lengths + /// @see get_element_space_size + size_t get_element_size() const { return inner_descriptor_.get_element_size(); } + + /// @brief Compute total tensor space size in elements. + /// + /// This function returns the total size of the memory backing a tensor with + /// this descriptor in *elements*, including required extra size for strides. + /// + /// @see get_element_space_size_in_bytes() + size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); } + + /// @brief Compute total tensor size in bytes. + /// + /// This function is like `get_element_space_size()`, except that the returned + /// value is measured in *bytes* rather than *elements*. Use this function for + /// figuring out how much memory needs to be allocated for a particular tensor. + /// + /// @see get_element_space_size() + size_t get_element_space_size_in_bytes() const + { + // For now, the backing type is the naive C++-type that represents the data + // type. When we are going to support packed types such as i4 and fp6, this + // is going to become more complicated. + return get_element_space_size() * data_type_sizeof(DT); + } + + /// @brief Get a tensor descriptor for the space backing a tensor. + /// + /// This function returns a tensor descriptor which represents the buffer space + /// required to a tensor with this descriptor. This is mainly useful to process + /// buffers with functions which normally operate over tensor descriptors. The + /// resulting tensor descriptor describes a 1D tensor with the same number of + /// elements as in the space. + /// + /// @see get_element_space_size() + TensorDescriptor get_space_descriptor() const + { + ck_tile::builder::test::Extent<1> lengths = {this->get_element_space_size()}; + ck_tile::builder::test::Extent<1> strides = {1}; + return TensorDescriptor(lengths, strides); + } + + private: + ck_tile::HostTensorDescriptor inner_descriptor_; +}; + +/// @brief Tensor descriptor construction helper. +/// +/// This function can be used to create a tensor descriptor. It accepts the same +/// parameters as the constructor of `TensorDescriptor`, that is, a sequence of +/// lengths and a sequence of strides (or a generator to generate the strides). +/// The main use of this function is that it allows automatic inference of the `RANK` +/// parameter. C++ constructors do not allow partial specification of type parameters, +/// and so its impossible to write `TensorDescriptor
x(Extent{1, 2, 3}, ...)` +/// and have the `RANK` be automatically inferred. Functions do allow this though, +/// so this function can be used to write `make_descriptor(Extent{1, 2, 3}, ...)` +/// +/// @tparam DT The conceptual data type of the tensor elements. This need not be the +/// type that the data is actually stored as in memory. +/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that +/// the tensor covers. +/// +/// @param lengths A sequence of tensor lengths, the conceptial dimensions of +/// the tensor in elements. +/// @param strides A sequence of in-memory strides of the tensor, or a generator +/// to generate those strides from the tensor lengths. +/// +/// @see TensorDescriptor +template +TensorDescriptor make_descriptor(const Extent& lengths, const auto& strides) +{ + return TensorDescriptor(lengths, strides); +} + +/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor. +/// +/// This function is similar to `alloc_buffer()`, except that the required size is +/// derived automatically from a tensor descriptor. The returned buffer is valid for +/// tensors with that layout. Strides are also taken into account when computing the +/// required size. +/// +/// @tparam DT The conceptual datatype of the elements of the tensor. +/// @tparam RANK The conceptual rank (number of dimensions) of the tensor. +/// +/// @param descriptor A descriptor of the memory layout of the tensor to allocate. +/// +/// @throws OutOfDeviceMemoryError if memory allocation failed. +/// +/// @see TensorDescriptor +/// @see DeviceBuffer +/// @see OutOfDeviceMemoryError +/// @see hipMalloc() +template +DeviceBuffer alloc_tensor_buffer(const TensorDescriptor& descriptor) +{ + return alloc_buffer(descriptor.get_element_space_size_in_bytes()); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp new file mode 100644 index 0000000000..f078a1ac82 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp @@ -0,0 +1,258 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include +#include +#include + +/// This file implements a generic GPU tensor "foreach" function. This +/// functionality turned out useful in separate parts of the testing +/// system, hence its implemented in a separate file. This version is +/// not particularly efficient (but it should at least be readable), +/// but it should be easy to replace the implementation in the future, +/// should that be needed. + +namespace ck_tile::builder::test { + +/// @brief Concept for constraining tensor iteration functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `tensor_foreach` function. +template +concept ForeachFunctor = requires(const F& f, const Extent& index) { + { f(index) } -> std::same_as; +}; + +namespace detail { + +/// @brief Default foreach kernel block size +/// +/// This value is the default number of threads in each block when +/// executing the foreach kernel. This value is mostly arbitrary, +/// 256 is usually a good default for AMD GPUs. +/// +/// @see tensor_foreach +constexpr int DEVICE_FOREACH_BLOCK_SIZE = 256; + +/// @brief Tensor iteration kernel +/// +/// This kernel implements the actual iteration logic, and is intended +/// to be used solely by `tensor_foreach` to iterate & invoke the +/// actual callback. +/// +/// @tparam BLOCK_SIZE The number of threads in each block on the GPU. +/// @tparam RANK The rank (number of spatial dimensions) of the tensor to +/// iterate. +/// @tparam F The type of the callback to invoke. This function must be +/// compatible with execution as a __device__ function. +/// +/// @param numel The total number of elements in the tensor. +/// @param shape_scan A right-exclusive scan of the shape of the tensor. +/// @param f The callback to invoke for each index of the tensor. This +/// functor must be eligible for running on the GPU. +template + requires ForeachFunctor +__global__ __launch_bounds__(BLOCK_SIZE) // + void foreach_kernel(const size_t numel, Extent shape_scan, F f) +{ + const auto gid = blockIdx.x * BLOCK_SIZE + threadIdx.x; + for(size_t flat_idx = gid; flat_idx < numel; flat_idx += gridDim.x * BLOCK_SIZE) + { + // Compute the current index. + Extent index = {}; + + size_t idx = flat_idx; + for(size_t i = 0; i < RANK; ++i) + { + const auto scanned_dim = shape_scan[i]; + index[i] = idx / scanned_dim; + idx %= scanned_dim; + } + + // Then invoke the callback with the index. + f(index); + } +} + +/// @brief A utility to get a C++ type for a CKB type +/// +/// Right now this is just an alias of an internal CKB helper, +/// but this should probably be moved elsewhere. +template +using cpp_type_t = typename builder::factory::internal::DataTypeToCK
::type; + +} // namespace detail + +/// @brief Calculate tensor memory offset given index and strides. +/// +/// This function returns the offset in memory in a tensor, given a particular +/// multi-dimensional index and a particular set of strides. Each value in the +/// index corresponds one-to-one with a value in the strides, which are the +/// index and stride at that dimension in the tensor. These strides must be +/// pre-scanned, meaning that each index is the absolute stride of elements +/// along that axis. In essence, this means that you should pass the output of +/// `TensorDescriptor::get_strides()` into this function. +/// +/// @pre The index must be inside the tensor space. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param index A multi-dimensional index inside the tensor space. +/// @param strides A set of strides, one for each dimension. +/// +/// @see TensorDescriptor +template +__host__ __device__ size_t calculate_offset(const Extent& index, const Extent& strides) +{ + size_t offset = 0; +#pragma unroll + for(size_t i = 0; i < RANK; ++i) + { + offset += index[i] * strides[i]; + } + return offset; +} + +/// @brief Invoke a callback on the GPU for every index in a tensor. +/// +/// This function invokes a callback functor on the GPU, for each index in +/// a tensor. This function _only_ takes care of iterating over all indices +/// in a tensor of a particular shape; this function does not handle or know +/// about actual tensor data. +/// +/// @note This function is currently implemented relatively naively: The +/// iteration order is always row-wise, implemented as a persistent kernel. +/// The main objective of this function is to be used with the CK-Builder +/// testing system, and so readability and correctness should be preferred +/// over performance. If this is ever a source of performance problems, +/// feel free to replace the implementation with something better. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param shape The shape of the tensor to iterate over. +/// @param f The callback to invoke for each index of the tensor. This +/// functor must be eligible for running on the GPU. +/// +/// @see ForeachFunctor +/// @see detail::foreach_kernel +template +void tensor_foreach(const Extent& shape, ForeachFunctor auto f) +{ + constexpr int block_size = detail::DEVICE_FOREACH_BLOCK_SIZE; + const auto kernel = detail::foreach_kernel; + + int occupancy; + check_hip(hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, block_size, 0)); + + int device; + check_hip(hipGetDevice(&device)); + + int multiprocessors; + check_hip( + hipDeviceGetAttribute(&multiprocessors, hipDeviceAttributeMultiprocessorCount, device)); + + // Pre-scan the shape to help indexing in the kernel. + // Note: the order is not that important, so long as the iteration + // order in the kernel is from large-to-small. Right layout is the + // easiest solution for that. + + Extent shape_scan; + size_t numel = 1; + for(int i = RANK; i > 0; --i) + { + shape_scan[i - 1] = numel; + numel *= shape[i - 1]; + } + + // Reset any errors from previous launches. + (void)hipGetLastError(); + + kernel<<>>(numel, shape_scan, f); + check_hip(hipGetLastError()); +} + +/// @brief Concept for tensor initializing functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `fill_tensor` function. +template +concept FillTensorFunctor = requires(const F& f, const Extent& index) { + { f(index) } -> std::convertible_to>; +}; + +/// @brief Utility for initializing tensors. +/// +/// This function is a utility helper for initializing tensors. It accepts a +/// tensor descriptor, buffer, and a callback. The callback is invoked for every +/// coordinate (which is passed to the callback), and the tensor is initialized +/// with resulting value. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param f A functor used to get the value at a particular coordinate. +/// +/// @see FillTensorFunctor +template +void fill_tensor(const TensorDescriptor& desc, + void* buffer, + FillTensorFunctor auto f) +{ + const auto strides = desc.get_strides(); + tensor_foreach(desc.get_lengths(), [buffer, f, strides](const auto& index) { + using T = detail::cpp_type_t
; + auto* ptr = static_cast(buffer); + const auto offset = calculate_offset(index, strides); + + ptr[offset] = f(index); + }); +} + +/// @brief Concept for tensor buffer initializing functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `fill_tensor_buffer` function. +template +concept FillTensorBufferFunctor = requires(const F& f, size_t index) { + { f(index) } -> std::convertible_to>; +}; + +/// @brief Utility for initializing tensor buffers. +/// +/// This function is a utility for initializing memory backing a tensor buffer. In +/// contrast to `fill_tensor`, this function first extracts the backing space of +/// the tensor, and then invokes the callback for each (flat) index. This function +/// is particular useful for initializing out-of-bounds indices with a known with a +/// known value. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param f A functor used to get the value at a particular index. +/// +/// @see FillTensorBufferFunctor +template +void fill_tensor_buffer(const TensorDescriptor& desc, + void* buffer, + FillTensorBufferFunctor
auto f) +{ + fill_tensor(desc.get_space_descriptor(), buffer, [f](auto index) { return f(index[0]); }); +} + +template +void clear_tensor_buffer(const TensorDescriptor& desc, + void* buffer, + detail::cpp_type_t
value = detail::cpp_type_t
{0}) +{ + fill_tensor_buffer(desc, buffer, [value]([[maybe_unused]] size_t i) { return value; }); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp index 15cb43f369..2976e6c14b 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp @@ -19,15 +19,30 @@ namespace ck_tile::builder::test { -template -void init_tensor_buffer_uniform_int(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, - int min_val, - int max_val) +/// @brief Initialize tensor data with a uniform int distribution +/// +/// This function initializes a tensor's device memory with random integer data, +/// drawn from a uniform distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param min_value The minimum value of the distribution (inclusive). +/// @param max_value The maximum value of the distribution (exclusive). +template +void init_tensor_buffer_uniform_int(void* buf, + const TensorDescriptor& descriptor, + int min_value, + int max_value) { size_t size = descriptor.get_element_space_size_in_bytes(); - if(max_val - min_val <= 1) + if(max_value - min_value <= 1) { throw std::runtime_error("Error while filling device tensor with random integer data: max " "value must be at least 2 greater than min value, otherwise " @@ -38,19 +53,34 @@ void init_tensor_buffer_uniform_int(const DeviceBuffer& buf, // we might be asked to generate int values on fp data types that don't have the required // precision - if(static_cast(max_val - 1) == static_cast(min_val)) + if(static_cast(max_value - 1) == static_cast(min_value)) { throw std::runtime_error("Error while filling device tensor with random integer data: " "insufficient precision in specified range"); } size_t packed_size = ck::packed_size_v; fill_tensor_uniform_rand_int_values<<<256, 256>>>( - static_cast(buf.get()), min_val, max_val, (size * packed_size) / sizeof(ck_type)); + static_cast(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type)); } -template -void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, +/// @brief Initialize tensor data with a uniform float distribution +/// +/// This function initializes a tensor's device memory with random floating data, +/// drawn from a uniform distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param min_value The minimum value of the distribution (inclusive). +/// @param max_value The maximum value of the distribution (exclusive). +template +void init_tensor_buffer_uniform_fp(void* buf, + const TensorDescriptor& descriptor, float min_value, float max_value) { @@ -59,15 +89,30 @@ void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf, using ck_type = factory::internal::DataTypeToCK
::type; size_t packed_size = ck::packed_size_v; - fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast(buf.get()), + fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type)); } -template -void init_tensor_buffer_normal_fp(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, +/// @brief Initialize tensor data with a normal float distribution +/// +/// This function initializes a tensor's device memory with random floating data, +/// drawn from a normal distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param sigma The standard deviation of the distribution. +/// @param mean The mean of the distribution. +template +void init_tensor_buffer_normal_fp(void* buf, + const TensorDescriptor& descriptor, float sigma, float mean) { @@ -76,7 +121,7 @@ void init_tensor_buffer_normal_fp(const DeviceBuffer& buf, using ck_type = factory::internal::DataTypeToCK
::type; size_t packed_size = ck::packed_size_v; fill_tensor_norm_rand_fp_values<<<256, 256>>>( - static_cast(buf.get()), sigma, mean, (size * packed_size) / sizeof(ck_type)); + static_cast(buf), sigma, mean, (size * packed_size) / sizeof(ck_type)); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index a0dfa27409..9c8b858018 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -5,6 +5,8 @@ #include +#include "ck_tile/builder/testing/validation.hpp" + /// This file is the main header for the CK-Builder testing system. A high-level /// description of this testing system is documented in /// `ck_tile/builder/testing/README.md`. This file deals mainly deals with the @@ -78,7 +80,7 @@ namespace ck_tile::builder::test { /// that this structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Args; @@ -98,7 +100,7 @@ struct Args; /// structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Inputs; @@ -118,7 +120,7 @@ struct Inputs; /// structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Outputs; @@ -133,7 +135,7 @@ struct Outputs; /// @note The easiest way to implement this type is to use the `DeviceBuffer` /// type to allocate individual device buffers for each input tensor. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. /// /// @see alloc_inputs() /// @see ValidUniqueInputs @@ -152,7 +154,7 @@ struct UniqueInputs; /// @note The easiest way to implement this type is to use the `DeviceBuffer` /// type to allocate individual device buffers for each output tensor. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. /// /// @see alloc_outputs() /// @see ValidUniqueOutputs @@ -195,7 +197,9 @@ concept ValidUniqueOutputs = requires(UniqueOutputs& inputs) { /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. /// /// @see Inputs /// @see UniqueInputs @@ -208,16 +212,18 @@ UniqueInputs alloc_inputs(const Args& args); /// @brief Allocate inputs corresponding to a signature. /// /// The `init_inputs()` function is used to initialize pseudo-random data -/// to the tensors specified in the Inputs structure. +/// to the tensors specified in the Inputs structure. Implementors should +/// fill each of the tensors in `inputs` with appropriate random data. /// /// @tparam SIGNATURE the signature to specialize the structure for. /// +/// @param args The run-time arguments of the operation. +/// @param inputs The operation inputs to initialize with random data. +/// /// @see Inputs -/// @see UniqueInputs /// @see tensor_initialization template - requires ValidUniqueInputs -void init_inputs(const Args& args, UniqueInputs& inputs); +void init_inputs(const Args& args, Inputs inputs); /// @brief Allocate outputs corresponding to a signature. /// @@ -226,7 +232,9 @@ void init_inputs(const Args& args, UniqueInputs& inputs); /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. /// /// @see Outputs /// @see UniqueOutputs @@ -236,6 +244,29 @@ template requires ValidUniqueOutputs UniqueInputs alloc_outputs(const Args& args); +/// @brief Compare device operation outputs. +/// +/// This function implements the main comparison functionality, used to compare +/// the output of one implementation for a particular `SIGNATURE` with that of +/// another. Usually, the `expected` output should be computed by a reference +/// implementation. +/// +/// The implementation of this function generates a "report", which includes +/// detailed information about which tensors are different, how many elements +/// were incorrect, and where (a subset of) those elements are located within +/// the tensor. See `ValidationReport` for more information about the report. +/// +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. +/// @param actual The actual results, the results of the operation to-be-tested. +/// @param expected The expected results, the results of the reference implementation. +/// +/// @see ValidationReport +template +ValidationReport +validate(const Args& args, Outputs actual, Outputs expected); + /// @brief Invoke a device operation created by CK Builder. /// /// This is the main function used to invoke a particular device operation @@ -257,7 +288,7 @@ UniqueInputs alloc_outputs(const Args& args); /// @post The tensors in `outputs` are overwritten with the outputs of the device /// operation. /// -/// @tparam SIGNATURE the signature to specialize this function for +/// @tparam SIGNATURE The signature to specialize this function for /// @tparam Operation the kernel of the operation to invoke. This type should be /// one that is created using the Builder API. /// @param operation An instance of the operation to invoke. diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp new file mode 100644 index 0000000000..275fa490eb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -0,0 +1,167 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/type_convert.hpp" +#include +#include +#include +#include + +/// This file implements functionality related to "validation", ie, functionality +/// to compare tensors. The functionality in this file should be testing-framework +/// agnostic, and it should NOT generate any error messages by itself. Instead, +/// all relevant information should be stored in the `ValidationReport` structure. +/// This structure should then be used to generate error messages, explainations, +/// etc, by the actual testing framework that the user has chosen. + +namespace ck_tile::builder::test { + +/// @brief Information about how a set of comparisons failed or succeeded. +/// +/// This structure represents a "report" generated by comparing sets of tensors. +/// Its intended to be used as the result of `ckt::validate()`, where `check()` +/// is invoked for each of the output tensors of a particular device operation. +/// The test should be considered successful if _all_ of those checks passes, +/// which can inspected by asserting that `get_errors().size()` is 0. +struct ValidationReport +{ + /// @brief Information related to a single tensor comparison. + /// + /// This structure holds the information about the result of comparing + /// two particular tensors. + struct Case + { + /// The name of the tensor that was compared here, stored here for convenience + /// so that reporting any errors is easier. + std::string tensor_name; + + /// The number of elements which were different between the two compared tensors. + uint64_t wrong_elements; + + /// The total number of elements in each tensor. + uint64_t total_elements; + + /// @brief Return whether the check associated to this case was successful. + /// + /// This function returns whether the check associated to this case was successful, + /// which is directly derived from checking whether the number of incorrect elements + /// was 0. + bool is_ok() const { return wrong_elements == 0; } + }; + + /// @brief Get comparison cases which were incorrect. + /// + /// This function returns a vector of comparison cases that did not succeed, ie, for + /// which `Case::is_ok` return false. In order to check whether validation passed, it + /// is sufficient to assert that this function returns no cases. + std::vector get_errors() const + { + std::vector errors; + std::copy_if(reports_.begin(), + reports_.end(), + std::back_inserter(errors), + [](const auto& report) { return !report.is_ok(); }); + return errors; + } + + /// @brief Compare two tensors and record the results in the report. + /// + /// This is the main function used to compare two tensors. The results of this + /// comparison, including any supplemental information, is recorded into the report. + /// + /// @returns `false` if the comparison failed. If so, the details can be found via + /// `get_errors()`. + /// + /// @tparam DT The data type of the tensors to check. + /// @tparam RANK The rank (number of spatial dimensions) of the tensor to check. + /// + /// @param tensor_name The name of the tensors to check. This should be a value by which + /// whoever is debugging the associated test later can easily find out which of the + /// outputs of a device operation was incorrect. + /// @param descriptor The descriptor (memory layout) of the tensor. + /// @param actual The device buffer with the values of the tensor to-be-tested, ie, the + /// results of the device operation. + /// @param expected The device buffer with the values of the reference tensor. These are + /// treated as a "golden standard", and should usually be generated by a reference + /// implementation. + /// @param rtol The relative acceptable tolerance between two values. + /// @param atol The absolute acceptable tolerance between two values. + template + bool check(std::string_view tensor_name, + const TensorDescriptor& descriptor, + const void* actual, + const void* expected, + double rtol = 1e-3, + double atol = 1e-3); + + private: + std::vector reports_; +}; + +template +bool ValidationReport::check(std::string_view tensor_name, + const TensorDescriptor& descriptor, + const void* actual_data, + const void* expected_data, + double rtol, + double atol) +{ + const auto strides = descriptor.get_strides(); + + // During development and CI, only the kernels that were changed would fail, and so we can + // assume that the average case does not have errors. Therefore, split out testing into a + // quick test which just counts the incorrect elements, and a more in-depth test that also + // returns the indices of the incorrect items. + + // Initial pass: count errors + + // Allocate and reset counter + auto d_error_count = alloc_buffer(sizeof(uint64_t)); + check_hip(hipMemset(d_error_count.get(), 0, sizeof(uint64_t))); + + tensor_foreach(descriptor.get_lengths(), [=, error_count = d_error_count.get()](auto index) { + using CKType = typename factory::internal::DataTypeToCK
::type; + + const auto* actual = static_cast(actual_data); + const auto* expected = static_cast(expected_data); + + static_assert(!std::is_same_v, + "TODO implement compare_kernel() for double"); + + const auto offset = calculate_offset(index, strides); + + const auto o = static_cast(type_convert(actual[offset])); + const auto r = static_cast(type_convert(expected[offset])); + const auto err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + // We expect the number of errors to be very low, so just use an atomic + // for now. + atomicAdd(reinterpret_cast(error_count), 1); + } + }); + + uint64_t error_count = 0; + check_hip( + hipMemcpy(&error_count, d_error_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + + // TODO: Gather detailed coordinates. + + reports_.push_back(Case{ + .tensor_name = std::string(tensor_name), + .wrong_elements = error_count, + .total_elements = descriptor.get_element_size(), + }); + + return error_count == 0; +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 800d485660..d13c8cfdd9 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -80,33 +80,36 @@ add_ck_builder_test(test_ckb_conv_builder test_instance_traits_util.cpp unit_device_buffer.cpp unit_tensor_descriptor.cpp + unit_tensor_foreach.cpp + unit_error.cpp + unit_validation.cpp unit_conv_elementwise_op.cpp unit_conv_tensor_layout.cpp unit_conv_tensor_type.cpp unit_conv_thread_block.cpp unit_conv_tuning_params.cpp) - - # Tests the inline diff utility used for comparing strings in tests assertions - add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) - # GPU reference validation tests (in validation/ folder) - # 1. Reference kernel execution and InstanceTraits - add_ck_builder_test(test_ckb_reference_execution - validation/test_reference_execution.cpp - validation/test_reference_instance_traits.cpp) - target_link_libraries(test_ckb_reference_execution PRIVATE utility) - - # Note: Optimized kernel validation tests will be added after merging dev branch - # with kernel Run() implementation from colleague's work +# Tests the inline diff utility used for comparing strings in tests assertions +add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) + +# GPU reference validation tests (in validation/ folder) +# 1. Reference kernel execution and InstanceTraits +add_ck_builder_test(test_ckb_reference_execution + validation/test_reference_execution.cpp + validation/test_reference_instance_traits.cpp) +target_link_libraries(test_ckb_reference_execution PRIVATE utility) + +# Note: Optimized kernel validation tests will be added after merging dev branch +# with kernel Run() implementation from colleague's work + +# Tests convolution trait selection and configuration +add_ck_builder_test(test_ckb_conv_traits + conv/ck/test_conv_traits.cpp) + +# Tests convolution problem description and parameter handling +add_ck_builder_test(test_ckb_conv_description + test_conv_description.cpp) - # Tests convolution trait selection and configuration - add_ck_builder_test(test_ckb_conv_traits - conv/ck/test_conv_traits.cpp) - - # Tests convolution problem description and parameter handling - add_ck_builder_test(test_ckb_conv_description - test_conv_description.cpp) - ################################################################################ # REGRESSION TESTS - Integration Tests (With Kernel Compilation) ################################################################################ diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index aa53aa9666..5a52b6a9b5 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -6,11 +6,14 @@ #include "utils/conv_algorithm_type_utils.hpp" #include "ck_tile/builder/testing/conv_fwd_ck.hpp" #include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +using ck_tile::test::MatchesReference; + constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .direction = ckb::ConvDirection::FORWARD, @@ -78,11 +81,18 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) .cde_elementwise_op = {}, }; - auto inputs = alloc_inputs(args); - auto outputs = alloc_outputs(args); + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); - init_inputs(args, inputs); + ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; ckt::run(conv, args, inputs.get(), outputs.get()); + + // TODO: This should be allocated via ckt::alloc_outputs() and + // initialized via ckt::run() with the reference implementation + // instead. + auto reference = outputs.get(); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference)); } diff --git a/experimental/builder/test/test_inline_diff.cpp b/experimental/builder/test/test_inline_diff.cpp index 8d3a90c95f..6a7a7ac8f7 100644 --- a/experimental/builder/test/test_inline_diff.cpp +++ b/experimental/builder/test/test_inline_diff.cpp @@ -5,8 +5,7 @@ #include "testing_utils.hpp" -namespace ck_tile::builder { -namespace { +using ck_tile::test::inlineDiff; TEST(InlineDiff, simpleColorDiff) { @@ -16,8 +15,8 @@ TEST(InlineDiff, simpleColorDiff) // some easy tests // you can veryfy the ungodly strings are meaningful by running echo -e "" - EXPECT_THAT(test::inlineDiff(str1, str2, true), "hello"); - EXPECT_THAT(test::inlineDiff(str1, str3, true), + EXPECT_THAT(inlineDiff(str1, str2, true), "hello"); + EXPECT_THAT(inlineDiff(str1, str3, true), "[\x1B[36mwor\x1B[0m|\x1B[35mhel\x1B[0m]l[\x1B[36md\x1B[0m|\x1B[35mo\x1B[0m]"); } @@ -28,8 +27,8 @@ TEST(InlineDiff, noColorDiff) std::string str3{"world"}; // some easy tests without color - EXPECT_THAT(test::inlineDiff(str1, str2, false), "hello"); - EXPECT_THAT(test::inlineDiff(str1, str3, false), "[wor|hel]l[d|o]"); + EXPECT_THAT(inlineDiff(str1, str2, false), "hello"); + EXPECT_THAT(inlineDiff(str1, str3, false), "[wor|hel]l[d|o]"); } TEST(InlineDiff, complexColorDiff) @@ -42,11 +41,8 @@ TEST(InlineDiff, complexColorDiff) "this part has degeahc, this part has, this part added, this part has ana extra letter"}; EXPECT_THAT( - test::inlineDiff(str5, str4, true), + inlineDiff(str5, str4, true), "this part has [\x1B[36mchanged\x1B[0m|\x1B[35mdegeahc\x1B[0m], this part has[\x1B[36m " "been left out\x1B[0m|\x1B[35m\x1B[0m], this part[\x1B[36m\x1B[0m|\x1B[35m added\x1B[0m], " "this part has an[\x1B[36m\x1B[0m|\x1B[35ma\x1B[0m] extra letter"); }; - -} // namespace -} // namespace ck_tile::builder diff --git a/experimental/builder/test/testing_utils.hpp b/experimental/builder/test/testing_utils.hpp index 7a03851ac4..b84d53b6df 100644 --- a/experimental/builder/test/testing_utils.hpp +++ b/experimental/builder/test/testing_utils.hpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include +#include "ck_tile/builder/testing/testing.hpp" #include #include #include @@ -21,6 +22,16 @@ /// dedicated function to override to provide printing support. std::ostream& operator<<(std::ostream& os, hipError_t status); +namespace ck_tile::builder::test { + +template +std::ostream& operator<<(std::ostream& os, [[maybe_unused]] Outputs outputs) +{ + return os << ""; +} + +} // namespace ck_tile::builder::test + namespace ck_tile::test { static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); } @@ -150,4 +161,47 @@ struct HipStatusMatcher : public ::testing::MatcherInterface /// @param error The error to expect. ::testing::Matcher HipError(hipError_t error); +template +struct ReferenceOutputMatcher + : public ::testing::MatcherInterface> +{ + ReferenceOutputMatcher(const builder::test::Args& args, + builder::test::Outputs expected) + : args_(&args), expected_(expected) + { + } + + bool MatchAndExplain(builder::test::Outputs actual, + [[maybe_unused]] ::testing::MatchResultListener* listener) const override + { + const auto report = ck_tile::builder::test::validate(*args_, actual, expected_); + const auto errors = report.get_errors(); + + if(listener->IsInterested() && !errors.empty()) + { + *listener << errors.size() << " tensors failed to validate"; + } + + return errors.empty(); + } + + void DescribeTo(std::ostream* os) const override { *os << ""; } + + void DescribeNegationTo(std::ostream* os) const override + { + *os << "isn't equal to "; + } + + const builder::test::Args* args_; + builder::test::Outputs expected_; +}; + +template +::testing::Matcher> +MatchesReference(const builder::test::Args& args, + builder::test::Outputs expected) +{ + return ::testing::MakeMatcher(new ReferenceOutputMatcher(args, expected)); +} + } // namespace ck_tile::test diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index 7ffd446966..b385210cea 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -11,40 +11,27 @@ namespace { namespace ckb = ck_tile::builder; using ck_tile::builder::factory::internal::DataTypeToCK; -TEST(ConvTensorType, AssignsTypesForFP16) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} +template +constexpr auto check_same = std::is_same_v::type, T>; -TEST(ConvTensorType, AssignsTypesForBF16) +TEST(ConvTensorType, Exhaustive) { - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} + using enum ckb::DataType; -TEST(ConvTensorType, AssignsTypesForFP32) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForINT32) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForI8) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForFP8) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); + const auto type = FP32; + // This switch ensures that we get a warning (error with -Werror) if + // a variant is missing. + switch(type) + { + case UNDEFINED_DATA_TYPE: break; + case FP32: EXPECT_TRUE((check_same)); break; + case FP16: EXPECT_TRUE((check_same)); break; + case BF16: EXPECT_TRUE((check_same)); break; + case INT32: EXPECT_TRUE((check_same)); break; + case FP8: EXPECT_TRUE((check_same)); break; + case I8: EXPECT_TRUE((check_same)); break; + case U8: EXPECT_TRUE((check_same)); break; + } } } // namespace diff --git a/experimental/builder/test/unit_device_buffer.cpp b/experimental/builder/test/unit_device_buffer.cpp index 75408acc16..c7180395b7 100644 --- a/experimental/builder/test/unit_device_buffer.cpp +++ b/experimental/builder/test/unit_device_buffer.cpp @@ -2,10 +2,11 @@ // SPDX-License-Identifier: MIT #include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "testing_utils.hpp" #include #include -#include +#include namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; @@ -54,6 +55,11 @@ TEST(DeviceBuffer, AutoFree) // Trying to use a pointer after freeing should return en error in HIP. EXPECT_THAT(hipMemset(ptr, 0xFF, size), HipError(hipErrorInvalidValue)); + + // Reset internal HIP error state. + // Otherwise, the error may leak into other tests, triggering anything that + // checks the output of hipGetLastError(); + (void)hipGetLastError(); } TEST(DeviceBuffer, ThrowsOnOom) @@ -62,13 +68,16 @@ TEST(DeviceBuffer, ThrowsOnOom) auto check = [] { auto buffer = ckt::alloc_buffer(size); }; EXPECT_THAT(check, Throws()); + + // Reset internal HIP error state. + // Otherwise, the error may leak into other tests, triggering anything that + // checks the output of hipGetLastError(); + (void)hipGetLastError(); } TEST(DeviceBuffer, AllocTensorBuffer) { - std::vector lengths = {128, 128, 128}; - std::vector strides = {128 * 128, 128, 1}; - ckt::TensorDescriptor descriptor(lengths, strides); + ckt::TensorDescriptor descriptor({128, 128, 128}, {128 * 128, 128, 1}); auto buffer = ckt::alloc_tensor_buffer(descriptor); diff --git a/experimental/builder/test/unit_error.cpp b/experimental/builder/test/unit_error.cpp new file mode 100644 index 0000000000..b666462385 --- /dev/null +++ b/experimental/builder/test/unit_error.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "testing_utils.hpp" +#include +#include + +namespace ckt = ck_tile::builder::test; + +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::Throws; +using ::testing::ThrowsMessage; + +[[noreturn]] void throw_error() { throw ckt::HipError("test error", hipErrorInvalidValue); } + +TEST(HipError, SourceInfo) +{ + EXPECT_THAT(throw_error, + ThrowsMessage(AllOf( + // The error message should include... + // ...the user message + HasSubstr("test error"), + // ...the HIP message + HasSubstr("invalid argument"), + // ...the HIP status code, + HasSubstr("(1)"), + // ...the filename + HasSubstr("experimental/builder/test/unit_error.cpp"), + // ...the function name + HasSubstr("throw_error"), + // Note: Don't include the row/column so that we can move + // stuff around in this file. + ))); +} + +TEST(CheckHip, BasicUsage) +{ + EXPECT_THAT([] { ckt::check_hip(hipSuccess); }, Not(Throws())); + EXPECT_THAT([] { ckt::check_hip(hipErrorNotMapped); }, Throws()); + EXPECT_THAT([] { ckt::check_hip(hipErrorOutOfMemory); }, Throws()); + EXPECT_THAT([] { ckt::check_hip("test message", hipErrorAlreadyMapped); }, + ThrowsMessage(HasSubstr("test message"))); +} diff --git a/experimental/builder/test/unit_tensor_descriptor.cpp b/experimental/builder/test/unit_tensor_descriptor.cpp index 07abfe44bd..d9e92bf07e 100644 --- a/experimental/builder/test/unit_tensor_descriptor.cpp +++ b/experimental/builder/test/unit_tensor_descriptor.cpp @@ -1,25 +1,28 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "testing_utils.hpp" #include #include +#include #include namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; using ::testing::ElementsAreArray; -using ::testing::Ge; +using ::testing::Eq; +using ::testing::Throws; TEST(TensorDescriptor, Basic) { - constexpr auto dt = ckb::DataType::FP16; - std::vector lengths = {123, 456, 789}; - std::vector strides = {456 * 789, 789, 1}; + constexpr auto dt = ckb::DataType::FP16; + constexpr size_t rank = 3; + ckt::Extent lengths = {123, 456, 789}; + ckt::Extent strides = {456 * 789, 789, 1}; - ckt::TensorDescriptor
descriptor(lengths, strides); + ckt::TensorDescriptor descriptor(lengths, strides); EXPECT_THAT(descriptor.get_lengths(), ElementsAreArray(lengths)); EXPECT_THAT(descriptor.get_strides(), ElementsAreArray(strides)); @@ -27,21 +30,143 @@ TEST(TensorDescriptor, Basic) TEST(TensorDescriptor, ComputeSize) { - constexpr auto dt = ckb::DataType::FP32; - std::vector lengths = {305, 130, 924}; - std::vector strides = {1000 * 1000, 1, 1000}; + constexpr auto dt = ckb::DataType::FP32; + constexpr size_t rank = 3; + ckt::Extent lengths = {305, 130, 924}; + ckt::Extent strides = {1001 * 1000, 1, 1000}; - ckt::TensorDescriptor
descriptor(lengths, strides); + ckt::TensorDescriptor descriptor(lengths, strides); - // Compute the location of the last item in memory, then add one - // to get the minimum size. - size_t expected_size = 1; + // Compute the location of the last item in memory, + // then add one to get the minimum size. + size_t expected_size = 1; + size_t expected_numel = 1; for(size_t i = 0; i < lengths.size(); ++i) { expected_size += (lengths[i] - 1) * strides[i]; + expected_numel *= lengths[i]; } - EXPECT_THAT(descriptor.get_element_space_size(), Ge(expected_size)); + EXPECT_THAT(descriptor.get_element_size(), Eq(expected_numel)); + EXPECT_THAT(descriptor.get_element_space_size(), Eq(expected_size)); EXPECT_THAT(descriptor.get_element_space_size_in_bytes(), - Ge(expected_size * ckt::data_type_sizeof(dt))); + Eq(expected_size * ckt::data_type_sizeof(dt))); +} + +TEST(TensorDescriptor, PackedRightLayout) +{ + const ckt::Extent lengths = {5125, 623, 1177, 1534}; + const auto strides = ckt::PackedRightLayout{}(lengths); + + EXPECT_THAT(strides, ElementsAreArray({623 * 1177 * 1534, 1177 * 1534, 1534, 1})); +} + +TEST(TensorDescriptor, PackedLeftLayout) +{ + const ckt::Extent lengths = {4, 15, 925, 662, 1462}; + const auto strides = ckt::PackedLeftLayout{}(lengths); + + EXPECT_THAT(strides, ElementsAreArray({1, 4, 4 * 15, 4 * 15 * 925, 4 * 15 * 925 * 662})); +} + +TEST(TensorDescriptor, MakeDescriptor) +{ + { + const ckt::Extent lengths = {10, 11, 12, 13, 14}; + + // Note: automatic inference of RANK. + const auto desc = + ckt::make_descriptor(lengths, ckt::PackedRightLayout{}); + + EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths)); + EXPECT_THAT(desc.get_strides(), + ElementsAreArray({11 * 12 * 13 * 14, 12 * 13 * 14, 13 * 14, 14, 1})); + } + + { + const ckt::Extent lengths = {4, 3, 2}; + const ckt::Extent strides = {60, 1, 7}; + + // Note: automatic inference of RANK. + const auto desc = ckt::make_descriptor(lengths, strides); + + EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths)); + EXPECT_THAT(desc.get_strides(), ElementsAreArray(strides)); + } +} + +TEST(TensorDescriptor, GetSpaceDescriptor) +{ + { + const auto desc = ckt::make_descriptor(ckt::Extent{4, 4, 4}, + ckt::PackedLeftLayout{}); + const auto space = desc.get_space_descriptor(); + + const auto expected = 4 * 4 * 4; + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected)); + } + + { + const ckt::Extent lengths = {6, 3, 4}; + const ckt::Extent strides = {102, 1, 2002}; + const auto desc = ckt::make_descriptor(lengths, strides); + const auto space = desc.get_space_descriptor(); + + // Compute the location of the last item in memory, + // then add one to get the minimum size. + size_t expected_size = 1; + for(size_t i = 0; i < lengths.size(); ++i) + { + expected_size += (lengths[i] - 1) * strides[i]; + } + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected_size})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected_size)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected_size)); + } +} + +TEST(TensorDescriptor, EmptyExtent) +{ + // A rank-0 tensor points to a single element + const auto desc = ckt::make_descriptor(ckt::Extent{}, ckt::Extent{}); + EXPECT_THAT(decltype(desc)::rank, Eq(0)); + EXPECT_THAT(desc.get_lengths().size(), Eq(0)); + EXPECT_THAT(desc.get_strides().size(), Eq(0)); + EXPECT_THAT(desc.get_element_size(), Eq(1)); + EXPECT_THAT(desc.get_element_space_size(), Eq(1)); + EXPECT_THAT(desc.get_element_space_size_in_bytes(), Eq(2)); + + // We expect a rank-1 tensor with the one dimension being 1. + const auto space = desc.get_space_descriptor(); + + const auto expected = 1; + + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size_in_bytes(), Eq(2)); +} + +TEST(TensorDescriptor, ExtentFromVector) +{ + EXPECT_THAT(ckt::Extent<4>::from_vector(std::vector{1, 2, 3, 4}), + ElementsAreArray({1, 2, 3, 4})); + + EXPECT_THAT([] { return ckt::Extent<5>::from_vector(std::vector{1, 2}); }, + Throws()); } diff --git a/experimental/builder/test/unit_tensor_foreach.cpp b/experimental/builder/test/unit_tensor_foreach.cpp new file mode 100644 index 0000000000..de635bc09b --- /dev/null +++ b/experimental/builder/test/unit_tensor_foreach.cpp @@ -0,0 +1,205 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using ::testing::Each; +using ::testing::Eq; + +TEST(TensorForeach, CalculateOffset) +{ + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{1, 2, 3}, ckt::Extent{100, 10, 1}), Eq(123)); + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{523, 266, 263}, ckt::Extent{1, 545, 10532}), + Eq(2915409)); + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{}, ckt::Extent{}), Eq(0)); + // Note: >4 GB overflow test + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{8, 2, 5, 7, 0, 4, 1, 3, 6, 9}, + ckt::Extent{1'000, + 1'000'000, + 10'000'000, + 1'000'000'000, + 1, + 10'000, + 100, + 10, + 100'000'000, + 100'000}), + Eq(size_t{7'652'948'130})); +} + +TEST(TensorForeach, VisitsCorrectCount) +{ + // tensor_foreach should visit every index exactly once. + // This test checks that the count is at least correct. + + const ckt::Extent shape = {10, 20, 30}; + + auto d_count = ckt::alloc_buffer(sizeof(uint64_t)); + ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t))); + + ckt::tensor_foreach(shape, [count = d_count.get()]([[maybe_unused]] const auto& index) { + atomicAdd(reinterpret_cast(count), 1); + }); + + uint64_t actual; + ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + + const auto expected = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + + EXPECT_THAT(actual, Eq(expected)); +} + +TEST(TensorForeach, VisitsEveryIndex) +{ + const ckt::Extent shape = {5, 6, 7, 8, 9, 10, 11}; + const auto total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + + // We know this is correct due to testing in unit_tensor_descriptor.cpp + const auto stride = ckt::PackedRightLayout{}(shape); + + auto d_output = ckt::alloc_buffer(sizeof(uint32_t) * total); + ckt::check_hip(hipMemset(d_output.get(), 0, sizeof(uint32_t) * total)); + + ckt::tensor_foreach(shape, [output = d_output.get(), stride](const auto& index) { + // We know this is correct due to the CalculateOffset test. + auto offset = ckt::calculate_offset(index, stride); + + // Use atomic add so that we can check that every index is visited exactly once. + atomicAdd(&reinterpret_cast(output)[offset], 1); + }); + + std::vector actual(total); + ckt::check_hip( + hipMemcpy(actual.data(), d_output.get(), sizeof(uint32_t) * total, hipMemcpyDeviceToHost)); + + EXPECT_THAT(actual, Each(Eq(1))); +} + +TEST(TensorForeach, FillTensorBuffer) +{ + auto desc = ckt::make_descriptor(ckt::Extent{31, 54, 13}, + ckt::PackedRightLayout{}); + + auto buffer = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, buffer.get(), [](size_t i) { return static_cast(i); }); + + std::vector h_buffer(desc.get_element_space_size()); + ckt::check_hip(hipMemcpy( + h_buffer.data(), buffer.get(), h_buffer.size() * sizeof(uint32_t), hipMemcpyDeviceToHost)); + + for(size_t i = 0; i < h_buffer.size(); ++i) + { + EXPECT_THAT(h_buffer[i], Eq(static_cast(i))); + } +} + +TEST(TensorForeach, FillTensor) +{ + // FillTensor with non-packed indices should not write out-of-bounds. + const ckt::Extent shape = {4, 23, 35}; + const ckt::Extent pad = {12, 53, 100}; + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + const auto strides = desc.get_strides(); + + auto size = desc.get_element_space_size(); + auto buffer = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, buffer.get(), []([[maybe_unused]] size_t i) { return 123; }); + + ckt::fill_tensor(desc, buffer.get(), []([[maybe_unused]] const auto& index) { return 1; }); + + auto d_error = ckt::alloc_buffer(sizeof(uint32_t) * size); + ckt::check_hip(hipMemset(d_error.get(), 0, sizeof(uint32_t))); + + ckt::tensor_foreach( + // Iterate over the entire padding so that we can check out-of-bounds elements + pad, + [shape, pad, strides, size, error = d_error.get(), tensor = buffer.get()]( + const auto& index) { + const auto offset = ckt::calculate_offset(index, strides); + const auto value = reinterpret_cast(tensor)[offset]; + + // Note: The space of the descriptor will not actually be (12, 53, 100) but + // more like (4, 53, 100), as the outer stride is irrelevant. So we have to + // perform an extra bounds check here. + if(offset < size) + { + // Check if the coordinate is within the shape bounds. + bool in_bounds = true; + for(size_t i = 0; i < shape.size(); ++i) + { + if(index[i] >= shape[i]) + { + in_bounds = false; + } + } + + // In-bounds elements are 1, out-of-bounds is 123. + if(in_bounds && value != 1) + { + atomicAdd(reinterpret_cast(error), 1); + } + else if(!in_bounds && value != 123) + { + atomicAdd(reinterpret_cast(error), 1); + } + } + }); + + uint32_t error_count = 0; + ckt::check_hip(hipMemcpy(&error_count, d_error.get(), sizeof(uint32_t), hipMemcpyDeviceToHost)); + + EXPECT_THAT(error_count, Eq(0)); +} + +TEST(TensorForeach, ClearTensorZeros) +{ + const ckt::Extent shape = {5, 4, 5, 4, 5, 4, 5, 6}; + const ckt::Extent pad = {6, 6, 6, 6, 6, 6, 6, 6}; + + const auto desc = + ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + + auto buffer = ckt::alloc_tensor_buffer(desc); + ckt::clear_tensor_buffer(desc, buffer.get()); + + // Check that all values are zeroed. + auto d_count = ckt::alloc_buffer(sizeof(uint64_t)); + ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t))); + + { + const auto size = desc.get_element_space_size(); + const auto strides = desc.get_strides(); + auto* count = d_count.get(); + const auto* tensor = reinterpret_cast(buffer.get()); + // Note: iterate over the entire pad, so that we can check out-of-bounds elements. + ckt::tensor_foreach(pad, + [count, tensor, strides, size]([[maybe_unused]] const auto& index) { + const auto offset = ckt::calculate_offset(index, strides); + + // Note: The space of the descriptor will not actually be (6, 6, + // ...) but more like (5, 6, ...), as the outer stride is + // irrelevant. So we have to perform an extra bounds check here. + if(offset < size && tensor[offset] != 0) + { + atomicAdd(reinterpret_cast(count), 1); + } + }); + } + + uint64_t actual; + ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + + EXPECT_THAT(actual, Eq(0)); +} diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp new file mode 100644 index 0000000000..06736ca624 --- /dev/null +++ b/experimental/builder/test/unit_validation.cpp @@ -0,0 +1,277 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/validation.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/testing/testing.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using testing::ElementsAreArray; +using testing::Eq; +using testing::StrEq; + +using ck_tile::test::MatchesReference; +using ck_tile::test::StringEqWithDiff; + +// Googletest cannot have both type AND value parameterized tests. +// For now just act lazy and use value template parameters. +template +struct Param +{ + constexpr static auto data_type = DT; + constexpr static auto shape = SHAPE; + constexpr static auto strides = STRIDES; + + constexpr static auto rank = shape.size(); + + static ckt::TensorDescriptor get_descriptor() + { + return ckt::make_descriptor(shape, strides); + } +}; + +template +struct ValidationReportTests : public ::testing::Test +{ +}; + +using Types = ::testing::Types< + Param, + Param, + Param, + Param>; + +TYPED_TEST_SUITE(ValidationReportTests, Types); + +TYPED_TEST(ValidationReportTests, SingleCorrect) +{ + const auto desc = TypeParam::get_descriptor(); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + // Generate a sort-of-random looking sequence + auto generator = [strides = desc.get_strides()](const auto& index) { + const auto flat_index = ckt::calculate_offset(index, strides); + return static_cast(flat_index * 10'000'019 % 768'351); + }; + + ckt::fill_tensor(desc, a.get(), generator); + ckt::fill_tensor(desc, b.get(), generator); + + ckt::ValidationReport report; + report.check("correct", desc, b.get(), a.get()); + + EXPECT_THAT(report.get_errors().size(), Eq(0)); +} + +TYPED_TEST(ValidationReportTests, SingleIncorrect) +{ + const auto desc = TypeParam::get_descriptor(); + const auto packed_strides = ckt::PackedRightLayout{}(desc.get_lengths()); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + ckt::fill_tensor(desc, a.get(), []([[maybe_unused]] const auto& i) { return 123; }); + ckt::fill_tensor(desc, b.get(), [packed_strides](const auto& index) { + const auto flat_index = ckt::calculate_offset(index, packed_strides); + return flat_index == 0 ? 0 : flat_index == 12345 ? 456 : flat_index == 999999 ? 1 : 123; + }); + + ckt::ValidationReport report; + report.check("incorrect", desc, b.get(), a.get()); + + const auto errors = report.get_errors(); + + const auto flat_size = desc.get_element_size(); + const auto expected_errors = flat_size >= 999999 ? 3 : flat_size >= 12345 ? 2 : 1; + + ASSERT_THAT(errors.size(), Eq(1)); + EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect")); + EXPECT_THAT(errors[0].wrong_elements, Eq(expected_errors)); + EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size())); +} + +TEST(ValidationReportTests, MultipleSomeIncorrect) +{ + ckt::ValidationReport report; + + { + auto desc = ckt::make_descriptor({'R', 'O', 'C', 'm'}, + ckt::PackedLeftLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer( + desc, a.get(), [](size_t i) { return ck::type_convert(i % 100); }); + ckt::fill_tensor_buffer( + desc, b.get(), [](size_t i) { return ck::type_convert(i % 101); }); + + report.check("incorrect 1", desc, b.get(), a.get()); + } + + { + auto desc = + ckt::make_descriptor({'H', 'I', 'P'}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return "ROCm"[i % 4]; }); + ckt::fill_tensor_buffer(desc, b.get(), [](size_t i) { + switch(i % 4) + { + case 0: return 'R'; + case 1: return 'O'; + case 2: return 'C'; + case 3: return 'm'; + default: return 'x'; + } + }); + + report.check("correct", desc, b.get(), a.get()); + } + + { + auto desc = ckt::make_descriptor({'G', 'P', 'U'}, + ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 1; }); + ckt::fill_tensor_buffer(desc, b.get(), []([[maybe_unused]] size_t i) { return 555; }); + + report.check("incorrect 2", desc, b.get(), a.get()); + } + + const auto errors = report.get_errors(); + + ASSERT_THAT(errors.size(), Eq(2)); + EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1")); + EXPECT_THAT(errors[0].wrong_elements, Eq(46840334)); + EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 2")); + EXPECT_THAT(errors[1].wrong_elements, Eq(482800)); +} + +// MatchesReference operates on the types defined in testing.hpp, so just +// quickly define a bunch of dummy values for that. + +struct DummySignature +{ +}; + +constexpr DummySignature DUMMY_SIGNATURE = {}; + +namespace ck_tile::builder::test { +template <> +struct Args +{ + auto make_a_descriptor() const + { + return make_descriptor(Extent{5, 5, 5, 5}, PackedRightLayout{}); + } + + auto make_b_descriptor() const + { + return make_descriptor(Extent{100000}, PackedLeftLayout{}); + } +}; + +template <> +struct Outputs +{ + void* a; + void* b; +}; + +template <> +ValidationReport validate(const Args& args, + Outputs actual, + Outputs expected) +{ + ValidationReport report; + report.check("a", args.make_a_descriptor(), actual.a, expected.a); + report.check("b", args.make_b_descriptor(), actual.b, expected.b); + return report; +} + +} // namespace ck_tile::builder::test + +TEST(MatchesReference, Correct) +{ + const ckt::Args args; + + const auto a_desc = args.make_a_descriptor(); + const auto b_desc = args.make_b_descriptor(); + + auto a_actual = ckt::alloc_tensor_buffer(a_desc); + auto b_actual = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2); + const auto actual = ckt::Outputs{ + .a = a_actual.get(), + .b = b_actual.get(), + }; + + auto a_expected = ckt::alloc_tensor_buffer(a_desc); + auto b_expected = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_expected.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2); + const auto expected = ckt::Outputs{ + .a = a_expected.get(), + .b = b_expected.get(), + }; + + EXPECT_THAT(actual, MatchesReference(args, expected)); +} + +TEST(MatchesReference, Incorrect) +{ + const ckt::Args args; + + const auto a_desc = args.make_a_descriptor(); + const auto b_desc = args.make_b_descriptor(); + + auto a_actual = ckt::alloc_tensor_buffer(a_desc); + auto b_actual = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2); + const auto actual = ckt::Outputs{ + .a = a_actual.get(), + .b = b_actual.get(), + }; + + auto a_expected = ckt::alloc_tensor_buffer(a_desc); + auto b_expected = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_expected.get(), 2); + ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2); + const auto expected = ckt::Outputs{ + .a = a_expected.get(), + .b = b_expected.get(), + }; + + testing::StringMatchResultListener listener; + EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener)); + + EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate")); +} From bbf0b1a3b377a37547d21f7ee1d9d18ef85853a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 5 Jan 2026 18:42:02 +0100 Subject: [PATCH 09/23] Fix large tensor grouped conv bwd data test (#3513) --- Jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index cb2f8631c5..78703ed9aa 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1469,8 +1469,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ - ./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" + make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) From 1224bc0a82fbf47e1452bc4dbd63371471e57d4a Mon Sep 17 00:00:00 2001 From: Estevan Vedovelli Date: Mon, 5 Jan 2026 13:03:30 -0500 Subject: [PATCH 10/23] Add support to gfx1153 and fix gfx115X WMMA config (#3496) * Support for gfx115X * Changes for gfx115X * Add gfx1153 * Update changelog --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- CHANGELOG.md | 1 + include/ck/utility/amd_wmma.hpp | 3 ++- include/ck_tile/core/arch/arch.hpp | 5 +++++ include/ck_tile/core/config.hpp | 8 ++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce8e5197a8..b149a74df3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added FP8 KV cache support for FMHA batch prefill. +* Added support for gfx1153 target. * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. ### Changed diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 35389bda37..057687985d 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -10,7 +10,8 @@ namespace ck { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ - defined(__gfx1103__) || defined(__gfx11_generic__) + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) #define __gfx11__ #endif diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index a162195390..c5c1a6e2c6 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -87,6 +87,7 @@ enum struct amdgcn_target_id GFX1150 = 0x1150, GFX1151 = 0x1151, GFX1152 = 0x1152, + GFX1153 = 0x1153, GFX11_GENERIC = 0x11FF, GFX1200 = 0x1200, GFX1201 = 0x1201, @@ -282,6 +283,7 @@ constexpr auto get_compiler_target() MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1153, GFX1153); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201); @@ -348,6 +350,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const* MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1150", GFX1150); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1151", GFX1151); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1152", GFX1152); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1153", GFX1153); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx11_generic", GFX11_GENERIC); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1200", GFX1200); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1201", GFX1201); @@ -603,6 +606,7 @@ CK_TILE_HOST_DEVICE constexpr auto get_compiler_target() MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1153, GFX1153); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201); @@ -683,6 +687,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* tes MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1150", GFX1150); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1151", GFX1151); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1152", GFX1152); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1153", GFX1153); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx11_generic", GFX11_GENERIC); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1200", GFX1200); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1201", GFX1201); diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 7830749efb..fed9209bad 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -315,6 +315,7 @@ namespace ck_tile::core { * @var CK_TILE_ARCH_GFX1102 Indicates if the compiler target architecture is GFX1102. * @var CK_TILE_ARCH_GFX1151 Indicates if the compiler target architecture is GFX1151. * @var CK_TILE_ARCH_GFX1152 Indicates if the compiler target architecture is GFX1152. + * @var CK_TILE_ARCH_GFX1153 Indicates if the compiler target architecture is GFX1153. * @var CK_TILE_ARCH_GFX11_GENERIC Indicates if the compiler target architecture is GFX11 generic. * @var CK_TILE_ARCH_GFX1200 Indicates if the compiler target architecture is GFX1200. * @var CK_TILE_ARCH_GFX1201 Indicates if the compiler target architecture is GFX1201. @@ -468,6 +469,12 @@ struct amdgcn_compiler_target_state static constexpr bool CK_TILE_ARCH_GFX1152 = false; #endif // __gfx1152__ +#if defined(__gfx1153__) + static constexpr bool CK_TILE_ARCH_GFX1153 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1153 = false; +#endif // __gfx1153__ + #if defined(__gfx11_generic__) static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = true; #else @@ -538,6 +545,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1150, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1151, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1152, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1153, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX11_GENERIC, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1200, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1201, \ From 2b563ad04828c5c970f7544d49831f33203587fb Mon Sep 17 00:00:00 2001 From: joyeamd Date: Tue, 6 Jan 2026 05:49:26 +0800 Subject: [PATCH 11/23] Joye/revise wp pipeline (#3493) * [CK_TILE] unify double and single lds implementation (#108) Unify LDS buffer management API for single and double buffering modes This change consolidates the Local Data Store (LDS) buffer management by: Merging single and double LDS buffer APIs into a unified interface Implementing ping-pong address calculation in pipeline when double LDS is enabled Computing pong buffer addresses dynamically using base address offsets --------- Co-authored-by: joye Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update wp_pipeline * fix a c++17 issue * update for ci errors * fix ci issues * include a header to fix ci errors * fix some rebase issues * update with rebase --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/ck_tile/ops/gemm.hpp | 1 + .../gemm/block/block_wp_asmem_breg_creg.hpp | 212 +++++ .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 34 +- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 175 ++-- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 127 ++- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 29 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 50 +- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 19 +- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 795 ++++++------------ .../gemm_quant/kernel/gemm_quant_kernel.hpp | 139 +-- .../kernel/grouped_gemm_quant_kernel.hpp | 49 +- ...p_bquant_pipeline_ag_bg_cr_base_policy.hpp | 42 + .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 23 +- 13 files changed, 766 insertions(+), 929 deletions(-) create mode 100644 include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 0eaedbfb3a..2c3a161121 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -25,6 +25,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp new file mode 100644 index 0000000000..4fc180b42b --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp @@ -0,0 +1,212 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on register +// C is block distributed tensor +template +struct BlockWeightPreshuffleASmemBRegCReg +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM; + static constexpr index_t KPerBlockPerIter = WarpGemm::kK; + + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + statically_indexed_array preloaded_a_warp_tensor; + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence<1>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + template + CK_TILE_DEVICE auto MakeALoadWindows(SmemBlockWindow& a_block_window) const + { + constexpr auto a_load_dstr = make_static_tile_distribution(MakeABlockDistributionEncode()); + + // create MIterPerWarp × KIterPerWarp window + return generate_tuple( + [&](auto kIter) { + return generate_tuple( + [&](auto mIter) { + return make_tile_window( + get_slice_tile( + a_block_window, + sequence{}, + sequence<(mIter + 1) * MPerBlockPerIter, + (kIter + 1) * KPerBlockPerIter>{}), + a_load_dstr); + }, + number{}); + }, + number{}); + } + + template + CK_TILE_DEVICE void LocalPrefetch(const ALoadWindows& a_load_windows) + { + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + + load_tile(preloaded_a_warp_tensor(loadIter), + a_load_windows[number{}][number{}]); + }); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ALoadWindows& a_load_windows, + BFlatBlockTensor& b_block_tensor, + const BFlatDistribution&) + { + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + using CWarpDstr = typename WarpGemm::CWarpDstr; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + using BWarpTensor = typename WarpGemm::BWarpTensor; + + constexpr auto b_block_y_lengths = + to_sequence(BFlatDistribution{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto b_block_y_index_zeros = + uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + BWarpTensor b_warp_tensor; + CWarpTensor c_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + typename sequence_split::right_type{}), + merge_sequences( + sequence<1, 1>{}, + typename sequence_split::right_type{})); + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WarpGemm{}( + c_warp_tensor, preloaded_a_warp_tensor(number{}), b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + + load_tile(preloaded_a_warp_tensor(number{}), + a_load_windows[number{}][number{}]); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 5ba5699dda..3f028ead2b 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -303,24 +303,15 @@ struct GroupedGemmKernel CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // TO DO: // Can we simplify this branching logic? if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - kargs.ds_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); + RunGemmWithPipelineSelection2LDS( + a_ptr, b_ptr, c_ptr, kargs.ds_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } else // SingleSmemBuffer { @@ -331,7 +322,7 @@ struct GroupedGemmKernel b_ptr, kargs.ds_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -343,7 +334,7 @@ struct GroupedGemmKernel {b_ptr}, kargs.ds_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -425,9 +416,7 @@ struct GroupedGemmKernel * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer - * @param ds_ptr input Ds pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. @@ -439,8 +428,7 @@ struct GroupedGemmKernel const BDataType* b_ptr, CDataType* c_ptr, const std::array& ds_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, + void* __restrict__ smem_ptr, const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -460,8 +448,8 @@ struct GroupedGemmKernel amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + const auto& c_block_tile = + GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); // Run Epilogue Pipeline if(kargs.k_batch == 1) @@ -469,7 +457,7 @@ struct GroupedGemmKernel auto c_block_window = Base::template MakeCBlockWindows( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); } else { @@ -477,7 +465,7 @@ struct GroupedGemmKernel Base::template MakeCBlockWindows( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); } } diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 65f58a8ca5..c77459b4ec 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -978,7 +978,7 @@ struct UniversalGemmKernel * @param bs_ptr input Bs pointer * @param ds_ptr input Ds pointer * @param e_ptr output E pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. @@ -990,7 +990,7 @@ struct UniversalGemmKernel const std::array& bs_ptr, const std::array& ds_ptr, EDataType* e_ptr, - void* smem_ptr_0, + void* smem_ptr, const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -1008,7 +1008,7 @@ struct UniversalGemmKernel // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}.template operator()( - as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr); const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); // Run Epilogue Pipeline @@ -1016,77 +1016,63 @@ struct UniversalGemmKernel { auto c_block_window = MakeCBlockWindows( e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } else { auto c_block_window = MakeCBlockWindows( e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } } - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param as_ptr input As pointer - * @param bs_ptr input Bs pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) + CK_TILE_DEVICE static auto + GetTileCoordinates(const KernelArgs& kargs) -> tuple { - // Create block windows using specialized methods - const auto& as_block_window = - MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); - const auto& bs_block_window = - MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + index_t iM, iN; - const index_t num_loop = - amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + // Regular launch: use 1D block indexing + const auto blockId = amd_wave_read_first_lane(blockIdx.x); + const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + iM = tile_m; + iN = tile_n; - // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, - AElementWise{}, - bs_block_window, - BElementWise{}, - num_loop, - smem_ptr_0, - smem_ptr_1); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - // Run Epilogue Pipeline - if(kargs.k_batch == 1) + return make_tuple(i_m, i_n); + } + + // Helper functions + CK_TILE_DEVICE static auto GetBlockId() -> index_t + { + // For 1D regular launch + return amd_wave_read_first_lane(get_block_id()); + } + + CK_TILE_DEVICE static auto GetGridSize() -> index_t + { + // For 1D regular launch + return amd_wave_read_first_lane(get_grid_size()); + } + + // Helper to get total number of tiles, handling both dim3 and index_t return types + template + CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t + { + auto grid_size = TilePartitioner::GridSize(std::forward(args)...); + + using GridSizeType = decltype(grid_size); + + if constexpr(std::is_same_v) { - auto c_block_window = MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); - - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + // GridSize returns dim3: compute total tiles as x * y * z + return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z); } else { - auto c_block_window = MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); - - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + // GridSize returns scalar (index_t): use directly + return amd_wave_read_first_lane(grid_size); } } @@ -1123,36 +1109,12 @@ struct UniversalGemmKernel } // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + constexpr auto scheduler_type = + GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1); + RunGemm( + as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } // Persistent kernel entry point @@ -1199,34 +1161,19 @@ struct UniversalGemmKernel } // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // Run the GEMM - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n); + // Advance to the next work item block_id += grid_size; if(block_id >= num_work) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 343e37ed66..4973d9c941 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,12 +64,17 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template + template CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_tile(dst_block_tile, dram_tile_window); + load_int4_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } @@ -217,22 +222,17 @@ struct GemmPipelineAgBgCrImplBase return std::move(a_copy_dram_window); } - template - CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const ALdsTensorView& a_lds_block_view, - const ALdsLoadTileDistr&, - const array& offset = {0, 0}) const + template + CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr&) const { - // A DRAM tile window for load - auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); - - // A LDS tile window for store auto a_lds_shape = []() { if constexpr(is_a_load_tr) return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(); + auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}); auto a_lds_load_tile_distr = []() { @@ -244,32 +244,73 @@ struct GemmPipelineAgBgCrImplBase else return ALdsLoadTileDistr{}; }(); + auto a_lds_gemm_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr); + return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); + } + + template < + typename ADramBlockWindowTmp, + typename ALdsTensorView, + typename ALdsLoadTileDistr, + typename std::enable_if_t::value, bool>* = nullptr> + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr& a_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // A DRAM tile window for load + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); + + // Create LDS windows + auto [a_copy_lds_window, a_lds_gemm_window] = + MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr); + return make_tuple(std::move(a_copy_dram_window), std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); } - template - CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BLdsTensorView& b_lds_block_view, - const BLdsLoadTileDistr&, + // Unified GetAWindows that supports 1, 2, or 3 LDS buffers + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorViewsTuple& a_lds_block_views_tuple, + const ALdsLoadTileDistr& a_lds_load_tile_distr, const array& offset = {0, 0}) const { // A DRAM tile window for load - auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); - // TODO: Do we really need those two tile windows??? - // They're exactly same... - // B LDS tile window for store + // Create LDS windows for each buffer + constexpr index_t num_buffers = ALdsTensorViewsTuple::size(); + auto a_lds_windows = generate_tuple( + [&](auto i) { + return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr); + }, + number{}); + + // Return: (dram_window, lds_windows_tuple) + // lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i) + return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows)); + } + + template + CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr&) const + { auto b_lds_shape = []() { if constexpr(is_b_load_tr) return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(); + auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); using BLdsDataType = @@ -286,13 +327,61 @@ struct GemmPipelineAgBgCrImplBase else return BLdsLoadTileDistr{}; }(); + auto b_lds_gemm_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr); + return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window)); + } + + template < + typename BDramBlockWindowTmp, + typename BLdsTensorView, + typename BLdsLoadTileDistr, + typename std::enable_if_t::value, bool>* = nullptr> + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr& b_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // A DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + + // Create LDS windows + auto [b_copy_lds_window, b_lds_gemm_window] = + MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr); + return make_tuple(std::move(b_copy_dram_window), std::move(b_copy_lds_window), std::move(b_lds_gemm_window)); } + + // Unified GetBWindows that supports 1, 2, or 3 LDS buffers + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorViewsTuple& b_lds_block_views_tuple, + const BLdsLoadTileDistr& b_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // B DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + + // Create LDS windows for each buffer + constexpr index_t num_buffers = BLdsTensorViewsTuple::size(); + auto b_lds_windows = generate_tuple( + [&](auto i) { + return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr); + }, + number{}); + + // Return: (dram_window, lds_windows_tuple) + // lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i) + return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows)); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 0b2cdde05e..8acfea4580 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -158,6 +158,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; @@ -172,7 +174,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync(); + constexpr index_t smem_size = Policy::template GetSmemSize(); + return 2 * smem_size; } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() @@ -240,8 +243,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync); @@ -303,8 +305,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}); // this pipeline has a pair of LDS buffers per logical tile - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + constexpr index_t smem_size = Policy::template GetSmemSize(); + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block1, b_lds_block1] = + Base::GetABLdsTensorViews(static_cast(p_smem) + smem_size); // set up LDS tile shapes constexpr auto a_lds_shape = []() { @@ -534,21 +538,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}.template operator()( a_dram_block_window_tmp, a_element_func, b_dram_block_window_tmp, b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -559,8 +560,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer"); + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() { // clang-format off @@ -191,7 +193,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return Policy::template GetSmemSize(); + constexpr index_t smem_size = Policy::template GetSmemSize(); + return 2 * smem_size; } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() @@ -281,8 +284,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { using ADramBlockWindowTmp = remove_cvref_t{}, AsDramBlockWindowTmp>>; @@ -324,8 +326,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // global read 0 ////////////// LDS desc, window & register ///////////////// - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + constexpr index_t smem_size = Policy::template GetSmemSize(); + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block1, b_lds_block1] = + Base::GetABLdsTensorViews(static_cast(p_smem) + smem_size); constexpr auto a_lds_shape = []() { if constexpr(is_a_load_tr_v()) @@ -680,8 +684,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* p_smem_0, - void* p_smem_1) const + void* p_smem) const { const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -693,8 +696,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -708,8 +710,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -721,8 +722,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, [](auto& e, const BDataType& b) { e = b; }, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -738,8 +738,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 index_t num_loop, bool has_hot_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; @@ -751,8 +750,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, PassThrough, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } @@ -769,16 +767,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* p_smem_0, - void* p_smem_1) const + void* p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), a_element_func, ck_tile::make_tuple(b_dram_block_window_tmp), b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); } template CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), ck_tile::make_tuple(b_dram_block_window_tmp), num_loop, - p_smem_0, - p_smem_1); + p_smem); } template index_t num_loop, bool has_hot_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), ck_tile::make_tuple(b_dram_block_window_tmp), num_loop, has_hot_loop, tail_number, - p_smem_0, - p_smem_1); + p_smem); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 019a828ec0..e90c6a27d7 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { @@ -201,6 +202,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using TileShape = typename Problem::BlockGemmShape; + constexpr index_t kNPerBlock = TileShape::kN; + constexpr index_t kKPerBlock = TileShape::kK; + constexpr index_t NIterPerWarp = + kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1); + constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2); + constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; @@ -213,13 +220,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy #endif constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; - constexpr index_t KRepeat = 1; + constexpr index_t KRepeat = KIterPerWarp; static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); constexpr index_t NBPerLoad = 1; constexpr index_t NThdPerWave = 1; constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp - constexpr index_t NRepeat = 1; + constexpr index_t NRepeat = NIterPerWarp; constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; return make_static_tile_distribution( @@ -232,8 +239,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy tuple, sequence<0, 1, 2>>, // which direction tuple, sequence<1, 2, 2>>, // which index // - sequence<1, 1, 2, 2>, - sequence<0, 3, 0, 3>>{}); + sequence<1, 2, 1, 2>, + sequence<0, 0, 3, 3>>{}); } template @@ -307,7 +314,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy typename Problem::CDataType, BlockWarps, WarpGemm>; - return BlockWeightPreshuffleASmemBSmemCRegV1{}; + return BlockWeightPreshuffleASmemBRegCReg{}; } /** * @brief Get the vector store size for C tensor. @@ -325,7 +332,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { using BlockGemm = remove_cvref_t())>; - using WG_ = typename BlockGemm::WG; + using WG_ = typename BlockGemm::WarpGemm; constexpr bool TransposeC = Problem::TransposeC; using CLayout = typename Problem::CLayout; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index f64901755b..c9499106de 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -32,19 +32,34 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 template CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) { - if(tail_number == TailNumber::Odd) + if(has_hot_loop) { - return run_func(bool_constant{}, - integral_constant{}); + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else // Even tail number + { + return run_func(bool_constant{}, + integral_constant{}); + } } - else // Even tail number + else { - return run_func(bool_constant{}, - integral_constant{}); + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else // Even tail number + { + return run_func(bool_constant{}, + integral_constant{}); + } } - return run_func(bool_constant{}, integral_constant{}); } }; @@ -52,7 +67,8 @@ template { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -75,11 +91,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 using BlockWeightPreshuffle = remove_cvref_t())>; - static constexpr auto config = - BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read @@ -95,6 +106,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock; + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; @@ -131,12 +144,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 using BlockWarps = remove_cvref_t; using WarpTile = remove_cvref_t; - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); + static constexpr index_t MWarp = BlockWarps::at(I0); + static constexpr index_t NWarp = BlockWarps::at(I1); - static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + static constexpr index_t WarpTileM = WarpTile::at(I0); + static constexpr index_t WarpTileN = WarpTile::at(I1); + static constexpr index_t WarpTileK = WarpTile::at(I2); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN); + static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK; static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; @@ -154,20 +171,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 #else static constexpr index_t mfma_per_wg = 1; #endif - static constexpr index_t dsread_per_wg = - max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); + static constexpr index_t dsread_per_wg = max( + index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); #if defined(__HIP_DEVICE_COMPILE__) - static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) % Problem::VectorLoadSize == 0); #endif - static constexpr index_t dsread_num_perK = - WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize; + static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) * + MIterPerWarp / WaveSize / Problem::VectorLoadSize; static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; static constexpr index_t Aload_num_perK = dswrite_num_perK; static constexpr index_t Aload_rep = dswrite_rep; - static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize; + static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize; static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; @@ -187,7 +204,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // clang-format off return concat('_', "pipeline_AGmemBGmemCRegV2", concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), - concat('x', WG::kM, WG::kN, WG::kK), + concat('x', WarpTileM, WarpTileN, WarpTileK), concat('x', GetVectorSizeA(), GetVectorSizeB()), concat('x', kPadM, kPadN, kPadK)); @@ -195,14 +212,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr index_t Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return PipelinePolicy::template GetSmemSize(); + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + return DoubleSmemBuffer ? 2 * smem_size : smem_size; } // dsread_perM: how many LDS reads want to issue in this M-iter @@ -515,515 +534,184 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // __builtin_amdgcn_sched_barrier(0); } - template ::value && - !is_detected::value, - bool>* = nullptr, - index_t UnaryOpSize_ = 8> - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + struct PipelineImpl : public PipelineImplBase { - static_assert( - std::is_same_v>, - "wrong!"); + using Base = PipelineImplBase; - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], - "wrong!"); - static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; - const index_t iMWarp = get_warp_id() / NWarp; - - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - __builtin_amdgcn_sched_barrier(0); - - // A tile in LDS - ADataType* p_a_lds_ping = static_cast(p_smem_ping); - ADataType* p_a_lds_pong = static_cast(p_smem_pong); - - constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeALdsBlockDescriptor(); - - auto a_lds_block_ping = - make_tensor_view(p_a_lds_ping, a_lds_block_desc); - auto a_lds_block_pong = - make_tensor_view(p_a_lds_pong, a_lds_block_desc); - - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); - - auto a_copy_lds_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - - auto a_copy_lds_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - - // ping-pong window for A LDS - auto a_warp_window_ping_tmp = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - - auto a_warp_window_pong_tmp = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_ping; - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_pong; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - // Block GEMM - auto block_weight_preshuffle = BlockWeightPreshuffle(); - // Acc register tile - auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); - - // B flat DRAM window for load - auto b_flat_distribution = - PipelinePolicy::template MakeBFlatDramTileDistribution(); - auto b_flat_dram_window = // tile_window_with_static_distribution - make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - b_flat_distribution); - - // pingpong buffer for B - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; - using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); - - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_flat_dram_windows; - - statically_indexed_array, NIterPerWarp> - b_warp_tensor_ping; - - statically_indexed_array, NIterPerWarp> - b_warp_tensor_pong; - - // Prefetch A0 - auto a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // Prefill A0 - auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - - __builtin_amdgcn_sched_barrier(0); - - // Prefetch A1 - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - block_sync_lds(); - - // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); - }); - __builtin_amdgcn_sched_barrier(0); - - // MAIN LOOP - index_t iCounter = (num_loop - 1) / 2; - while(iCounter > 0) + template ::value && + !is_detected::value, + bool>* = nullptr, + index_t UnaryOpSize_ = 8> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const { - // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_assert( + std::is_same_v>, + "wrong!"); - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); + // A tile in LDS + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); - // Prefill A(2i+1) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); - // Prefetch A(2i+2) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + auto a_lds_blocks = generate_tuple( + [&](auto i) { + ADataType* p_a_lds = static_cast( + static_cast(static_cast(p_smem) + smem_size * i.value)); + return make_tensor_view(p_a_lds, a_lds_block_desc); + }, + number<2>{}); - // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( + BlockWeightPreshuffle::MakeABlockDistributionEncode()); + auto&& windows_result = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr); + auto&& a_copy_dram_window = windows_result.template get<0>(); + auto&& a_lds_windows = windows_result.template get<1>(); + auto a_copy_lds_windows = generate_tuple( + [&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); }, + number<2>{}); + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + auto a_load_windows = generate_tuple( + [&](auto i) -> decltype(auto) { + return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]); + }, + number<2>{}); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window(b_flat_dram_block_window_tmp + .get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, + number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock); - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BBlockTile = + decltype(make_static_distributed_tensor(b_flat_distribution)); + + ABlockTile a_global_tile; + BBlockTile b_global_tile[2]; + + // // Prefetch A0 + Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + + Base::template GlobalPrefetch( + b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + + // Prefill A0 + Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); + + // Prefetch A1 + Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]); + + __builtin_amdgcn_sched_barrier(0); + // MAIN LOOP + if constexpr(HasHotLoop) + { + index_t i_global_read = amd_wave_read_first_lane(2); + do + { { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } + Base::template GlobalPrefetch( + b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); + Base::GlobalPrefetch( + a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + block_weight_preshuffle(c_block_tile, + a_load_windows[I0], + b_global_tile[0], + b_flat_distribution); - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]); + HotLoopScheduler(); + } { - block_sync_lds(); + Base::template GlobalPrefetch( + b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); + Base::GlobalPrefetch( + a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + block_weight_preshuffle(c_block_tile, + a_load_windows[I1], + b_global_tile[1], + b_flat_distribution); + + block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]); + HotLoopScheduler(); } - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + i_global_read += 2; + } while(i_global_read < num_loop); + } - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); - }); - HotLoopScheduler(); + // tail + if constexpr(TailNum == TailNumber::Even) + { + { + Base::template GlobalPrefetch( + b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); + block_weight_preshuffle( + c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution); + block_sync_lds(); + block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]); + Last2ndHotLoopScheduler(); + } + { + block_weight_preshuffle( + c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution); + LastHotLoopScheduler(); + } + } + else if constexpr(TailNum == TailNumber::Odd) + { + block_weight_preshuffle( + c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution); + LastHotLoopScheduler(); + } - // Next K - - // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(2i+2) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - - // Prefetch A(2i+3) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); - }); - HotLoopScheduler(); - - iCounter--; + return c_block_tile; } - - // tail - if constexpr(TailNum == TailNumber::Even) - { - // __builtin_amdgcn_sched_barrier(0); - // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(loopK) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - - // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - // TailHotLoopScheduler(); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); - }); - - Last2ndHotLoopScheduler(); - - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - LastHotLoopScheduler(); - } - else if constexpr(TailNum == TailNumber::Odd) - { - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - LastHotLoopScheduler(); - } - - return c_block_tile; - } + }; // called from universal gemm kernel template (a_dram_block_window_tmp[number<0>{}], - PassThrough, - b_flat_dram_block_window_tmp[number<0>{}], - num_loop, - p_smem_ping, - p_smem_pong); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp[number<0>{}], + a_element_func, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } // called from general gemm kernel @@ -1066,23 +751,21 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + void* p_smem) const { - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - const auto RunPipeline = [&](auto bool_val, auto tail_num_) { - (void)bool_val; // Suppress unused parameter warning - constexpr auto tail_num = tail_num_.value; + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr auto PassThrough = [](const ADataType& a) { return a; }; - return operator()(a_dram_block_window_tmp, - PassThrough, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_ping, - p_smem_pong); + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } // called from grouped gemm kernel @@ -1095,21 +778,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { - const auto RunPipeline = [&](auto bool_val, auto tail_num_) { - (void)bool_val; // Suppress unused parameter warning - constexpr auto tail_num = tail_num_.value; + const auto has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr auto PassThrough = [](const auto& x) { return x; }; - return operator()(a_dram_block_window_tmp, - PassThrough, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_0, - p_smem_1); + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 8aab756ccf..4f79361037 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1723,7 +1723,7 @@ struct QuantGemmKernel * @param aq_ptr input AQ pointer * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. @@ -1735,7 +1735,7 @@ struct QuantGemmKernel const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, + void* smem_ptr, const QuantGemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -1762,7 +1762,7 @@ struct QuantGemmKernel m = kargs.M; } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m); + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m); } else if constexpr(kQuantType == QuantType::BQuantGrouped) { @@ -1772,7 +1772,7 @@ struct QuantGemmKernel n = kargs.N; } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n); + a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n); } else if constexpr(kQuantType == QuantType::ABQuantGrouped) { @@ -1788,7 +1788,7 @@ struct QuantGemmKernel aq_block_window, bq_block_window, num_loop, - smem_ptr_0, + smem_ptr, m, n); } @@ -1796,7 +1796,7 @@ struct QuantGemmKernel kQuantType == QuantType::TensorQuant) { return GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0); + a_block_window, b_block_window, num_loop, smem_ptr); } }(); @@ -1812,14 +1812,14 @@ struct QuantGemmKernel kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -1828,7 +1828,7 @@ struct QuantGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } else @@ -1840,14 +1840,14 @@ struct QuantGemmKernel kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -1856,89 +1856,7 @@ struct QuantGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); - } - } - } - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param aq_ptr input AQ pointer - * @param bq_ptr input BQ pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - [[maybe_unused]] const AQDataType* aq_ptr, - const BQDataType* bq_ptr, - CDataType* c_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const QuantGemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create block windows using specialized methods - const auto& a_block_window = - MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); - const auto& b_block_window = - MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); - - const index_t num_loop = - amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = [&]() { - if constexpr(kQuantType == QuantType::BQuantGrouped) - { - index_t n = 0; - if constexpr(PreshuffleQuant) - { - n = kargs.N; - } - return GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - smem_ptr_0, - smem_ptr_1, - n); - } - else - { - return nullptr; - } - }(); - - const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - - // Run Epilogue Pipeline with k_batch dispatch - if constexpr(kQuantType == QuantType::BQuantGrouped) - { - if(k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } } @@ -1961,37 +1879,10 @@ struct QuantGemmKernel CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - - RunGemm2LDS(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - RunGemm(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemm( + a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 1c98a372be..06a80c8b55 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -318,21 +318,18 @@ struct QuantGroupedGemmKernel CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // Only for BQuantGrouped DoubleSmemBuffer is supported if constexpr(GemmPipeline::DoubleSmemBuffer == true && kQuantType == QuantType::BQuantGrouped) { - - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemmWithPipelineSelection2LDS(a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, - smem_ptr_1, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -348,7 +345,7 @@ struct QuantGroupedGemmKernel aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -361,7 +358,7 @@ struct QuantGroupedGemmKernel aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -377,8 +374,7 @@ struct QuantGroupedGemmKernel [[maybe_unused]] const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, - void* smem_ptr_1, + void* smem_ptr, const QuantGroupedGemmKernelArgs& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -399,27 +395,22 @@ struct QuantGroupedGemmKernel const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); // Run GEMM cooperatively by whole workgroup - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, bq_block_window, num_loop, tail_num, smem_ptr); // Run Epilogue Pipeline with split_k dispatch if(kargs.k_batch == 1) { auto c_block_window = Base::template MakeCBlockWindow( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else { auto c_block_window = Base::template MakeCBlockWindow( c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } } @@ -435,7 +426,7 @@ struct QuantGroupedGemmKernel * @param aq_ptr input AQ pointer * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k * batch. @@ -449,7 +440,7 @@ struct QuantGroupedGemmKernel const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, + void* smem_ptr, const QuantGroupedGemmKernelArgs& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -481,7 +472,7 @@ struct QuantGroupedGemmKernel num_loop, has_hot_loop, tail_num, - smem_ptr_0); + smem_ptr); } else if constexpr(kQuantType == QuantType::BQuantGrouped) { @@ -491,13 +482,13 @@ struct QuantGroupedGemmKernel num_loop, has_hot_loop, tail_num, - smem_ptr_0); + smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) { return GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr); } }(); @@ -510,14 +501,14 @@ struct QuantGroupedGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -526,7 +517,7 @@ struct QuantGroupedGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } else @@ -538,14 +529,14 @@ struct QuantGroupedGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped || kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } else if constexpr(kQuantType == QuantType::RowColQuant) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -554,7 +545,7 @@ struct QuantGroupedGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp index b155297054..b7dc0bd616 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp @@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); } + // as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed; + // move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here + // temporarily + template + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 18b236c29b..43f37ec4d8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -184,8 +184,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t n, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + void* p_smem) const { static_assert( std::is_same_v> && @@ -210,8 +209,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV __builtin_amdgcn_sched_barrier(0); // A tile in LDS - ADataType* p_a_lds_ping = static_cast(p_smem_ping); - ADataType* p_a_lds_pong = static_cast(p_smem_pong); + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + ADataType* p_a_lds_ping = static_cast(p_smem); + ADataType* p_a_lds_pong = + reinterpret_cast(static_cast(p_smem) + smem_size); constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); @@ -561,9 +562,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong, - index_t n = 0) const // Default value for non-preshuffle case + void* p_smem, + index_t n = 0) const { return operator()( a_dram_block_window_tmp, @@ -572,8 +572,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV bq_dram_block_window_tmp, n, num_loop, - p_smem_ping, - p_smem_pong); + p_smem); } template Date: Tue, 6 Jan 2026 15:39:00 +0800 Subject: [PATCH 12/23] Merge some updates for ck_tile headers (#3342) * fix some issues from internal branch * update cshuffle_epilogue * update cshuffle_epilogue * update cshuffle * update warp_gemm --- include/ck_tile/core/arch/arch.hpp | 26 ++-- .../ck_tile/core/tensor/transpose_tile.hpp | 29 +--- .../ops/epilogue/cshuffle_epilogue.hpp | 35 +++-- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 124 ++++++++++++------ ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 13 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 8 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 10 ++ .../gemm/warp/warp_gemm_attribute_wmma.hpp | 17 ++- .../warp/warp_gemm_attribute_wmma_impl.hpp | 7 +- ..._gemm_attribute_wmma_impl_16bit_traits.hpp | 8 ++ ...p_gemm_attribute_wmma_impl_8bit_traits.hpp | 10 ++ ...p_gemm_attribute_wmma_impl_base_traits.hpp | 4 + .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 + .../gemm/test_gemm_pipeline_ut_cases.inc | 31 +++-- 14 files changed, 205 insertions(+), 119 deletions(-) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index c5c1a6e2c6..97e962f5a3 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1124,8 +1124,14 @@ CK_TILE_DEVICE static constexpr auto get_device_arch() { // FIXME(0): on all devices except gfx11 it returns gfx12_t // FIXME(1): during the host compilation pass it returns gfx12_t -#if defined(__gfx11__) +#if defined(__gfx103__) + return gfx103_t{}; +#elif defined(__gfx11__) return gfx11_t{}; +#elif defined(__gfx950__) + return gfx950_t{}; +#elif defined(__gfx9__) + return gfx9_t{}; #else return gfx12_t{}; #endif @@ -1146,26 +1152,10 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; } CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; } -CK_TILE_DEVICE static constexpr auto arch_tag_dispatch() -{ -#if defined(__gfx103__) - return gfx103_t{}; -#elif defined(__gfx11__) - return gfx11_t{}; -#elif defined(__gfx12__) - return gfx12_t{}; -#elif defined(__gfx950__) - return gfx950_t{}; -#elif defined(__gfx9__) - return gfx9_t{}; -#else - return gfx_invalid_t{}; -#endif -} } // namespace detail CK_TILE_DEVICE static constexpr auto get_n_lds_banks() { - return detail::get_n_lds_banks(detail::arch_tag_dispatch()); + return detail::get_n_lds_banks(get_device_arch()); } enum LLVMSchedGroupMask : int32_t diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp index e5a0664ec9..50927c5ca4 100644 --- a/include/ck_tile/core/tensor/transpose_tile.hpp +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -34,46 +34,23 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor(); constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor(); - // y_dim_out_to_in - // For swapped Hs tile case I need only get_rh_minor_to_y - // since rh_major are already swapped due to swapped Hs. - constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) { - using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode; - - map rh_minor_to_y_; - - static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) { - constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i]; - - rh_minor_to_y_(rh_minor) = i; - }); - - return rh_minor_to_y_; - }; - // In swapped Hs case -> tile // we have same rh_major, but reversed rh_minor! - constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{}); - constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{}); + constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); - // Is this really needed?? Should we have simple reverse here?? constexpr auto y_dim_out_to_in = [&] { map y_dim_out_to_in_; - for(const auto& [rh_minor, y_out] : rh_minor_to_y_out) - { - y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor]; - } + static_for<0, NDimY, 1>{}([&](auto i) { y_dim_out_to_in_(i) = NDimY - 1 - i; }); return y_dim_out_to_in_; }(); - constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); // input and output vector dim in the order of input Y dims constexpr index_t y_dim_vec_in = NDimY - 1; - constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1]; + constexpr index_t y_dim_vec_out = 0; // vector lengths constexpr index_t vec_length_in = y_lengths[y_dim_vec_in]; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index c73897f064..97f936fde9 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -333,14 +333,30 @@ struct CShuffleEpilogue { constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; // BlockedLayout - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}; + // this branch is for original a16w4 + if constexpr(is_any_of::value || + is_any_of::value) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 1>>{}; + } } }(); constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( @@ -351,7 +367,8 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + return lds_block_desc.get_element_space_size() * sizeof(ODataType); } template diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index c77459b4ec..628f5f7dc8 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -423,7 +423,7 @@ struct UniversalGemmKernel const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA() : GemmPipeline::template GetVectorSizeA(); - bool AsTesnorIsValid = {true}; + bool AsTensorIsValid = {true}; static_for<0, NumATensor, 1>{}([&](auto index) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -437,15 +437,27 @@ struct UniversalGemmKernel "Can't support K that is not a multiple of k_batch * KPerBlock " "without padding!"); } - AsTesnorIsValid = false; + AsTensorIsValid = false; } if(kargs.K % vectorSizeA != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.K % vectorSizeA; + constexpr ck_tile::index_t APackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + AsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + AsTensorIsValid = false; } - AsTesnorIsValid = false; } } else @@ -457,20 +469,33 @@ struct UniversalGemmKernel CK_TILE_ERROR( "Can't support M that is not a multiple of MPerBlock without padding!"); } - AsTesnorIsValid = false; + AsTensorIsValid = false; } if(kargs.M % vectorSizeA != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.M % vectorSizeA; + constexpr ck_tile::index_t APackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + + AsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + AsTensorIsValid = false; } - AsTesnorIsValid = false; } } }); - bool BsTesnorIsValid = {true}; + bool BsTensorIsValid = {true}; const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB() : GemmPipeline::template GetVectorSizeB(); static_for<0, NumBTensor, 1>{}([&](auto index) { @@ -484,47 +509,72 @@ struct UniversalGemmKernel CK_TILE_ERROR( "Can't support N that is not a multiple of NPerBlock without padding!"); } - BsTesnorIsValid = false; + BsTensorIsValid = false; } if(kargs.N % vectorSizeB != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.N % vectorSizeB; + constexpr ck_tile::index_t BPackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + BsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + BsTensorIsValid = false; } - BsTesnorIsValid = false; } - } - else - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) + else { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) { - CK_TILE_ERROR( - "Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + BsTensorIsValid = false; } - BsTesnorIsValid = false; - } - if(kargs.K % vectorSizeB != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + if(kargs.K % vectorSizeB != 0) { - CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + const auto remainder = kargs.K % vectorSizeB; + constexpr ck_tile::index_t BPackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) + { + BsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "K is not a multiple of vector load size for B tensor!"); + } + BsTensorIsValid = false; + } } - BsTesnorIsValid = false; } } }); - bool DTesnorIsValid = {true}; + bool DTensorIsValid = {true}; static_for<0, NumDTensor, 1>{}([&](auto index) { using DiLayout = remove_cvref_t>; if(std::is_same_v == false) { - DTesnorIsValid = false; + DTensorIsValid = false; } if constexpr(std::is_same_v) { @@ -535,7 +585,7 @@ struct UniversalGemmKernel CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " "NPerBlock without padding!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) { @@ -543,7 +593,7 @@ struct UniversalGemmKernel { CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } } else @@ -555,7 +605,7 @@ struct UniversalGemmKernel CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " "MPerBlock without padding!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) { @@ -563,7 +613,7 @@ struct UniversalGemmKernel { CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } } }); @@ -608,7 +658,7 @@ struct UniversalGemmKernel return false; } } - return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; + return AsTensorIsValid && BsTensorIsValid && DTensorIsValid; } CK_TILE_DEVICE static auto diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index d68da14ac5..6199142d98 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -845,10 +845,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr index_t smem_size_a = - integer_least_multiple(sizeof(typename Problem::ADataType) * - Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK, - 16); + using ADataType = remove_cvref_t; + constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor(); + constexpr index_t smem_size_a = integer_least_multiple( + a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16); return smem_size_a; } @@ -859,8 +859,9 @@ struct UniversalGemmBasePolicy std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>; - constexpr index_t smem_size_b = integer_least_multiple( - sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16); + constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); + constexpr index_t smem_size_b = integer_least_multiple( + b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); return smem_size_b; } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 47607a40f5..5b00eb244b 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -53,11 +53,11 @@ struct TileGemmUniversalTraits static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - using AsLayout = AsLayout_; - using BsLayout = BsLayout_; - using CLayout = CLayout_; + using AsLayout = AsLayout_; + using BsLayout = BsLayout_; + using CLayout = CLayout_; + static constexpr bool TransposeC = TransposeC_; - static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; static constexpr bool UsePersistentKernel = UsePersistentKernel_; static constexpr index_t NumWaveGroups = NumWaveGroups_; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c0fbf8e5d3..7bcc9107da 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -306,6 +306,16 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; +using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed = + WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed = + WarpGemmImpl, + 2>>; + template using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index ff2ba501fe..ef31d06c9c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -68,6 +68,19 @@ struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + // When kTransC is true and A/B types differ, we need an impl with swapped types + using TransposedImpl = + std::conditional_t, + WarpGemmAttributeWmmaImpl>, + Impl>; + using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; using CDataType = typename Impl::CDataType; @@ -104,7 +117,7 @@ struct WarpGemmAttributeWmma { if constexpr(kTransC) { - Impl{}(c_vec, b_vec, a_vec, bool_constant{}); + TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant{}); } else { @@ -117,7 +130,7 @@ struct WarpGemmAttributeWmma { if constexpr(kTransC) { - return Impl{}(b_vec, a_vec); + return TransposedImpl{}(b_vec, a_vec); } else { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 0464ffbce4..cf0efbbaae 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -22,9 +22,10 @@ struct WmmaTraits; template struct WarpGemmAttributeWmmaImpl { - using ADataType = typename Traits::ADataType; - using BDataType = typename Traits::BDataType; - using CDataType = typename Traits::CDataType; + using TraitsType = Traits; + using ADataType = typename Traits::ADataType; + using BDataType = typename Traits::BDataType; + using CDataType = typename Traits::CDataType; using AVecType = typename Traits::AVecType; using BVecType = typename Traits::BVecType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp index 992f0a8783..d9d4ec9430 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp @@ -10,6 +10,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -30,6 +32,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -50,6 +54,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -70,6 +76,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp index 34c4dbe551..eace7e3956 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -10,6 +10,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -35,6 +37,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -60,6 +64,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -80,6 +86,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -100,6 +108,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp index 524215ddfa..e00b9d772f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -10,6 +10,8 @@ struct WmmaTraitsBase; template struct WmmaTraitsBase { + using ArchType = gfx11_t; + using ADataType = ADType; using BDataType = BDType; using CDataType = CDType; @@ -57,6 +59,8 @@ struct WmmaTraitsBase template struct WmmaTraitsBase { + using ArchType = gfx12_t; + using ADataType = ADType; using BDataType = BDType; using CDataType = CDType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 82c6e43834..d6c21e88b5 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -100,6 +100,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; @@ -113,6 +114,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // scale mfma based f8f6f4 diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index 6e7c086e55..5239b2d888 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -31,7 +31,14 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) if constexpr(std::is_same_v) { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } else { @@ -84,7 +91,14 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) } else { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } } else @@ -103,18 +117,7 @@ TYPED_TEST(TEST_SUITE_NAME, PaddK) for(int M : Ms) { - if constexpr(std::is_same_v) - { -#if defined(ARCH_GFX12) || defined(ARCH_GFX11) - this->Run(M, N, K); -#else - EXPECT_THROW(this->Run(M, N, K), std::runtime_error); -#endif - } - else - { - this->Run(M, N, K); - } + this->Run(M, N, K); } } From 1c433c64ec5254d202b7cbf4b8b0e98678ea2a4f Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Tue, 6 Jan 2026 09:29:06 +0100 Subject: [PATCH 13/23] [CK_BUILDER] Integrate reference conv with testing (#3511) * ck-builder: explicitly delete forward declarations Before, these functions were seen as a forward declaration for an existing function. If no actual implementation overload could be found, these would be selected and a linker error or warning would be generated. By marking these functions as explicitly deleted, they incorrect invocations are generated as compile error instead. * ck-builder: ckt::run plumbing for reference conv This implements the ckt::run plumbing for the reference convolution implementation and sets up the first complete end-to-end test. * ck-builder: make validation system check for all-zeros When both the actual and reference output are both all zero bits, there is probably something wrong in the test framework. * ck-builder: proper implementation+tests for TensorDescriptor::is_packed * ck-builder: fix typos --- .../builder/factory/reference_factory.hpp | 45 +++---- .../ck_tile/builder/testing/conv_fwd_ck.hpp | 79 +++++++++--- .../builder/testing/conv_fwd_reference.hpp | 114 ++++++++++++++++++ .../builder/testing/tensor_descriptor.hpp | 30 +++++ .../ck_tile/builder/testing/testing.hpp | 23 +++- .../ck_tile/builder/testing/validation.hpp | 60 +++++++-- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 16 +-- .../builder/test/unit_tensor_descriptor.cpp | 19 +++ experimental/builder/test/unit_validation.cpp | 23 +++- 9 files changed, 349 insertions(+), 60 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp index 0246c805c2..0748725c96 100644 --- a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp @@ -125,9 +125,9 @@ struct ReferenceFactory // Direct Run method (simpler interface, direction-agnostic) template - static void Run(InPtrType input, - WeiPtrType weight, - OutPtrType output, + static void Run(InPtrType* input, + WeiPtrType* weight, + OutPtrType* output, int G, int N, int K, @@ -142,9 +142,9 @@ struct ReferenceFactory if constexpr(ConvDirectionIsForward) { ck_tile::naive_grouped_conv_fwd( - input, - weight, - output, + static_cast(input), + static_cast(weight), + static_cast(output), G, N, K, @@ -160,9 +160,9 @@ struct ReferenceFactory { ck_tile:: naive_grouped_conv_bwd_data( - input, - weight, - output, + static_cast(input), + static_cast(weight), + static_cast(output), G, N, K, @@ -179,19 +179,20 @@ struct ReferenceFactory ck_tile::naive_grouped_conv_bwd_weight(input, - weight, - output, - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); + OutDataType>( + static_cast(input), + static_cast(weight), + static_cast(output), + G, + N, + K, + C, + input_spatial, + filter_spatial, + output_spatial, + strides, + dilations, + left_pads); } } diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp index cc5c613d95..499e0ef3de 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp @@ -3,10 +3,10 @@ #pragma once -#include -#include - #include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include +#include /// This file contains the implementation details for invoking/testing /// grouped convolution operations in old CK. The main item is the @@ -15,6 +15,63 @@ namespace ck_tile::builder::test { +namespace detail { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template > +concept CkConvInstance = requires(Conv& conv, + // TODO: This should be changed depending on IsMultiA etc. + // Currently that is not yet supported elsewhere anyway. + const void* p_a, + const void* p_b, + void* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::AElementwiseOp elementwise_a, + Ops::BElementwiseOp elementwise_b, + Ops::CDEElementwiseOp elementwise_cde) { + { + conv.MakeArgument(p_a, + p_b, + // TODO: Support multiple D outputs. + {}, + p_e, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // TODO: Ds lengths/strides + {}, + {}, + // E lengths/strides + lengths, + strides, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde) + }; +}; + +} // namespace detail + /// @brief Concept for checking whether a convolution is invoked like old CK. /// /// This concept is used to tell whether a convolution implementation is @@ -24,13 +81,8 @@ namespace ck_tile::builder::test { /// /// - SIGNATURE is the operation signature. /// - Conv is a convolution instance created by the CK Builder API. -template -concept IsCkConvInstance = - // TODO: This should be implemented by converting the signature into the - // type parameters for DeviceGroupedConvFwdMultipleABD. For now, just leave - // it empty. Improve when needed, you get the point. Also we should probably - // move this to the ck conv factory helper. - true; +template +concept CkConvInstance = detail::CkConvInstance; /// @brief `run()` specialization for forward convolution and old CK. /// @@ -39,10 +91,9 @@ concept IsCkConvInstance = /// operation. This should be caught and reported by the testing framework. /// /// @see run() -template - requires ValidConvSignature && ConvDirectionIsForward && - IsCkConvInstance -void run(Conv& conv, +template + requires ValidConvSignature && ConvDirectionIsForward +void run(CkConvInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs) diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp new file mode 100644 index 0000000000..85493e32eb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp @@ -0,0 +1,114 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/conv_fwd.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// grouped convolution operations using the reference implementation. +/// The main item is the `run()` function, which is the primary way to +/// invoke the reference execution mechanism. +/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, +/// but its made specific to the reference implementation, which is +/// invoked in a slightly different way. + +namespace ck_tile::builder::test { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This concept is used to tell whether a convolution implementation is +/// likely to be the reference implementation - that is, whether we should +/// invoke it like the reference kernel. This is mainly used with `run()` to +/// differentiate which implementation that should be invoked. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept RefConvInstance = requires(Conv& conv, + const void* input, + const void* weight, + void* output, + int G, + int N, + int K, + int C, + std::vector dims) { + { + conv.Run(input, + weight, + output, + G, + N, + K, + C, + dims, // input_spatial + dims, // filter_spatial + dims, // output_spatial + dims, // strides + dims, // dilations + dims // left_pads + ) + }; +}; + +/// @brief `run()` specialization for forward convolution and the reference +/// implementation. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @throws std::runtime_error if the arguments weren't actually valid for the +/// operation. This should be caught and reported by the testing framework. +/// +/// @see run() +template + requires ValidConvSignature && + // TODO: Maybe we can unify this implementation for bwd/weight too? + // for now, just concern outselves with reference and see when the + // rest of the bwd/weight plumbing is there. + ConvDirectionIsForward +void run(RefConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + // We don't want to compute the output dims manually, just get + // them via the existing infrastructure + const auto param = args.to_ck_conv_param(); + + // TODO: The reference convolution is currently missing a few features. + // Just throw for now, but regard these as TODO items that should be resolved + // eventually. + + // Right pads are not supported right now for some reason. + for(auto right_pad : param.input_right_pads_) + { + if(right_pad != 0) + throw std::runtime_error("TODO: Support right pad in reference conv"); + } + + if(!args.make_input_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed input tensor in reference conv"); + if(!args.make_weight_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed weight tensor in reference conv"); + if(!args.make_output_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed output tensor in reference conv"); + + conv.Run(inputs.input, + inputs.weight, + outputs.output, + param.G_, + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.output_spatial_lengths_, + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp index 0ba01a77ca..15fe4d89db 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/testing/type_traits.hpp" @@ -369,6 +370,35 @@ struct TensorDescriptor return get_element_space_size() * data_type_sizeof(DT); } + /// @brief Check if a tensor is packed in memory. + /// + /// This function checks whether the tensor memory is "packed", that is, whether + /// all elements are continuous in memory with no gaps. + bool is_packed() const + { + // First sort by stride, then check if they match the scan of the + // sizes. + const auto& lengths = inner_descriptor_.get_lengths(); + const auto& strides = inner_descriptor_.get_strides(); + + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](auto i, auto j) { + return strides[i] < strides[j]; + }); + + size_t x = 1; + for(size_t i = 0; i < RANK; ++i) + { + if(strides[indices[i]] != x) + return false; + + x *= lengths[indices[i]]; + } + + return true; + } + /// @brief Get a tensor descriptor for the space backing a tensor. /// /// This function returns a tensor descriptor which represents the buffer space diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index 9c8b858018..609c93cacf 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -220,10 +220,13 @@ UniqueInputs alloc_inputs(const Args& args); /// @param args The run-time arguments of the operation. /// @param inputs The operation inputs to initialize with random data. /// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// /// @see Inputs /// @see tensor_initialization template -void init_inputs(const Args& args, Inputs inputs); +void init_inputs(const Args& args, Inputs inputs) = delete; /// @brief Allocate outputs corresponding to a signature. /// @@ -236,13 +239,16 @@ void init_inputs(const Args& args, Inputs inputs); /// /// @param args The run-time arguments of the operation. /// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// /// @see Outputs /// @see UniqueOutputs /// @see alloc_buffer() /// @see alloc_tensor_buffer() template requires ValidUniqueOutputs -UniqueInputs alloc_outputs(const Args& args); +UniqueInputs alloc_outputs(const Args& args) = delete; /// @brief Compare device operation outputs. /// @@ -262,10 +268,14 @@ UniqueInputs alloc_outputs(const Args& args); /// @param actual The actual results, the results of the operation to-be-tested. /// @param expected The expected results, the results of the reference implementation. /// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// /// @see ValidationReport template -ValidationReport -validate(const Args& args, Outputs actual, Outputs expected); +ValidationReport validate(const Args& args, + Outputs actual, + Outputs expected) = delete; /// @brief Invoke a device operation created by CK Builder. /// @@ -296,10 +306,13 @@ validate(const Args& args, Outputs actual, Outputs void run(Operation& operation, const Args& args, const Inputs& inputs, - const Outputs& outputs); + const Outputs& outputs) = delete; } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp index 275fa490eb..267bf8d2ac 100644 --- a/experimental/builder/include/ck_tile/builder/testing/validation.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -13,6 +13,7 @@ #include #include #include +#include /// This file implements functionality related to "validation", ie, functionality /// to compare tensors. The functionality in this file should be testing-framework @@ -48,12 +49,22 @@ struct ValidationReport /// The total number of elements in each tensor. uint64_t total_elements; + /// The number of elements which were bitwise 0. + uint64_t zero_elements; + + /// @brief Check whether both the output and reference tensor were both all zeros. + /// + /// If both tensors are all zero, it indicates either an incorrect testing setup + /// or an issue with the testing framework. For that reason we also consider that + /// a failure. + bool is_all_zero() const { return zero_elements == total_elements; } + /// @brief Return whether the check associated to this case was successful. /// /// This function returns whether the check associated to this case was successful, /// which is directly derived from checking whether the number of incorrect elements - /// was 0. - bool is_ok() const { return wrong_elements == 0; } + /// was 0 AND whether the tensor was not all zero. + bool is_ok() const { return wrong_elements == 0 && !is_all_zero(); } }; /// @brief Get comparison cases which were incorrect. @@ -123,10 +134,13 @@ bool ValidationReport::check(std::string_view tensor_name, // Initial pass: count errors // Allocate and reset counter - auto d_error_count = alloc_buffer(sizeof(uint64_t)); - check_hip(hipMemset(d_error_count.get(), 0, sizeof(uint64_t))); + auto d_counters = alloc_buffer(sizeof(uint64_t) * 2); + check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2)); - tensor_foreach(descriptor.get_lengths(), [=, error_count = d_error_count.get()](auto index) { + auto d_error_count = &reinterpret_cast(d_counters.get())[0]; + auto d_zero_count = &reinterpret_cast(d_counters.get())[1]; + + tensor_foreach(descriptor.get_lengths(), [=](auto index) { using CKType = typename factory::internal::DataTypeToCK
::type; const auto* actual = static_cast(actual_data); @@ -137,21 +151,44 @@ bool ValidationReport::check(std::string_view tensor_name, const auto offset = calculate_offset(index, strides); - const auto o = static_cast(type_convert(actual[offset])); - const auto r = static_cast(type_convert(expected[offset])); + const auto a = actual[offset]; + const auto b = expected[offset]; + + const auto o = static_cast(type_convert(a)); + const auto r = static_cast(type_convert(b)); const auto err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { // We expect the number of errors to be very low, so just use an atomic // for now. - atomicAdd(reinterpret_cast(error_count), 1); + atomicAdd(d_error_count, 1); + } + + // Now compare the numbers as bitwise too. + // Update the counter if they're both zero. + using Bytes = std::array; + bool all_zero = true; + for(auto x : std::bit_cast(a)) + { + if(x != std::byte{0}) + all_zero = false; + } + for(auto x : std::bit_cast(b)) + { + if(x != std::byte{0}) + all_zero = false; + } + if(all_zero) + { + atomicAdd(d_zero_count, 1); } }); uint64_t error_count = 0; - check_hip( - hipMemcpy(&error_count, d_error_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + uint64_t zero_count = 0; + check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); // TODO: Gather detailed coordinates. @@ -159,9 +196,10 @@ bool ValidationReport::check(std::string_view tensor_name, .tensor_name = std::string(tensor_name), .wrong_elements = error_count, .total_elements = descriptor.get_element_size(), + .zero_elements = zero_count, }); - return error_count == 0; + return reports_.back().is_ok(); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 5a52b6a9b5..1ba811bbe0 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -5,6 +5,7 @@ #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" #include "ck_tile/builder/testing/conv_fwd_ck.hpp" +#include "ck_tile/builder/testing/conv_fwd_reference.hpp" #include "ck_tile/host/device_prop.hpp" #include "testing_utils.hpp" @@ -34,6 +35,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; +using Reference = ckb::ConvBuilder::Instance; + TEST(Fwd2DFp16_CShufV3_GNHWC, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); @@ -81,18 +84,17 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) .cde_elementwise_op = {}, }; - auto inputs = ckt::alloc_inputs(args); - auto outputs = ckt::alloc_outputs(args); + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; ckt::run(conv, args, inputs.get(), outputs.get()); - // TODO: This should be allocated via ckt::alloc_outputs() and - // initialized via ckt::run() with the reference implementation - // instead. - auto reference = outputs.get(); + auto ref_conv = Reference{}; + ckt::run(ref_conv, args, inputs.get(), reference.get()); - EXPECT_THAT(outputs.get(), MatchesReference(args, reference)); + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/unit_tensor_descriptor.cpp b/experimental/builder/test/unit_tensor_descriptor.cpp index d9e92bf07e..672ebbd88a 100644 --- a/experimental/builder/test/unit_tensor_descriptor.cpp +++ b/experimental/builder/test/unit_tensor_descriptor.cpp @@ -170,3 +170,22 @@ TEST(TensorDescriptor, ExtentFromVector) EXPECT_THAT([] { return ckt::Extent<5>::from_vector(std::vector{1, 2}); }, Throws()); } + +TEST(TensorDescriptor, IsPacked) +{ + constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{}) + .is_packed()); + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{5334, 235, 1563, 256, 23}, ckt::PackedRightLayout{}) + .is_packed()); + EXPECT_TRUE(ckt::make_descriptor
(ckt::Extent{}, ckt::Extent{}).is_packed()); + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{461, 345, 5, 93}, ckt::Extent{160425, 5, 1, 1725}) + .is_packed()); + EXPECT_FALSE( + ckt::make_descriptor
(ckt::Extent{10, 11, 12}, ckt::Extent{1, 100, 1100}).is_packed()); + EXPECT_FALSE( + ckt::make_descriptor
(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed()); +} diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp index 06736ca624..5f6b620d6b 100644 --- a/experimental/builder/test/unit_validation.cpp +++ b/experimental/builder/test/unit_validation.cpp @@ -67,7 +67,7 @@ TYPED_TEST(ValidationReportTests, SingleCorrect) // Generate a sort-of-random looking sequence auto generator = [strides = desc.get_strides()](const auto& index) { const auto flat_index = ckt::calculate_offset(index, strides); - return static_cast(flat_index * 10'000'019 % 768'351); + return static_cast((flat_index + 1) * 10'000'019 % 768'351); }; ckt::fill_tensor(desc, a.get(), generator); @@ -110,6 +110,27 @@ TYPED_TEST(ValidationReportTests, SingleIncorrect) EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size())); } +TYPED_TEST(ValidationReportTests, ZeroIsIncorrect) +{ + const auto desc = TypeParam::get_descriptor(); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + ckt::ValidationReport report; + report.check("zero_is_incorrect", desc, b.get(), a.get()); + + const auto errors = report.get_errors(); + ASSERT_THAT(errors.size(), Eq(1)); + EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect")); + EXPECT_THAT(errors[0].wrong_elements, Eq(0)); + EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size())); + EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size())); +} + TEST(ValidationReportTests, MultipleSomeIncorrect) { ckt::ValidationReport report; From 2ffbf7f476d99b6fc3db71480b49d221c602e071 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 6 Jan 2026 09:36:54 -0800 Subject: [PATCH 14/23] add tabulate package to aiter docker (#3519) --- Dockerfile.aiter | 2 +- Jenkinsfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile.aiter b/Dockerfile.aiter index 94591f9012..020afeccf4 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -2,7 +2,7 @@ ARG BASE_DOCKER="rocm/pytorch:latest" FROM $BASE_DOCKER ARG AITER_BRANCH="main" ARG CK_AITER_BRANCH="develop" -RUN pip install pandas zmq einops ninja && \ +RUN pip install pandas zmq einops ninja tabulate && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ sudo mkdir /home/jenkins/workspace && \ diff --git a/Jenkinsfile b/Jenkinsfile index 78703ed9aa..7292d9b70c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1046,7 +1046,7 @@ def run_aiter_tests(Map conf=[:]){ sh "rocminfo" sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" - //sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" //temporarily disable + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" From 960ef551bf5d615d45e31b954e0faff147e76c85 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Tue, 6 Jan 2026 11:08:54 -0800 Subject: [PATCH 15/23] Fix build error from extra comma (#3516) The newer rocm compiler gives an error with a trailing comma in testing::AllOf. --- experimental/builder/test/unit_error.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/test/unit_error.cpp b/experimental/builder/test/unit_error.cpp index b666462385..201780cc6a 100644 --- a/experimental/builder/test/unit_error.cpp +++ b/experimental/builder/test/unit_error.cpp @@ -30,7 +30,7 @@ TEST(HipError, SourceInfo) // ...the filename HasSubstr("experimental/builder/test/unit_error.cpp"), // ...the function name - HasSubstr("throw_error"), + HasSubstr("throw_error") // Note: Don't include the row/column so that we can move // stuff around in this file. ))); From 2309c8605441b13f7161e4bd17df5b943a79efdf Mon Sep 17 00:00:00 2001 From: kensclin Date: Wed, 7 Jan 2026 04:35:01 +0800 Subject: [PATCH 16/23] [CK_TILE] add preshuffleB mode for ABQuant GEMM (#3495) * [CK_TILE] add preshuffleB mode for ABQuant GEMM * fix precommit error * use template method call for cvt_scale_to_fp32 * fix precommit error * add test code * fix precommit error * switch abquant gemmconfig to default * Add changelog.md * fix precommit error * fix conflict --- CHANGELOG.md | 1 + .../gemm_abquant_quantgrouped.cpp | 60 ++ .../run_gemm_quant_example.inc | 49 +- include/ck_tile/ops/gemm_quant.hpp | 3 + ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 282 ++++++++ ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 120 ++++ .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 611 ++++++++++++++++++ test/ck_tile/gemm_block_scale/CMakeLists.txt | 6 + .../test_gemm_quant_abquant_preshuffle_2d.cpp | 44 ++ .../test_gemm_quant_fixtures.hpp | 12 +- 10 files changed, 1161 insertions(+), 27 deletions(-) create mode 100644 include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp create mode 100755 include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index b149a74df3..3280ad07dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## (Unreleased) Composable Kernel 1.3.0 ### Added +* Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". * Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 4a90c07e05..155f19881e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index d8988be7b0..398a61f368 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, ck_tile::BaseGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::Scheduler, has_hot_loop_v, tail_number_v>>>>; + using AQuantPipeline = + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>; + + using BQuantPipeline = std::conditional_t< + GemmConfig::PreshuffleB, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + + using ABQuantPipeline = + std::conditional_t, + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmPipeline = std::conditional_t< QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - std::conditional_t, - ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t< - QuantMode == ck_tile::QuantType::ABQuantGrouped, - ck_tile::ABQuantGemmPipelineAgBgCrCompV3, - std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::WPQuantBPipelineAgBgCrV2, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>>; + std::conditional_t>>; constexpr bool TiledPermuteN = (BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -908,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - if((QuantMode == ck_tile::QuantType::ABQuantGrouped || - QuantMode == ck_tile::QuantType::AQuantGrouped || + if((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant || std::is_same_v) && GemmConfig::PreshuffleB) @@ -938,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::ABQuantGrouped) && - !GemmConfig::PreshuffleQuant) + !GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB) { if(a_layout == "R" && b_layout == "R") { diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 1e4aece0d7..696de378aa 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" @@ -24,6 +25,8 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp new file mode 100644 index 0000000000..63a5151108 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -0,0 +1,282 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// BQ (scale tensor) is block distributed tensor. +// Consecutive QuantGroupSize elements of B are quantized with a separate scale. +// B is block window on block distributed tensor. +// C is block distributed tensor +template +struct BlockGemmWeightPreshuffleABQuantARegBRegCReg +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + // Threadblock GEMM tile size + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN; + static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK; + static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + // number of warps along M and N for threadblock's GEMM problem size + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consistent with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of QuantGroupSize"); + static_assert(QScalesPerWarpGemmRow == 1, + "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK"); + static_assert(KIterPerWarp % QScalesPerBlockRow == 0, + "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); + + static_assert(KPerBlock / BQuantGroupSize::kK > 0, + "Error! Each row of blockgemm should have a separate scale"); + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + // Currently tested combinations (A, B, BQ) + // 1. fp8, fp8, fp32 -> f32 + // 2. bf8, bf8, fp32 -> f32 + // 3. i4, fp8, (fp8/fp32) -> f32 + // 4. i4, bf8, (fp8/fp32) -> f32 + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v) && + std::is_same_v); + + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + static constexpr bool TransposeC = Problem::TransposeC; + }; + + public: + using Traits = GemmTraits_; + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto warp_size = get_warp_size(); + + using WG = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 128 / (1 * 16) = 8 + static constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); // 128 / (4 * 16) = 2 + static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 128 / 16 = 8 + static constexpr auto MIter_2nd_last = + (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1 + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WG::kK, QuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8 + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + return BlockGemmQuantCommon:: + MakeCBlockTile(); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + ABlockTensor& a_warp_tensor, + BFlatBlockTensor& b_warp_tensor, + AQBlockTensor& aq_block_tensor, + BQBlockTensor& bq_block_tensor, + ABlockWindow& a_warp_windows) const + { + using CWarpDstr = typename WG::CWarpDstr; + using AccTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + statically_indexed_array, MIterPerWarp> + c_acc; + + auto zero_accumulators = [&] { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; + }); // make sure WG::CWarpTensor exposes a clear/zero + }); + }); + }; + static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { + zero_accumulators(); + static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + AQPickerCommon aq_picker(aq_block_tensor); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + aq_picker.template cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; + }); + }); + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp new file mode 100755 index 0000000000..80e41cad45 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -0,0 +1,120 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelineAgBgCrPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() + { + using AQDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + return GemmAQuantPipelineAgBgCrDefaultPolicy::MakeAQDramTileDistribution(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); + } + + // as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed; + // move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here + // temporarily + template + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + + using WarpGemm = WarpGemmDispatcher; + + // TODO : Use a custom block policy for AsBrCr + using BlockGemmPolicy = + BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy; + return BlockGemmWeightPreshuffleABQuantARegBRegCReg{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp new file mode 100644 index 0000000000..0f3951ffcc --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -0,0 +1,611 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV2 +{ + using Base = WeightPreshufflePipelineAGmemBGmemCRegV2; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockWeightPreshuffle = remove_cvref_t< + decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant())>; + + static constexpr auto config = + BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + using Base::kKPerBlock; + using Base::kMPerBlock; + using Base::kNPerBlock; + + using Base::KIterPerWarp; + using Base::MIterPerWarp; + using Base::NIterPerWarp; + + using Base::BlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::MWarp; + using Base::NWarp; + + using Base::KPerBlockPerIter; + using Base::MPerBlockPerIter; + + using Base::flatKPerWarp; + using Base::flatNPerWarp; + + using Base::m_preload; + + static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t KPerBlockAQ = + integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK); + static constexpr index_t KPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK); + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK); + static constexpr index_t GetVectorSizeAQ() + { + return PipelinePolicy::template GetVectorSizeAQ(); + } + static constexpr index_t GetVectorSizeBQ() + { + return PipelinePolicy::template GetVectorSizeBQ(); + } + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + return concat('_', "bquant_pipeline_AgBgCrV2_preshuffleB", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeAQ(), GetVectorSizeBQ()), + concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); + // clang-format on + } + + template + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) + constexpr index_t Aload_inst = + (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) + constexpr index_t Bload_inst = + (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. + constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( + ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), + BQuantGroupSize::kK * BQuantGroupSize::kK), + VectorLoadSize); + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) + constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. + constexpr index_t buffer_load_rep = + min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma + + static_for<0, nloop, 1>{}([&](auto) { + static_for<0, mfma_inst, 1>{}([&](auto i_inst) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. + if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + if constexpr(ds_rep > 0 && i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + + if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } + } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + + static constexpr bool PreshuffleB = Problem::PreshuffleB; + static constexpr auto TailNum = Problem::TailNum; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t m, + index_t n, + index_t num_loop, + void* p_smem) const + { + (void)m; + (void)n; + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = std::is_same_v; + static_assert(!is_a_col_major, "A must be row major (col major not supported yet)"); + + constexpr bool is_bq_col_major = std::is_same_v; + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + + constexpr bool is_b_row_major = std::is_same_v; + static_assert(!is_b_row_major, "B must be col major (row major not supported yet)"); + + const index_t iMWarp = get_warp_id() / NWarp; + // Double-Buffering (loop_count=2) for full load/compute overlap. + const index_t loop_count = 2; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + ADataType* p_a_lds_ping = static_cast(p_smem); + ADataType* p_a_lds_pong = + reinterpret_cast(static_cast(p_smem) + smem_size); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_pong; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + aq_dram_block_window_tmp.get_window_lengths(), + aq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeAQDramTileDistribution()); + // BQ DRAM window for load + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + bq_dram_block_window_tmp.get_window_lengths(), + bq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeBQDramTileDistribution()); + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // Strictly not needed given type deduction, but helps with readability + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution()); + using BQBlockTile = + decltype(make_static_distributed_tensor(BQBlockTileDistr{})); + + // Load tile 0 for BQ data directly into registers for block tile + AQBlockTile aq_block_tile, aq_block_tile_2; + BQBlockTile bq_block_tile, bq_block_tile_2; + aq_block_tile = load_tile(aq_copy_dram_window); + bq_block_tile = load_tile(bq_copy_dram_window); + // move BQ to tile 1 + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + // Prefill A0 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + // MAIN LOOP + index_t iCounter = (num_loop - 1) / loop_count; + + while(iCounter > 0) + { + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + iCounter--; + HotLoopScheduler(); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + HotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + Base::LastHotLoopScheduler(); + } + + return c_block_tile; + } + + // Replace lines 485-526 with a single optimized operator: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem, + index_t m = 0, + index_t n = 0) const // Default value for non-preshuffle case + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + m, + n, + num_loop, + p_smem); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + TailNumber tail_number, + void* p_smem, + index_t n = 0) const + { + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + n, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, true, tail_number); + } +}; +} // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index f89aea1c17..2dad8be205 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -39,6 +39,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle + test_gemm_quant_abquant_preshuffle_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # AQuant tests add_gtest_executable(test_tile_gemm_quant_aquant_prefill test_gemm_quant_aquant_prefill.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp new file mode 100644 index 0000000000..793c9bd1df --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleBTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 6fb1b77fa8..3798cc4443 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -894,10 +894,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase; - using BaseGemmPipeline = - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + using BaseGemmPipeline = std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -926,8 +926,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase; using GemmPipeline = - std::conditional_t, + std::conditional_t, ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< From 76696ace4460a5bcf79d9a75a97b30c76507e284 Mon Sep 17 00:00:00 2001 From: kyle-256 Date: Wed, 7 Jan 2026 04:36:04 +0800 Subject: [PATCH 17/23] [CKTILE] Support A/B Quantization in Blockscale Grouped Gemm (#3452) * update grouped_gemm blockwise kernel * update config * update kernel * update examples * remove test code for now * sync test files with origin/develop * update example * fix code lint * fix code-lint * update test code * run clang format * run pre-commit * update api --- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 3 +- .../17_grouped_gemm/abquant_grouped_gemm.cpp | 278 ++++++++ .../17_grouped_gemm/abquant_grouped_gemm.hpp | 171 +++++ .../run_grouped_gemm_abquant_example.inc | 604 ++++++++++++++++++ .../kernel/grouped_gemm_quant_kernel.hpp | 17 +- test/ck_tile/CMakeLists.txt | 1 + .../grouped_gemm_abquant/CMakeLists.txt | 16 + .../test_grouped_gemm_abquant_1x128x128.cpp | 47 ++ .../test_grouped_gemm_abquant_1x1x128.cpp | 47 ++ .../test_grouped_gemm_abquant_ut_cases.inc | 87 +++ .../test_grouped_gemm_abquant_util.hpp | 530 +++++++++++++++ 11 files changed, 1798 insertions(+), 3 deletions(-) create mode 100644 example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp create mode 100644 example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp create mode 100644 example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc create mode 100644 test/ck_tile/grouped_gemm_abquant/CMakeLists.txt create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 9b51af22fe..0f0a0d8ba7 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") quant_grouped_gemm_bf8_rowcol.cpp quant_grouped_gemm_bf8_tensor.cpp ) - + add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp) add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) @@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp new file mode 100644 index 0000000000..84da1e26da --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -0,0 +1,278 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/host.hpp" +#include "abquant_grouped_gemm.hpp" + +// Non-persistent grouped gemm for ABQuant +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +// Persistent grouped gemm tileloop for ABQuant +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); +} + +#include "run_grouped_gemm_abquant_example.inc" + +int main(int argc, char* argv[]) +{ + int result1 = run_abquant_grouped_gemm_example(argc, argv); + return result1; +} diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp new file mode 100644 index 0000000000..da8bd5514c --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp @@ -0,0 +1,171 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); +}; + +template +struct GemmQuantConfig; + +// ABQuant specialization for GemmQuantConfig +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") + .insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.") + .insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); +} + +// Forward declaration of the non-persistent version +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); + +// Forward declaration of the tileloop version for persistent kernels +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc new file mode 100644 index 0000000000..bc5167439d --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc @@ -0,0 +1,604 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_abquant_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_abquant( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); + } + + return ave_time; +} + +template +int run_abquant_grouped_gemm_example_with_layouts( + int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + const BQLayout bq_layout = BQLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + + auto [result, arg_parser] = create_args(argc, argv); + + auto valid_input_data = [&](int group_count, const auto&... args) { + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector AQs; // dimension of AQ tensor is calculated from A tensor + std::vector BQs; // dimension of BQ tensor is calculated from B tensor + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + std::vector stride_AQs = arg_parser.get_int_vec("stride_AQs"); + std::vector stride_BQs = arg_parser.get_int_vec("stride_BQs"); + + ck_tile::index_t AQK, BQK; + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + // Let get_default_stride calculate based on layout + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + // For ABQuantGrouped, both A and B need quantization + static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped, + "This file only supports ABQuantGrouped mode"); + + AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize + BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl; + + if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(bq_tensors[i]); + } + else + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + } + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back( + std::make_unique(aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back( + std::make_unique(bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + kbatch, + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + float ave_time = invoke_abquant_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Reference implementation for ABQuantGrouped + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_abquant_grouped_gemm_example_prec_type_with_bquant( + std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + using AQDataType = typename Types::AccDataType; + using BQDataType = typename Types::AccDataType; + using AQuantGroupSize = ck_tile::QuantGroupShape>; + + constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + if(a_layout == "R" && b_layout == "C" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +template +int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + std::string c_layout, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(bquant_group_size == "1x1x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else if(bquant_group_size == "1x128x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128."); + } +} + +template +int run_abquant_gemm_example_persistency(std::string a_layout, + std::string b_layout, + std::string c_layout, + bool persistent, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(persistent) + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } + else + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } +} + +int run_abquant_grouped_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string c_layout = arg_parser.get_str("c_layout"); + const std::string data_type = arg_parser.get_str("prec"); + bool persistent = arg_parser.get_bool("persistent"); + const std::string bquant_group_size = arg_parser.get_str("bquant_group_size"); + + if(data_type == "fp8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else if(data_type == "bf8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type configuration."); + } +} diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 06a80c8b55..c9e725f5fd 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -484,6 +484,17 @@ struct QuantGroupedGemmKernel tail_num, smem_ptr); } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr); + } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) { @@ -499,7 +510,8 @@ struct QuantGroupedGemmKernel c_ptr, kargs, block_idx_m, block_idx_n); if constexpr(kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } @@ -527,7 +539,8 @@ struct QuantGroupedGemmKernel c_ptr, kargs, block_idx_m, block_idx_n); if constexpr(kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 197c9d6e1d..93cd7fa063 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(grouped_gemm_multi_d) add_subdirectory(grouped_gemm_quant) +add_subdirectory(grouped_gemm_abquant) add_subdirectory(gemm_multi_d) add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) diff --git a/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt new file mode 100644 index 0000000000..e735aa8e9a --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x1x128 test_grouped_gemm_abquant_1x1x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x1x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x128x128 test_grouped_gemm_abquant_1x128x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x128x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp new file mode 100644 index 0000000000..06b0068cb7 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x128x128 +using BQuantGroupSize_1x128x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x128x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x128x128, KernelTypes_ABQuant_1x128x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x128x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp new file mode 100644 index 0000000000..7704eda169 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x1x128 +using BQuantGroupSize_1x1x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x1x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x1x128, KernelTypes_ABQuant_1x1x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x1x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc new file mode 100644 index 0000000000..48574ab977 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_CLASS_NAME, Basic) +{ + const int group_count = 6; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop +// Using 256x256x128 to match the test kernel's tile size (M_Tile=128, N_Tile=128, K_Tile=128) +TYPED_TEST(TEST_CLASS_NAME, SmallUniform) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(256); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +TYPED_TEST(TEST_CLASS_NAME, OddTail) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(128); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp new file mode 100644 index 0000000000..c7ed6f5472 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp @@ -0,0 +1,530 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" + +template +class TestCkTileGroupedGemmABQuant : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using AQDataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using BQDataType = std::tuple_element_t<6, Tuple>; + using AccDataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + using AQuantGroupSize = std::tuple_element_t<9, Tuple>; + using BQuantGroupSize = std::tuple_element_t<10, Tuple>; + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using AQLayout = Row; + using BQLayout = Col; + + static constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + struct GemmConfig + { + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(ADataType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool PreshuffleB = false; + static constexpr bool TransposeC = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + static constexpr bool IsPersistent = Persistent; + }; + + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + + std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); + } + + template + static constexpr inline auto is_row_major(Layout layout_) + { + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; + } + + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + template + float invoke_grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile:: + TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = Config::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + + template + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + std::vector& stride_AQs, + std::vector& stride_BQs, + const int group_count = 8) + { + ck_tile::index_t AQK, BQK; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + AQK = K / AQuantGroupSize::kK; + BQK = K / BQuantGroupSize::kK; + + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{})); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{})); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout{})); + stride_BQs[i] = + ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout{})); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{})))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(BLayout{})))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(BQLayout{})))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc + << " c_m_n: " << c_m_n_tensors[i].mDesc << " aq: " << aq_tensors[i].mDesc + << " bq: " << bq_tensors[i].mDesc << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back(std::make_unique( + aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back(std::make_unique( + bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + 1, // k_batch + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + + if constexpr(Persistent) + { + std::vector kargs; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error( + hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr); + } + else + { + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + invoke_grouped_gemm_abquant(gemm_descs, stream, kargs_ptr); + } + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + EXPECT_TRUE(pass); + } +}; + +// Aliases for split test files +template +using TestCkTileGroupedGemmABQuant_1x1x128 = TestCkTileGroupedGemmABQuant; + +template +using TestCkTileGroupedGemmABQuant_1x128x128 = TestCkTileGroupedGemmABQuant; From aaa35f0bbfa45dadc4380ddd6e0224668ddb97b4 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 6 Jan 2026 12:46:59 -0800 Subject: [PATCH 18/23] [CK_Tile] Support for various group sizes Preshuffle quant for 2d block scale gemm (#3445) * formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * initial commit with prints * debugging * fine-grained working * debugging medium grained * fixing the tile window * formatting * enabling prefill shapes * working prefill shapes * formatted * clean up * code cleanup * bug fix after merging with develop * clean up after merging with develop * added comments for the tile window and tile distribution encoding --------- Co-authored-by: Cong Ma Co-authored-by: Thomas Ning Co-authored-by: Agarwal --- ...mm_bquant_quantgrouped_preshufflequant.cpp | 245 ++++++- .../run_gemm_quant_example.inc | 2 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 1 + .../gemm_quant/kernel/gemm_quant_kernel.hpp | 645 ++---------------- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 5 +- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 15 +- .../pipeline/gemm_group_quant_utils.hpp | 164 ++++- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 20 +- 8 files changed, 428 insertions(+), 669 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index e0e0a64416..62ca34b057 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -9,36 +9,194 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; void bquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; lut[hash_multiple_strings( {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, @@ -47,10 +205,63 @@ void bquant_quantgrouped_preshufflequant_instance_factory( lut[hash_multiple_strings( {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 398a61f368..607c53d9af 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -540,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 16a0835b1d..313e449c7b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -322,6 +322,7 @@ struct BQuantBlockUniversalGemmAsBsCr constexpr index_t reg_offset = nIter; auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; // cross lane ops uint32_t scale_reg_dword; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 4f79361037..004fb18e0b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -280,12 +280,13 @@ struct QuantGemmKernel // Helper: Create Pre-shuffled Quantization Tensor Descriptor // =================================================================== template CK_TILE_DEVICE static auto - MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B) + MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B) { // Step 1: Calculate base BQ tensor dimensions // ---------------------------------------------------------- @@ -304,8 +305,9 @@ struct QuantGemmKernel // ---------------------------------------------------------- // Pad the X dimension to be a multiple of block_tile_size to ensure // each thread block can process complete tiles without edge cases - const auto block_tile_size = NPerBlock * KPerBlockBQ; - const auto bq_pad0_desc = transform_tensor_descriptor( + const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; + + const auto bq_pad0_desc = transform_tensor_descriptor( bq_desc, make_tuple(make_pass_through_transform(bq_y), make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), @@ -318,7 +320,7 @@ struct QuantGemmKernel // This separates the work into tiles that can be processed by // individual warps/waves const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = WarpTileN * KPerBlockBQ; + const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( @@ -813,12 +815,18 @@ struct QuantGemmKernel static_assert(std::is_same_v, "PreshuffleQuant with BQuantGrouped currently only supports " "ColumnMajor BQ layout"); + using QuantGroupSize = remove_cvref_t; return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlockBQ, GemmPipeline::NPerBlock, TilePartitioner::BlockGemmShape::WarpTile::at(I1), - GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); + GemmPipeline::GetVectorSizeBQ()>( + bq_ptr, + ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN), + QuantGroupSize::kN, + kargs.QK_B); } else { @@ -879,13 +887,38 @@ struct QuantGemmKernel if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = + constexpr auto block_n = + TilePartitioner::NPerBlock / + QuantGroupSize::kN; // Number of N-dimension quantization groups per block + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at( + I1); // Number of N-dimension elements per warp + constexpr auto warp_per_group = + (QuantGroupSize::kN < + warp_n) // Determine how many warps share the same scale in N-dimension + ? (warp_n / QuantGroupSize::kN) + : (QuantGroupSize::kN / warp_n); + constexpr auto bqk_per_block = + TilePartitioner::KPerBlock / + QuantGroupSize::kK; // Number of K-dimension quantization groups per block + constexpr auto + tile_window_width = // The pre-shuffled layout flattens warp_n × + // bqk_per_block scales per row, Padded up to warp_size + // to ensure coalesced memory access. ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_n / warp_n; - auto block_n_idx = i_n / block_n; + + // Adapts based on fine vs coarse quantization granularity: + // - Fine-grained (QuantGroupSize::kN < warp_n): + // Multiple quant groups per warp → fewer rows needed per block. + // height = block_n / warp_per_group + // + // - Coarse-grained (QuantGroupSize::kN >= warp_n): + // Each row represents one quant group. + // height = block_n + constexpr auto tile_window_height = + (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n; + auto block_n_idx = + i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a + // block index. return make_tile_window( bq_tensor_view, @@ -1125,596 +1158,6 @@ struct QuantGemmKernel return true; } - template - CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - const AQDataType* aq_ptr, - const BQDataType* bq_ptr, - CDataType* c_ptr, - const QuantGemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) - { - - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - const auto& a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - }(); - - const auto& aq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; - const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; - const auto aq_desc = - make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), - make_tuple(aq_x, 1), - number{}, - number<1>{}); - - const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ; - const auto aq_pad0_desc = transform_tensor_descriptor( - aq_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = - GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; - const auto wave_tile_count_x = - ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); - - const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( - aq_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto aq_pad1_desc = transform_tensor_descriptor( - aq_unmerge_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_pass_through_transform(wave_tile_count_x), - make_right_pad_transform( - wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - const auto pad_wave_size = - ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); - const auto aq_merge_pad1_desc = transform_tensor_descriptor( - aq_pad1_desc, - make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), - make_pass_through_transform(pad_wave_size)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view(aq_ptr, aq_merge_pad1_desc); - } - else if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - !PreshuffleQuant) - { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(kargs.stride_AQ, 1), - number{}, - number<1>{}); - } - else // Column major AQ - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions - make_tuple(kargs.stride_AQ, 1), // Same stride pattern - number{}, - number<1>{}); - } - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, 0), // broadcasting over n - number<1>{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - const auto& b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - if constexpr(PreshuffleB) - { - index_t kFlatK = GemmPipeline::flatKPerWarp * - (splitk_batch_offset.splitted_k / - GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - return make_naive_tensor_view( - b_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - } - }(); - - const auto& bq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - bq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(0, 1), // broadcasting over m - number<1>{}, - number<1>{}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v, - "PreshuffleQuant with BQuantGrouped currently only supports " - "ColumnMajor BQ layout"); - - return MakePreshuffledQuantTensorView< - GemmPipeline::KPerBlockBQ, - GemmPipeline::NPerBlock, - TilePartitioner::BlockGemmShape::WarpTile::at(I1), - GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); - } - else - { - using QuantGroupSize = remove_cvref_t; - - if constexpr(std::is_same_v) - { - // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] - // Dimensions: [K/QuantGroupK, N/QuantGroupN] - // Strides: [N/QuantGroupN, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), - integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), - number{}, - number<1>{}); - } - else - { - static_assert(std::is_same_v); - // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] - // Dimensions: [N/QuantGroupN, K/QuantGroupK] - // Strides: [K/QuantGroupK, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), - integer_divide_ceil(kargs.K, QuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), - number{}, - number<1>{}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple( - a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& aq_pad_view = [&]() { return views.at(I1); }(); - - const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I2); - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& bq_pad_view = [&]() { return views.at(I3); }(); - - // TODO vector write in for C in ColMajor - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I4); - if constexpr(std::is_same_v) - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - if constexpr(PreshuffleB) - { - - return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view); - } - else - { - return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - - const auto& a_pad_view = views.at(I0); - const auto& aq_pad_view = views.at(I1); - const auto& b_pad_view = views.at(I2); - const auto& bq_pad_view = views.at(I3); - const auto& c_pad_view = views.at(I4); - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& aq_block_window = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_m / warp_m; - auto block_m_idx = i_m / block_m; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {block_m_idx * tile_window_height, 0}); - } - else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) - { - using QuantGroupSize = remove_cvref_t; - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto block_m = TilePartitioner::MPerBlock; - if constexpr(std::is_same_v) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else // Column major AQ - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {0, i_m}); - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return nullptr; // TODO: use some other "empty" type? - } - }(); - - const auto& b_block_window = [&]() { - if constexpr(PreshuffleB) - { - - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); - } - else - { - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - } - else - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); - } - } - }(); - - const auto& bq_block_window = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(bq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - using QuantGroupSize = remove_cvref_t; - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_n / warp_n; - auto block_n_idx = i_n / block_n; - - return make_tile_window( - bq_pad_view, - make_tuple(number{}, number{}), - {block_n_idx * tile_window_height, 0}); - } - else - { - if constexpr(std::is_same_v) - { - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); - } - else - { - static_assert(std::is_same_v); - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - else - { - return nullptr; // TODO: use some other "empty" type here - } - }(); - - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple( - a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window); - } - /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 39f0cbdbd3..a4bba6cf76 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; - constexpr index_t VecLoadSize = GetVectorSizeBQ(); constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; using WarpTile = typename Problem::BlockGemmShape::WarpTile; @@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC BlockSize, NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), - VecLoadSize, + Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); @@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC KPerBlockBQ, // Logical K dimension NPerBlockBQ, // Logical N dimension Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index b43066cdc5..13d400d5fc 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), - 0) + (PreshuffleQuant) + ? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) : is_bq_row_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 0ec8942426..34f815ed27 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -192,6 +192,7 @@ template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern @@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding static_assert(num_warps == MWarps * NWarps * KWarps); static_assert(KWarps == 1); - /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) - /// - /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative - /// to warp dimensions. - /// - /// Three distinct distribution patterns are handled: - /// - /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): - /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast - /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp - /// - /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): - /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN - /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 - /// - /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): - /// - Quantization group spans multiple warps - /// - All warps share the same scale value - /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale - /// - /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { // Preshuffle only supported for ColumnMajor currently @@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding if constexpr(PreshuffleQuant) { - // ColumnMajor only for preshuffle - constexpr index_t X1 = warp_size; - constexpr index_t X0 = NPerTile / warp_size; - constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = KPerTile / Y1; + // ============================================================================= + // PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION + // ============================================================================= + // For pre-shuffled quantization, the BQ scale tensor has been reorganized + // (pre-shuffled) to optimize memory access patterns during dequantization. + // + // Tile Dimensions: + // - K-axis (Y in encoding): Corresponds to the K-dimension iteration + // - N-axis (X in encoding): Flattened scale index combining N and K groups + // + // The encoding distributes work across threads such that each thread loads + // the correct pre-shuffled scale for its corresponding B-matrix elements. + // ============================================================================= + if constexpr(NPerQ <= WarpGemm::kN) + { + // ========================================================================= + // CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN) + // ========================================================================= + // Multiple quantization scales exist within a single warp's N-dimension. + // Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp. + // + // Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256 + // → 2 scales per warp in N, 2 K-groups per block + constexpr auto N1 = BlockGemmShape::kK / + KPerQ; // Number of K-dimension quantization groups per block, + // Each K-group of KPerQ elements shares the same scale. + constexpr auto N0 = + WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ + // <= WarpGemm::kN, each warp handles multiple scales. + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension + constexpr auto NR0 = + warp_size / + (N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization + constexpr auto K1 = NWarps; // Number of warps distributed along this dimension + constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile + constexpr auto KR = 1; // No replication in K-dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2>>, - tuple, sequence<1>>, - sequence<1, 2>, - sequence<0, 0>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2, 0>>, + tuple, sequence<1, 0, 2, 1, 3>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else if constexpr(NPerQ < WarpGemm::kN * NWarps) + { + // ========================================================================= + // CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN * + // NWarps) + // ========================================================================= + // Each warp handles exactly one quantization scale in N-dimension. + // Some warps share the same scale (KR > 1 creates warp grouping). + // + // Example: NPerQ=32, WarpGemm::kN=16, NWarps=4 + // → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups) + + constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale + constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales) + constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN) + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ) + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<1, 0, 2, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + // ========================================================================= + // CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps) + // ========================================================================= + // The quantization group spans ALL warps in N-dimension. + // All warps share the same scale value for their N-tiles. + // + // Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 + // → 128 >= 16*4=64, so all 4 warps use the same scale + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Minimal (1) since scale is shared across N + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = 32; // Fixed broadcast size + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<2, 0, 3, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } } else { + /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) + /// + /// This function determines the optimal thread distribution pattern for loading and + /// applying quantization scales to the B matrix based on the quantization group size + /// (NPerQ) relative to warp dimensions. + /// + /// Three distinct distribution patterns are handled: + /// + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): + /// - Multiple quantization groups exist within a single warp's N-dimension + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale + /// broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): + /// - Each warp handles exactly one quantization scale + /// - Scales are distributed across warps with replication factor XR = NPerQ / + /// WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): + /// - Quantization group spans multiple warps + /// - All warps share the same scale value + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// + /// @return A static tile distribution encoding for the BQ scale tensor if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 43f37ec4d8..e4de7e4211 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t NPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN); static constexpr index_t KPerBlockBQ = integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = @@ -352,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -427,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -462,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else From d7497d26948ca90d0224920472712e0f657fb744 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Wed, 7 Jan 2026 01:05:56 -0700 Subject: [PATCH 19/23] [CK TILE] Refactor function amd_buffer_load_invalid_element_return_zero (#3512) Refactor function amd_buffer_load_invalid_element_return_zero to avoid the inefficient ASM code generated by compiler. Compiler generates suboptimal assembly for ternary operator, causing excessive VGPR usage Tested compilers: - Rocm 7.0.1 - Rocm 7.1.1 Co-authored-by: Thomas Ning --- .../arch/amd_buffer_addressing_builtins.hpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 562b246ac3..9f9770df1b 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2376,12 +2376,23 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{numeric::zero()}; + { + if(src_thread_element_valid) + { + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + } + else + { + return thread_buffer{numeric::zero()}; + } + } else - return tmp; + { + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + } #endif } From 37e9547a2932aaf327e577350e647af37ca92619 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 03:56:50 -0500 Subject: [PATCH 20/23] Fix ref algorithm dispatching. --- .../builder/include/ck_tile/builder/factory/conv_dispatcher.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 319293cff1..e235db4bb0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -112,7 +112,7 @@ constexpr auto make_conv_instance() return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - if constexpr(TileAlgorithm) + else if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } From 82803221c3335a2efa2aba0915573e8bb9d14656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 03:57:02 -0500 Subject: [PATCH 21/23] Fix smoke tests. --- experimental/builder/test/test_conv_description.cpp | 2 +- experimental/builder/test/unit_conv_tuning_params.cpp | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 498de9a42f..7204b09157 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -161,7 +161,7 @@ struct DefaultAlgorithm ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; - ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, + ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, .scheduler = ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index ee1388a77f..9005742930 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams) { ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; - } block_gemm; + } block_gemm_pipeline; } kAlgorithm; constexpr auto block_gemm = SetBlockGemm(); @@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) { constexpr struct Algorithm { - struct GridwiseGemm - { - ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; - } gridwise_gemm; + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; } kAlgorithm; constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); From 00f45cca2e089b67160708ab73d1dfeddadcd537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 03:58:09 -0500 Subject: [PATCH 22/23] clang-format --- experimental/builder/test/test_conv_description.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 7204b09157..9e8008ccf0 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -162,7 +162,8 @@ struct DefaultAlgorithm ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler = ckb::PipelineScheduler::INTRAWAVE}; + .scheduler = + ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); From c5cdd51ce4cf5749ca158d466cf3647283d7c4be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Wed, 7 Jan 2026 05:41:22 -0500 Subject: [PATCH 23/23] Fix factory for regular WMMA conv bwd weight. --- .../builder/factory/conv_bwd_weight_wmma_factory.hpp | 8 ++++---- .../conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp | 2 +- experimental/builder/test/impl/conv_algorithm_types.hpp | 2 +- .../builder/test/utils/conv_algorithm_type_utils.hpp | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 817432081b..32161a234a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -51,13 +51,13 @@ struct ConvBwdWeightWmmaFactory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid B source access order"); // The forward convolution kernel class instance. diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index 0981ea6c11..ff350ac804 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -22,7 +22,7 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x64x1) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 797f440bdc..27ba1ec3b6 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -669,7 +669,7 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = ConvAlgorithmTemplate, + Transfer_<>, ConvSpecializationBwdWeight_, GridGemm_, Prefetch_>; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index b705ab1e47..d80d6a1b8c 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -428,7 +428,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << "," << to_string(static_cast>(t)); return oss.str(); }