From d5746dd120c5d5ed9fd4558af0f189ec6308a155 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Tue, 11 Nov 2025 00:12:23 +0530 Subject: [PATCH 001/114] [CK-Tile] Add gtests for compiler CI for faster testing (#3123) * Add gtests for compiler CI for faster testing * Add changes to have a custom target * Add a gtest suite for gemm kernel for running CI tests with compiler mode * Fix Clang error (EOL) * Removed compiler subfolder from CMake * Add gtest suite for gemm kernel * Disable failed tests * Fix build errors * Resolved PR comments * Update shape for persistent gemm kernel test * Seperated types by H/W archs * Made changes to persistent types * Fix persistent build failure issue --------- Co-authored-by: Thomas Ning --- test/ck_tile/gemm/CMakeLists.txt | 6 + .../gemm/test_gemm_pipeline_compiler.cpp | 900 ++++++++++++++++++ 2 files changed, 906 insertions(+) create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 96c071cbc4..c08ab33b91 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -22,6 +22,12 @@ else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") + add_gtest_executable(test_gemm_pipeline_compiler test_gemm_pipeline_compiler.cpp) + target_compile_options(test_gemm_pipeline_compiler PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp new file mode 100644 index 0000000000..bf39e0b552 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp @@ -0,0 +1,900 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +// ============================================================================ +// Comprehensive GEMM Compiler Validation Test Suite +// This file consolidates all GEMM pipeline tests for compiler validation +// Covers essential combinations of data types, layouts, and pipeline types +// ============================================================================ + +// ---------------------------------------------------------------------------- +// Test Class Definitions for Different Pipeline Types +// ---------------------------------------------------------------------------- + +template +class TestGemmMem : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmMemWmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV3 : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmCompV3Wmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV4 : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmCompV4Wmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV6 : public TestCkTileGemmPipeline> +{ +}; + +template +class TestGemmPersistent : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmPersistentWmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +// ---------------------------------------------------------------------------- +// Type Definitions for Each Pipeline Configuration +// ---------------------------------------------------------------------------- + +// Memory Pipeline Types +using MemTestTypes = ::testing::Types< + // Parameters: ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, + // M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, K_TileSize, Scheduler, + // PipelineType + + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// Memory Pipeline WMMA Types +using MemWmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV3 Pipeline Types +using CompV3TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// CompV3 Pipeline WMMA Types +using CompV3WmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV4 Pipeline Types +using CompV4TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// CompV4 Pipeline WMMA Types +using CompV4WmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV6 Pipeline Types +using CompV6TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +// Persistent CompV3 Pipeline Types +using PersistentTestTypes = ::testing::Types, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// Persistent CompV3 Pipeline WMMA Types +using PersistentWmmaTestTypes = ::testing::Types, + std::tuple>; +#endif + +// ---------------------------------------------------------------------------- +// Test Suite Registrations +// ---------------------------------------------------------------------------- + +TYPED_TEST_SUITE(TestGemmMem, MemTestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmMemWmma, MemWmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV3, CompV3TestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmCompV3Wmma, CompV3WmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV4, CompV4TestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmCompV4Wmma, CompV4WmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV6, CompV6TestTypes); +TYPED_TEST_SUITE(TestGemmPersistent, PersistentTestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmPersistentWmma, PersistentWmmaTestTypes); +#endif + +// ============================================================================ +// Memory Pipeline Tests (Mem) +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmMem + +TYPED_TEST(TEST_SUITE_NAME, SmallM_SingleRow) +{ + std::vector Ms{1}; + constexpr int N = 1024; + constexpr int K = TestFixture::K_Tile * 2; + + for(int M : Ms) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_M) +{ + this->Run(TestFixture::M_Tile * 2, TestFixture::N_Tile, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_N) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile * 2, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_K) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_512x1024x512) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, Square_1024x1024x1024) +{ + constexpr int M = 1024; + constexpr int N = 1024; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_2048x2048x2048) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, VeryLargeMatrix_4096x4096x4096) +{ + constexpr int M = 4096; + constexpr int N = 4096; + constexpr int K = 4096; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, TallSkinny_4096x128x1024) +{ + constexpr int M = 4096; + constexpr int N = 128; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, ShortWide_128x4096x1024) +{ + constexpr int M = 128; + constexpr int N = 4096; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, DeepNarrow_2048x2048x8192) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 8192; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_ExtremelyTallMatrix) +{ + constexpr int M = 16384; + constexpr int N = 64; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_ExtremelyWideMatrix) +{ + constexpr int M = 64; + constexpr int N = 16384; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_VeryDeepK) +{ + constexpr int M = 1024; + constexpr int N = 1024; + constexpr int K = 16384; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Memory Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmMemWmma + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_WMMA) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_WMMA) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_WMMA) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V3 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV3 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV3) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV3) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, MidLargeM_CompV3) +{ + std::vector Ms{127, 255}; + constexpr int N = 1024; + + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + constexpr int VecLoadSize = (std::is_same_v || + std::is_same_v || + std::is_same_v) + ? 16 + : 8; + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV3) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV3) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, BatchedSmall_CompV3) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Compute V3 Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV3Wmma + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV3Wmma) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV3Wmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV3Wmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV3Wmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V4 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV4 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV4) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV4) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV4) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV4) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Compute V4 Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV4Wmma + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV4Wmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV4Wmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV4Wmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V6 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV6 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV6) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV6) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, MidLargeM_CompV6) +{ + std::vector Ms{127, 255}; + constexpr int N = 1024; + + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + constexpr int VecLoadSize = (std::is_same_v || + std::is_same_v || + std::is_same_v) + ? 16 + : 8; + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV6) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV6) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +// ============================================================================ +// Persistent Kernel Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmPersistent + +TYPED_TEST(TEST_SUITE_NAME, SmallM_Persistent) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_Persistent) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_Persistent) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_Persistent) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Persistent Kernel Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmPersistentWmma + +TYPED_TEST(TEST_SUITE_NAME, SmallM_PersistentWmma) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_PersistentWmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_PersistentWmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_PersistentWmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA From e593a14ae1677d7aed696589e8740796bc6085c1 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 11 Nov 2025 02:58:08 +0800 Subject: [PATCH 002/114] [ck] correct memory size in grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 (#3168) b1 and b0 use same layout, so, the size of b1_tensors_device should be same with b0_tensors_device's --- .../grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 63343df3a8..6f30bdaa73 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -221,8 +221,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b0_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - b1_tensors_device.emplace_back( - std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); From 7b6ba8d5c2dc7663e15bd8811c18b4c51cf94c99 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 11 Nov 2025 02:58:20 +0800 Subject: [PATCH 003/114] [ck] Enable missing op for gfx11 and gfx12 (#3187) --- profiler/src/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 9f86f6d88f..c22867fbed 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,6 +40,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp) list(APPEND PROFILER_OPS profile_contraction_scale.cpp) endif() +endif() + +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) @@ -53,7 +56,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) endif() - if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) @@ -74,7 +77,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) - endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR From 9f33b7cfd3df3fcfd540f7633b0abd7019935761 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 10 Nov 2025 11:08:41 -0800 Subject: [PATCH 004/114] fix input range (#3188) --- example/ck_tile/03_gemm/run_gemm_example.inc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index d5f164c40f..703ab810d8 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -309,8 +309,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-2.f, 2.f}(a_m_k); + ck_tile::FillUniformDistribution{-2.f, 2.f}(b_k_n); } else if(init_method == 1) { From 1c544abf57d5a98280c6e26194d568ca475de799 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Tue, 11 Nov 2025 16:38:15 +0100 Subject: [PATCH 005/114] Extend support for ak1 / bk1 WMMA (#3073) * Extend AK1 / BK1 support: - Add support for AK1 != BK1 - Add support for AK1, BK1 > 8 - Introduce KInner template parameter for pipelines when loading multiple tiles with one instruction * fix clang format --- example/01_gemm/gemm_wmma_fp8_v3.cpp | 10 +- .../blockwise_gemm_pipeline_wmma_selector.hpp | 3 + .../blockwise_gemm_pipeline_wmmaops_base.hpp | 47 +-- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 302 ++++++++++-------- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 250 +++++++++------ .../gridwise_ab_transfer_thread_tiles.hpp | 98 +++++- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 6 +- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 159 ++++++--- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 25 +- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 14 + ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp | 4 +- ...emm_wmma_universal_f16_f8_f16_km_kn_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_km_nk_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_km_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_km_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp | 3 +- 24 files changed, 621 insertions(+), 339 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp8_v3.cpp index 0376820b7b..2f8eac113b 100644 --- a/example/01_gemm/gemm_wmma_fp8_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp8_v3.cpp @@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t; using ComputeTypeA = ck::f8_t; using ComputeTypeB = ck::f8_t; -using ALayout = Row; +using ALayout = Col; using BLayout = Col; using CLayout = Row; @@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 64, - 8, 8, + 16, 16, // AK1, BK1 16, 16, 4, 2, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 4, 16, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, - S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, + 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp index 8cff087ddb..89952910e6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -28,6 +28,7 @@ template constexpr auto BlockGemmPipeline_Selector() { @@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) @@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 265db9166a..abc9720714 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -30,6 +30,7 @@ template struct BlockwiseGemmWmmaops_pipeline_base { @@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; using ThisThreadBlock = ThisThreadBlock; @@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr index_t B_KRow = 1; #endif - static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); - static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + static constexpr auto wmma_gemm = WmmaGemm{}; + + static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner; + static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread); + static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread); static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); - - static constexpr auto wmma_gemm = - WmmaGemm{}; - static constexpr index_t KRepeat = KPerBlock / KPack; static constexpr auto WmmaK = Number{}; @@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |MRepeat|MWave |KRow |MLane |KPack - return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); } __device__ static auto CalculateBThreadOriginDataIndex() @@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack - return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); } template @@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base return make_tuple(c_thread_m, c_thread_n); } - using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + using Tuple7 = decltype(CalculateAThreadOriginDataIndex()); /** * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. @@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base * repeat dimensions. */ __host__ __device__ - BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), - Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(), + Tuple7 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && @@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); // C[M, N, NumRegWmma] @@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 5d7c570428..5f731933e2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; - using Base::WaveSize; using typename Base::HotLoopInstList; using Base::A_K1; @@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0, I0), - a_thread_buf); - + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + a_thread_buf); if constexpr(m0 == I0) { if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); }); } else @@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), + make_tuple(I0, n0, k0, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs( I0)[Number{}], b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), + make_tuple(I0, n0, I0, I0, I0, I0, I0), b_thread_buf); }); } } - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, I0, I0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + I0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, I0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + static_for<0, KInner, 1>{}([&](auto) { + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); }); }); }); @@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, I1, I1, Number{})); + make_tuple(Number{}, I1, I1, I1, I1, I1, Number{})); // B[NRepeat, N1, N2, KPack] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; @@ -399,6 +418,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; @@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_offset) { static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, - m0, - I0, - I0, - I0, - I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0_inner, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs(I0)[Number< n0 * BScaleStruct::num_scale_k_block + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_inner) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0_inner, - I0, - I0, - Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard. + // B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts. + // It is performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0_offset + k0_inner == KRepeat - 1 && + m0 == MRepeat - 1 && n0 == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0_inner, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard. - // B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts. - // It is performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 && - n0 == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } }); }); }); @@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); using AThreadCopy = @@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 83dadb2175..cbe13b6e00 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; @@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same_v) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } else { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_scale_struct.b_scale_thread_bufs( - I0)[Number{}], - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } }); @@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 465952e285..23f16d38e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -17,6 +17,9 @@ template {}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(ABK1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } } __device__ static constexpr auto GetBlockStep() diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 68476ef3bf..a36ccd43ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -313,14 +313,16 @@ struct ABTransferWaveTiles // This is a block descriptor used to read LDS memory into register // It's defined in a way consistent with the existing implementation to // avoid changes in the pipelines - return make_naive_tensor_descriptor(make_tuple(Number{}, + return make_naive_tensor_descriptor(make_tuple(I1, Number{}, + Number{}, Number{}, Number{}, Number{}, Number{}), - make_tuple(Number{}, + make_tuple(I0, Number{}, + Number{}, Number{}, Number{}, Number{}, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index fa7eb4faaa..38ebdab65e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -109,9 +109,20 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - // TODO: I am pretty sure this is always 16 and *should* always be 16. - static constexpr auto KPack = - math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16); + static constexpr index_t KPerWmmaBlk = + WmmaSelector::selected_wmma + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -201,54 +212,115 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 return b1_block_copy_step; } + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(K1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } + } + template __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) { - constexpr auto a_wave_desc = [&]() { - // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto A_KRow = I2; -#else - constexpr auto A_KRow = I1; -#endif - return transform_tensor_descriptor( - ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); - - return a_wave_desc; + return MakeWmmaTileDescriptor(ABlockDesc_{}); } template __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) { - constexpr auto b0_wave_desc = [&]() { - // BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto B_KRow = I2; -#else - constexpr auto B_KRow = I1; -#endif - return transform_tensor_descriptor( - B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); - - return b0_wave_desc; + return MakeWmmaTileDescriptor(B0BlockDesc_{}); } template @@ -356,6 +428,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 MRepeat, LRepeat, KPack, + KInner, true>())>; // TransposeC (must be true to work), C' = B' x A' // block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 7a5e324468..56f09cee96 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -151,10 +151,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), + static constexpr index_t KPerWmmaBlk = WmmaSelector::selected_wmma - .k_per_wmma); + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -218,6 +228,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, MPerWmma, AK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingA, PermuteA, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -251,6 +264,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, NPerWmma, BK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingB, PermuteB, BBlockTransferThreadClusterLengths_BK0_N_BK1, @@ -563,7 +579,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base NPerWmma, MRepeat, NRepeat, - KPack>())>; + KPack, + KInner>())>; // Used to create obj in global function and pass it to Run method using EpilogueCShuffle = diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index bca68764f9..55ede990af 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -95,6 +95,7 @@ struct wmma_type __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index 71b5c5e7cf..806b6e684d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -48,7 +48,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index f4489dc45f..4516d06492 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -50,7 +50,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 423f86365c..5ace0594f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -53,7 +53,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index 2eb28958e6..27deab1c8c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -56,7 +56,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp index d10b9facd5..bd5c7d8783 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp @@ -48,7 +48,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp index d9d16ede65..1956d1a951 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp index 9277e5e901..934c6aa7ef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp index e97a649c19..9860b81b78 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp index c8f1b85ddb..4d7169565a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp index fc0220a502..3728368bc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp index b87cf64b0f..3506575f5d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp index 31ad66409e..eef0d6de6a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp @@ -50,7 +50,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp index 4c37c398fe..2418be62b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -55,7 +55,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp index 6b5314b701..38f2869303 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance From 06c651b100c9dc50753277069bdc68411da7ca1a Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 11 Nov 2025 07:42:26 -0800 Subject: [PATCH 006/114] formatting (#3182) --- include/ck_tile/ops/gemm_quant.hpp | 1 + .../block/block_gemm_quant_common.hpp | 38 +++++++++++++++++++ ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 18 ++------- .../block_universal_gemm_as_aquant_bs_cr.hpp | 17 ++------- .../block_universal_gemm_as_bs_bquant_cr.hpp | 17 ++------- 5 files changed, 48 insertions(+), 43 deletions(-) create mode 100644 include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 3273131875..3e16d937cb 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp new file mode 100644 index 0000000000..d695888b88 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Common utilities for quantized GEMM block operations +template +struct BlockGemmQuantCommon +{ + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemmType::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index df55081b69..2d92745f75 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -100,21 +101,8 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 8b95ec6ddf..1f72f4dc12 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -543,20 +544,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9db444b57f..660c30aa6e 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -376,20 +377,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template From aa1fb29aa102d937e061f138ecb22ef81e7a8fcd Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 11 Nov 2025 07:44:38 -0800 Subject: [PATCH 007/114] Bump commit ref for TheRock in workflows (#3189) * Bump commit ref for TheRock in workflows * Update to more recent commit (could also `rm` the patch) * Revert "Update to more recent commit (could also `rm` the patch)" This reverts commit 4b9f4952ead77e068f5ab86a07701c7e9bed48cc. * Rm patch that no longer applies * Fix post_build_upload flag name * Fix artifact_group plumbing for setup test env --- .github/workflows/therock-ci-linux.yml | 6 ++++-- .github/workflows/therock-test-component.yml | 4 ++-- .github/workflows/therock-test-packages.yml | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index f4d0c0063c..86d134e456 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -53,8 +53,8 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit path: "TheRock" + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: Setup ccache run: | @@ -77,6 +77,8 @@ jobs: - name: Patch rocm-libraries run: | git config --global --add safe.directory '*' + # Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo + rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps @@ -128,7 +130,7 @@ jobs: run: | python3 TheRock/build_tools/github_actions/post_build_upload.py \ --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --artifact-group ${{ env.AMDGPU_FAMILIES }} \ --build-dir TheRock/build \ --upload diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 1ccc1d57bc..27eff4fdb0 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,13 +51,13 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' with: ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} - AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + ARTIFACT_GROUP: ${{ inputs.amdgpu_families }} OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} VENV_DIR: ${{ env.VENV_DIR }} FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }} diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index efb5a6b1a0..81632fce48 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: "Configuring CI options" env: From 88e3212fccf2a879c0e718deecc28caff453bb29 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 11 Nov 2025 11:17:24 -0500 Subject: [PATCH 008/114] chore(copyright): update copyright header for tile_engine directory (#3180) --- tile_engine/ops/commons/test_benchmark.sh | 3 +++ tile_engine/ops/commons/test_validation.py | 3 +++ tile_engine/ops/commons/validation_utils.py | 2 +- tile_engine/ops/gemm/codegen_utils.py | 2 +- tile_engine/ops/gemm/gemm_benchmark.hpp | 2 +- tile_engine/ops/gemm/gemm_benchmark.py | 2 +- tile_engine/ops/gemm/gemm_benchmark_single.cpp | 2 +- tile_engine/ops/gemm/gemm_common.hpp | 2 +- tile_engine/ops/gemm/gemm_instance_builder.py | 3 +++ tile_engine/ops/gemm/gemm_profiler.hpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py | 2 +- .../ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py | 3 +++ tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp | 2 +- tile_engine/ops/gemm_preshuffle/commons/validation_utils.py | 2 +- tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp | 3 +++ tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py | 2 +- .../ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp | 2 +- tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp | 2 +- .../ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py | 4 ++-- tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp | 2 +- 23 files changed, 34 insertions(+), 19 deletions(-) diff --git a/tile_engine/ops/commons/test_benchmark.sh b/tile_engine/ops/commons/test_benchmark.sh index 1fb7c163af..e2e0324da8 100755 --- a/tile_engine/ops/commons/test_benchmark.sh +++ b/tile_engine/ops/commons/test_benchmark.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Test script for tile engine GEMM benchmarks # This script demonstrates how to run the new individual benchmark executables diff --git a/tile_engine/ops/commons/test_validation.py b/tile_engine/ops/commons/test_validation.py index 79f24265f1..46fb008c27 100644 --- a/tile_engine/ops/commons/test_validation.py +++ b/tile_engine/ops/commons/test_validation.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """ Test script to verify that the validation logic is working correctly. """ diff --git a/tile_engine/ops/commons/validation_utils.py b/tile_engine/ops/commons/validation_utils.py index 3eb7bf8b57..5787446e8c 100644 --- a/tile_engine/ops/commons/validation_utils.py +++ b/tile_engine/ops/commons/validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 0020fccf05..eecc2228a6 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -1,5 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # -*- coding: utf-8 -*- diff --git a/tile_engine/ops/gemm/gemm_benchmark.hpp b/tile_engine/ops/gemm/gemm_benchmark.hpp index 0e2619785e..7c8df32ad8 100644 --- a/tile_engine/ops/gemm/gemm_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm/gemm_benchmark.py b/tile_engine/ops/gemm/gemm_benchmark.py index 9f323f2640..cc04dbe0db 100755 --- a/tile_engine/ops/gemm/gemm_benchmark.py +++ b/tile_engine/ops/gemm/gemm_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm/gemm_benchmark_single.cpp b/tile_engine/ops/gemm/gemm_benchmark_single.cpp index bbcc6eb505..6323c066a1 100644 --- a/tile_engine/ops/gemm/gemm_benchmark_single.cpp +++ b/tile_engine/ops/gemm/gemm_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp index 4732f2a1ba..899221547f 100644 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 1aff42b902..8885c821c1 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 575e5240a8..3c6bbc34d3 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp index 53dcdb5e1f..f8c196e32a 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py index fb81b9c2c2..044e08baca 100755 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 032a625354..41d2f736e1 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp index 4732f2a1ba..899221547f 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index 3f7858f146..cc167fb75f 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp index 8e19c11c7d..3a2cdc71fe 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py index b38ff5dffb..70ce3b0d72 100644 --- a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py +++ b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 77a9f26527..748fe581d3 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -1,3 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once #include "ck_tile/core.hpp" diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py index 0217a439f2..d8892be7d6 100755 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 1f03d1cf9b..4fbb25f0c9 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index abaa5ebd46..1b2cfe3735 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 57c250f57e..9ce6d8cb25 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -1,5 +1,5 @@ -## Copyright © Advanced Micro Devices, Inc. or its affiliates. -## SPDX-License-Identifier: MIT +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT import argparse import os diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 85b731c231..739bd7e677 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -1,4 +1,4 @@ -// Copyright © Advanced Micro Devices, Inc. or its affiliates. +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once From 1b1c46e508c1fd40a03f54114b6b78629032fb4f Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 12 Nov 2025 00:23:57 +0800 Subject: [PATCH 009/114] [CK_TILE] Fix gemm_quant (#3186) --- .../38_block_scale_gemm/CMakeLists.txt | 2 +- .../38_block_scale_gemm/gemm_quant_basic.cpp | 4 + .../38_block_scale_gemm/gemm_utils.hpp | 8 ++ include/ck_tile/host/tensor_shuffle_utils.hpp | 98 ++++++++++++++----- .../gemm/warp/warp_gemm_attribute_wmma.hpp | 1 + ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 4 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 11 +-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 6 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 5 +- .../pipeline/tile_gemm_quant_traits.hpp | 5 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 2 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 14 ++- .../test_gemm_quant_fixtures.hpp | 24 +++-- 13 files changed, 135 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 7358d4d749..b1ae9369a2 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index b22596537f..d605a2b780 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f) int main(int argc, char* argv[]) { +#if CK_TILE_USE_WMMA + return !run_gemm_example(argc, argv); +#else // Use non-preshuffled GemmConfig for 2D block scale support return !run_gemm_example(argc, argv); +#endif } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 589caf88f4..1839c7f98d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -55,21 +82,46 @@ template auto shuffle_b_permuteN(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, - NRepeat, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + } } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 90f6204ff3..dd2931f6b7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -79,6 +79,7 @@ struct WarpGemmAttributeWmma static constexpr index_t kM = Impl::kM; static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; + static constexpr index_t kCMLane = Impl::kCMLane; static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 2d92745f75..6422c07e1d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -82,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 1f72f4dc12..bbdd3128bf 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -25,13 +25,11 @@ struct BlockGemmAQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { @@ -349,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, // 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3. - constexpr uint32_t kTileRowsOfCPerThread = 4; + constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8; decltype(threadIdx.x) pull_from_lane = 0; if constexpr(WarpGemm::kM == 16) { @@ -410,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // desired row coefficient auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; - constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8; + ; constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane; // Multiply by 4 because output is stored in tiles of 4 diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 660c30aa6e..28ae709bf0 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -25,13 +25,11 @@ struct BlockGemmBQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 36cbb87877..15d2727f3b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -240,7 +240,10 @@ struct QuantGemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static auto BlockSize() + { + return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize); + } CK_TILE_HOST static constexpr QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index c4429b76f9..3a5b86382d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -41,7 +41,8 @@ template + bool UsePersistentKernel_ = false, + int VectorSize_ = 16> struct TileGemmQuantTraits { static constexpr bool kPadM = kPadM_; @@ -50,7 +51,7 @@ struct TileGemmQuantTraits static constexpr QuantType kQuantType = QuantType_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using ALayout = ALayout_; diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 3a49e69c37..1c4a25c8bd 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 6454101daf..6226a2de9e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -69,7 +69,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - + // WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize. + // static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + // VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only + // need to check the limitation on RDNA cards, i.e. assume wave size is 32. + constexpr ck_tile::index_t WaveSize = 32; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile); + constexpr bool SupportVectorSize16 = + (M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0; + constexpr int VectorSize = PreshuffleB ? (SupportVectorSize16 ? 16 : 8) : 16; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, @@ -89,7 +97,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test ALayout, BLayout, GemmConfig::TransposeC, - DoubleSmemBuffer>; + DoubleSmemBuffer, + false, + VectorSize>; // Let the derived class create the appropriate pipeline and epilogue static_cast(this) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index cabc0ec02c..5aac095514 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -7,6 +7,16 @@ #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else + return is_8bit ? 64 : 32; +#endif +} + struct GemmConfigBase { static constexpr bool kPadM = false; @@ -40,7 +50,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleQuant : public GemmConfigBase @@ -75,7 +85,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefill : public GemmConfigBase @@ -94,7 +104,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill @@ -132,7 +142,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase Date: Tue, 11 Nov 2025 14:26:01 -0500 Subject: [PATCH 010/114] chore(copyright): update copyright header for script directory (#3184) * chore(copyright): update copyright header for tile_engine directory * chore(copyright): update copyright header for script directory --------- Co-authored-by: Vidyasagar Ananthan --- script/check_copyright_year.sh | 3 +++ script/clang-format-overwrite.sh | 3 +++ script/cmake-ck-dev.sh | 3 +++ script/convert_miopen_driver_to_profiler.py | 3 ++- script/count_vgpr.sh | 3 +++ .../generate_list_of_files_not_referenced_in_tests.py | 4 ++-- script/dependency-parser/main.py | 2 +- script/dependency-parser/src/enhanced_ninja_parser.py | 2 +- script/dependency-parser/src/selective_test_filter.py | 2 +- script/gemm_profile.sh | 2 +- script/hipclang_opt.sh | 3 +++ script/install_precommit.sh | 3 +++ script/launch_tests.sh | 2 +- script/ninja_json_converter.py | 2 +- script/process_perf_data.py | 3 +++ script/process_perf_data.sh | 3 +++ script/process_qa_data.sh | 3 +++ script/profile_batched_gemm.sh | 3 +++ script/profile_gemm.sh | 3 +++ script/profile_gemm_bilinear.sh | 3 +++ script/profile_grouped_conv_bwd_data.sh | 3 +++ script/profile_grouped_conv_bwd_weight.sh | 3 +++ script/profile_grouped_conv_fwd.sh | 3 +++ script/profile_grouped_conv_fwd_outelementop.sh | 3 +++ script/profile_grouped_gemm.sh | 3 +++ script/profile_mixed_gemm.sh | 3 +++ script/profile_onnx_gemm.sh | 3 +++ script/profile_permute_scale.sh | 3 +++ script/profile_reduce_no_index.sh | 3 +++ script/profile_reduce_with_index.sh | 3 +++ script/profile_resnet50.sh | 3 +++ script/profile_splitK_gemm.sh | 3 +++ script/remod_for_ck_tile.py | 3 +++ script/remove_exec_bit.sh | 2 +- script/run_ck_profiler_gemm_with_csv_shapes.py | 2 +- script/run_full_performance_tests.sh | 3 +++ script/run_gemm_performance_tests.sh | 3 +++ script/run_performance_tests.sh | 3 +++ script/sccache_wrapper.sh | 3 +++ script/test_convnd_fwd.sh | 3 +++ script/test_reduce_no_index.sh | 3 +++ script/uninstall_precommit.sh | 3 +++ 42 files changed, 108 insertions(+), 11 deletions(-) diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh index f7709472ef..1b63c6b711 100755 --- a/script/check_copyright_year.sh +++ b/script/check_copyright_year.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + current_year=$(date +%Y) exit_code=0 diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 74391ded28..23b57b9935 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | grep -v 'build/' | grep -v 'include/rapidjson'| xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|.hpp|.inc|include/rapidjson/")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 6220009b03..9643af1de0 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # exit when a command exits with non-zero status; also when an unbound variable is referenced set -eu # pipefail is supported by many shells, not supported by sh and dash diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index d814e0719c..5aff9c0a7f 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -1,5 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + # Convert miopen driver command to ck Profiler # Example: python3 ../script/convert_miopen_driver_to_profiler.py # /opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 -k 64 -y 3 -x 3 diff --git a/script/count_vgpr.sh b/script/count_vgpr.sh index 07debc53a8..651a894db6 100755 --- a/script/count_vgpr.sh +++ b/script/count_vgpr.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + FILE=$1 for num in {0..255} diff --git a/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py b/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py index 8419b9491e..58bb9e8e93 100644 --- a/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py +++ b/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT -## Copyright © Advanced Micro Devices, Inc. or its affiliates. -## SPDX-License-Identifier: MIT # This script generate list of files that are not referenced from any test (list in JSON format) # Script only looks at not referenced files from three directories: include, library and profiler diff --git a/script/dependency-parser/main.py b/script/dependency-parser/main.py index 623ae05afd..f345362b26 100644 --- a/script/dependency-parser/main.py +++ b/script/dependency-parser/main.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/dependency-parser/src/enhanced_ninja_parser.py b/script/dependency-parser/src/enhanced_ninja_parser.py index ff6344a4c1..2ac8e8537a 100644 --- a/script/dependency-parser/src/enhanced_ninja_parser.py +++ b/script/dependency-parser/src/enhanced_ninja_parser.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/dependency-parser/src/selective_test_filter.py b/script/dependency-parser/src/selective_test_filter.py index d3228ef624..83f7f7eebe 100644 --- a/script/dependency-parser/src/selective_test_filter.py +++ b/script/dependency-parser/src/selective_test_filter.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/gemm_profile.sh b/script/gemm_profile.sh index 89419ca711..d3d66bcaa9 100755 --- a/script/gemm_profile.sh +++ b/script/gemm_profile.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT BIN=./bin/tile_example_gemm_weight_preshuffle diff --git a/script/hipclang_opt.sh b/script/hipclang_opt.sh index c51bd51d97..ba5636eeb6 100755 --- a/script/hipclang_opt.sh +++ b/script/hipclang_opt.sh @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + rm *.ll *.s BC_FILE=$1 diff --git a/script/install_precommit.sh b/script/install_precommit.sh index 545dcfa666..f80b06a95a 100755 --- a/script/install_precommit.sh +++ b/script/install_precommit.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + run_and_check() { "$@" status=$? diff --git a/script/launch_tests.sh b/script/launch_tests.sh index 52151b71f6..1911613023 100755 --- a/script/launch_tests.sh +++ b/script/launch_tests.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT # Get the directory where the script is located diff --git a/script/ninja_json_converter.py b/script/ninja_json_converter.py index e68f7ccfa3..5e974cf730 100644 --- a/script/ninja_json_converter.py +++ b/script/ninja_json_converter.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/process_perf_data.py b/script/process_perf_data.py index b35ba64041..5f81512a4c 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import io import argparse diff --git a/script/process_perf_data.sh b/script/process_perf_data.sh index 50c84924f5..4786ddded0 100755 --- a/script/process_perf_data.sh +++ b/script/process_perf_data.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd need the following python packages: diff --git a/script/process_qa_data.sh b/script/process_qa_data.sh index 420453cddc..d56ef5c1ec 100755 --- a/script/process_qa_data.sh +++ b/script/process_qa_data.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd need the following python packages: diff --git a/script/profile_batched_gemm.sh b/script/profile_batched_gemm.sh index f90baaed68..bb7d61deec 100755 --- a/script/profile_batched_gemm.sh +++ b/script/profile_batched_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh index b88159e74d..f766ca50fa 100755 --- a/script/profile_gemm.sh +++ b/script/profile_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_gemm_bilinear.sh b/script/profile_gemm_bilinear.sh index e6edefae85..057d7d7e49 100755 --- a/script/profile_gemm_bilinear.sh +++ b/script/profile_gemm_bilinear.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 DRIVER="../build/bin/ckProfiler" diff --git a/script/profile_grouped_conv_bwd_data.sh b/script/profile_grouped_conv_bwd_data.sh index a1d2f450c9..3805ed86cd 100755 --- a/script/profile_grouped_conv_bwd_data.sh +++ b/script/profile_grouped_conv_bwd_data.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_conv_bwd_weight.sh b/script/profile_grouped_conv_bwd_weight.sh index e3652202d4..146431621c 100755 --- a/script/profile_grouped_conv_bwd_weight.sh +++ b/script/profile_grouped_conv_bwd_weight.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_conv_fwd.sh b/script/profile_grouped_conv_fwd.sh index 9a974525ad..8491aecf9e 100755 --- a/script/profile_grouped_conv_fwd.sh +++ b/script/profile_grouped_conv_fwd.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_conv_fwd_outelementop.sh b/script/profile_grouped_conv_fwd_outelementop.sh index ac444a25c2..a0df8cd4c5 100755 --- a/script/profile_grouped_conv_fwd_outelementop.sh +++ b/script/profile_grouped_conv_fwd_outelementop.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_gemm.sh b/script/profile_grouped_gemm.sh index 8adb7c81ac..fe452d5cab 100755 --- a/script/profile_grouped_gemm.sh +++ b/script/profile_grouped_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_mixed_gemm.sh b/script/profile_mixed_gemm.sh index 383c7ea36e..a867bf3a77 100755 --- a/script/profile_mixed_gemm.sh +++ b/script/profile_mixed_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_onnx_gemm.sh b/script/profile_onnx_gemm.sh index c2721e7f59..ea18fc761e 100755 --- a/script/profile_onnx_gemm.sh +++ b/script/profile_onnx_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 DRIVER="../build/bin/ckProfiler" diff --git a/script/profile_permute_scale.sh b/script/profile_permute_scale.sh index 945d10f47b..31d6a06c5e 100755 --- a/script/profile_permute_scale.sh +++ b/script/profile_permute_scale.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_reduce_no_index.sh b/script/profile_reduce_no_index.sh index 66bfe1dcd3..3bae07906b 100755 --- a/script/profile_reduce_no_index.sh +++ b/script/profile_reduce_no_index.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + DRIVER="../build/bin/ckProfiler" VERIFY="-v $1" INIT=$2 diff --git a/script/profile_reduce_with_index.sh b/script/profile_reduce_with_index.sh index 43543f4430..943a590528 100755 --- a/script/profile_reduce_with_index.sh +++ b/script/profile_reduce_with_index.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + DRIVER="../build/bin/ckProfiler" VERIFY="-v $1" INIT=$2 diff --git a/script/profile_resnet50.sh b/script/profile_resnet50.sh index b55cb2ccef..ec6b32c0c8 100755 --- a/script/profile_resnet50.sh +++ b/script/profile_resnet50.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_splitK_gemm.sh b/script/profile_splitK_gemm.sh index d62f0e4753..843d59c918 100755 --- a/script/profile_splitK_gemm.sh +++ b/script/profile_splitK_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/remod_for_ck_tile.py b/script/remod_for_ck_tile.py index 7601c9d619..feb50dc290 100755 --- a/script/remod_for_ck_tile.py +++ b/script/remod_for_ck_tile.py @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os root_dir = os.getcwd() diff --git a/script/remove_exec_bit.sh b/script/remove_exec_bit.sh index 2926683d6a..0b3ca80422 100755 --- a/script/remove_exec_bit.sh +++ b/script/remove_exec_bit.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT for file in $(git diff --cached --name-only --diff-filter=ACM | grep -E '\.(cpp|hpp|txt|inc)$'); do diff --git a/script/run_ck_profiler_gemm_with_csv_shapes.py b/script/run_ck_profiler_gemm_with_csv_shapes.py index eb0eb9c920..2590e3942e 100644 --- a/script/run_ck_profiler_gemm_with_csv_shapes.py +++ b/script/run_ck_profiler_gemm_with_csv_shapes.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # -*- coding: utf-8 -*- diff --git a/script/run_full_performance_tests.sh b/script/run_full_performance_tests.sh index 508200b21a..55740da097 100755 --- a/script/run_full_performance_tests.sh +++ b/script/run_full_performance_tests.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ # you would also need to set up some environment variables in order to diff --git a/script/run_gemm_performance_tests.sh b/script/run_gemm_performance_tests.sh index 12adad30f8..c72b2a760b 100755 --- a/script/run_gemm_performance_tests.sh +++ b/script/run_gemm_performance_tests.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ # run the script as "./run_gemm_performance_tests.sh diff --git a/script/run_performance_tests.sh b/script/run_performance_tests.sh index 4e13b59d34..9163e6d693 100755 --- a/script/run_performance_tests.sh +++ b/script/run_performance_tests.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ # run the script as "./run_performance_tests.sh diff --git a/script/sccache_wrapper.sh b/script/sccache_wrapper.sh index 30fd17e520..1a7e37881e 100755 --- a/script/sccache_wrapper.sh +++ b/script/sccache_wrapper.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set -e COMPILERS_HASH_DIR=${COMPILERS_HASH_DIR:-"/tmp/.sccache"} SCCACHE_EXTRAFILES=${SCCACHE_EXTRAFILES:-"${COMPILERS_HASH_DIR}/rocm_compilers_hash_file"} diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh index 8bd2c2fc33..d716caac15 100644 --- a/script/test_convnd_fwd.sh +++ b/script/test_convnd_fwd.sh @@ -1,4 +1,7 @@ #!/usr/bin/env bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # set -e diff --git a/script/test_reduce_no_index.sh b/script/test_reduce_no_index.sh index b956303837..717a872c45 100755 --- a/script/test_reduce_no_index.sh +++ b/script/test_reduce_no_index.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## The following will be used for CI diff --git a/script/uninstall_precommit.sh b/script/uninstall_precommit.sh index b0d4d15166..394425acdd 100755 --- a/script/uninstall_precommit.sh +++ b/script/uninstall_precommit.sh @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + pre-commit uninstall From c54ecd905b07849076069d56c284472230564568 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 11 Nov 2025 14:27:33 -0500 Subject: [PATCH 011/114] docs: update ckProfiler readme with selective building option (#3140) * docs: update ckProfiler readme with selective building option * docs: add list of operations for ckProfiler --- profiler/README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/profiler/README.md b/profiler/README.md index 05bbc7b4f9..86f668eacb 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -1,5 +1,23 @@ [Back to the main page](../README.md) # Composable Kernel profiler + +## Building Specific Profilers +To reduce build time, filter which operations to compile using CMake options: + +```bash +# Build all grouped_gemm variants (grouped_gemm, grouped_gemm_fastgelu, grouped_gemm_tile_loop, etc.) +cmake -DCK_PROFILER_OP_FILTER="grouped_gemm" .. + +# Build ONLY base grouped_gemm (excludes variants - use exact regex match with ^ and $) +cmake -DCK_PROFILER_OP_FILTER="^grouped_gemm$" .. +``` + +Both `CK_PROFILER_OP_FILTER` and `CK_PROFILER_INSTANCE_FILTER` accept regex patterns. Default builds all operations. + +To find the complete list of operations, run the following command: +```bash +find profiler/src -name "profile_*.cpp" | sed 's|profiler/src/profile_||' | sed 's|.cpp||' | sort +``` ## Profiler GEMM UNIVERSAL kernels ```bash # arg1: tensor operation (gemm_universal: Universal GEMM) From b145a5fe80d2f9d965f2c8555808017c3a660fc2 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 11 Nov 2025 15:15:49 -0500 Subject: [PATCH 012/114] Add CK Tile Tutorials Folder with GEMM and COPY Kernel (#3038) * feat: add tutorial folder with gemm tutorial * chore: move copy kernel from examples folder to tutorial * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: remove handdrawn images * docs: add write ups to explain the gemm kernel * docs: add about block level pipeline and static distributed tensors --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- CMakeLists.txt | 6 + example/ck_tile/CMakeLists.txt | 1 - tutorial/CMakeLists.txt | 15 + .../ck_tile/00_copy_kernel}/CMakeLists.txt | 6 +- .../ck_tile/00_copy_kernel}/README.md | 0 .../ck_tile/00_copy_kernel}/copy_basic.cpp | 22 +- .../ck_tile/00_copy_kernel}/copy_basic.hpp | 0 .../00_copy_kernel}/test_tile_example.sh | 2 +- .../01_naive_gemm/BLOCK_LEVEL_PIPELINE.md | 589 +++++++++++++++++ tutorial/ck_tile/01_naive_gemm/CMakeLists.txt | 7 + .../01_naive_gemm/HOST_LEVEL_PIPELINE.md | 618 ++++++++++++++++++ .../01_naive_gemm/KERNEL_ENTRY_POINT.md | 464 +++++++++++++ tutorial/ck_tile/01_naive_gemm/README.md | 150 +++++ tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md | 506 ++++++++++++++ ...e_gemm_block_pipeline_agmem_bgmem_creg.hpp | 165 +++++ ...ice_gemm_block_policy_agmem_bgmem_creg.hpp | 135 ++++ ...ce_gemm_host_pipeline_agmem_bgmem_creg.hpp | 92 +++ ...tice_gemm_host_policy_agmem_bgmem_creg.hpp | 51 ++ .../ck_tile/01_naive_gemm/practice_gemm.cpp | 131 ++++ .../ck_tile/01_naive_gemm/practice_gemm.hpp | 69 ++ .../ck_tile/01_naive_gemm/reference_gemm.hpp | 36 + ...ce_gemm_warp_pipeline_asmem_bsmem_creg.hpp | 195 ++++++ ...tice_gemm_warp_policy_asmem_bsmem_creg.hpp | 35 + tutorial/ck_tile/CMakeLists.txt | 7 + 24 files changed, 3287 insertions(+), 15 deletions(-) create mode 100644 tutorial/CMakeLists.txt rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/CMakeLists.txt (54%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/README.md (100%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/copy_basic.cpp (86%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/copy_basic.hpp (100%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/test_tile_example.sh (95%) create mode 100644 tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md create mode 100644 tutorial/ck_tile/01_naive_gemm/CMakeLists.txt create mode 100644 tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md create mode 100644 tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md create mode 100644 tutorial/ck_tile/01_naive_gemm/README.md create mode 100644 tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md create mode 100644 tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp create mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 049da5637f..7b4990dba4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -683,6 +683,12 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) PACKAGE_NAME examples ) add_subdirectory(example) + + add_subdirectory(tutorial) + rocm_package_setup_component(tutorials + LIBRARY_NAME composablekernel + PACKAGE_NAME tutorials + ) add_subdirectory(tile_engine) if(BUILD_TESTING) add_subdirectory(test) diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index a6cfcde86e..92ee0a4c31 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -25,7 +25,6 @@ add_subdirectory(22_gemm_multi_abd) add_subdirectory(35_batched_transpose) add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) -add_subdirectory(39_copy) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) diff --git a/tutorial/CMakeLists.txt b/tutorial/CMakeLists.txt new file mode 100644 index 0000000000..a2f35ca53f --- /dev/null +++ b/tutorial/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include +) + +message(STATUS "Building tutorials...") +add_custom_target(tutorials) + +# add all tutorial subdir +file(GLOB dir_list LIST_DIRECTORIES true *) +FOREACH(subdir ${dir_list}) + if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") + add_subdirectory(${subdir}) + ENDIF() +ENDFOREACH() diff --git a/example/ck_tile/39_copy/CMakeLists.txt b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt similarity index 54% rename from example/ck_tile/39_copy/CMakeLists.txt rename to tutorial/ck_tile/00_copy_kernel/CMakeLists.txt index 98397a33d2..91dd036eff 100644 --- a/example/ck_tile/39_copy/CMakeLists.txt +++ b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt @@ -1,7 +1,9 @@ -add_executable(tile_example_copy EXCLUDE_FROM_ALL copy_basic.cpp) +add_executable(tile_tutorial_copy_kernel EXCLUDE_FROM_ALL copy_basic.cpp) # Impact: This flag ensures that the compiler doesn't make # assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns. -target_compile_options(tile_example_copy PRIVATE +target_compile_options(tile_tutorial_copy_kernel PRIVATE -mllvm -enable-noalias-to-md-conversion=0 ) + +add_dependencies(tutorials tile_tutorial_copy_kernel) diff --git a/example/ck_tile/39_copy/README.md b/tutorial/ck_tile/00_copy_kernel/README.md similarity index 100% rename from example/ck_tile/39_copy/README.md rename to tutorial/ck_tile/00_copy_kernel/README.md diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/tutorial/ck_tile/00_copy_kernel/copy_basic.cpp similarity index 86% rename from example/ck_tile/39_copy/copy_basic.cpp rename to tutorial/ck_tile/00_copy_kernel/copy_basic.cpp index de91dc1be9..282e9ff8c1 100644 --- a/example/ck_tile/39_copy/copy_basic.cpp +++ b/tutorial/ck_tile/00_copy_kernel/copy_basic.cpp @@ -54,10 +54,10 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); // Define tile configuration - using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N - using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension - using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension - using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension + using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N + using WaveTile = ck_tile::sequence<64, 4>; // per-wave tile size along M and N dimension + using BlockWaves = ck_tile::sequence<4, 1>; // number of waves per block along M and N dimension + using BlockTile = ck_tile::sequence<512, 4>; // per-block tile size along M and N dimension // Calculate grid size ck_tile::index_t kGridSize = @@ -68,14 +68,14 @@ bool run(const ck_tile::ArgParser& arg_parser) using Shape = ck_tile::TileCopyShape; using Problem = ck_tile::TileCopyProblem; using Policy = ck_tile::TileCopyPolicy; - using Kernel = ck_tile::ElementWiseTileCopyKernel; - // using Kernel = ck_tile::TileCopyKernel; - // using Kernel = ck_tile::TileCopyKernel_LDS; + using Kernel = ck_tile::ElementWiseTileCopyKernel; // operates on element by + // element basis. - // question: Why do we not have a pipeline? - // answer: For basic copy operation, pipeline is not needed. - // we intentionally do not use pipeline for this example and let the kernel be composite of - // Problem and Policy + // We also implement two variations of the copy kernel: + // 1. TileCopyKernel: This is the basic copy kernel that operates on tile by tile basis. + // 2. TileCopyKernel_LDS: This is the copy kernel that operates on tile by tile basis and uses + // the LDS. using Kernel = ck_tile::TileCopyKernel; using Kernel = + // ck_tile::TileCopyKernel_LDS; auto blockSize = Kernel::BlockSize(); diff --git a/example/ck_tile/39_copy/copy_basic.hpp b/tutorial/ck_tile/00_copy_kernel/copy_basic.hpp similarity index 100% rename from example/ck_tile/39_copy/copy_basic.hpp rename to tutorial/ck_tile/00_copy_kernel/copy_basic.hpp diff --git a/example/ck_tile/39_copy/test_tile_example.sh b/tutorial/ck_tile/00_copy_kernel/test_tile_example.sh similarity index 95% rename from example/ck_tile/39_copy/test_tile_example.sh rename to tutorial/ck_tile/00_copy_kernel/test_tile_example.sh index 416338fac4..4ee5fdf15d 100755 --- a/example/ck_tile/39_copy/test_tile_example.sh +++ b/tutorial/ck_tile/00_copy_kernel/test_tile_example.sh @@ -4,7 +4,7 @@ set -euo pipefail -BIN="${BIN:-../../../build/bin/tile_example_copy}" +BIN="${BIN:-../../../build/bin/tile_tutorial_copy_kernel}" WARMUP="${WARMUP:-20}" REPEAT="${REPEAT:-100}" VALIDATE="${VALIDATE:-1}" diff --git a/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md b/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md new file mode 100644 index 0000000000..114fccfd56 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md @@ -0,0 +1,589 @@ +# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg + +## Overview + +The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates: +1. **Data movement** from DRAM → Registers → LDS +2. **GEMM computation** using data in LDS +3. **Iteration** over the K dimension when needed + +This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C. + +--- + +## Architecture: Problem and Policy + +Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern: + +### Problem: `PracticeGemmBlockPipelineProblem` +Contains: +- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType` +- **Shape information**: `BlockTile` and `WaveTile` dimensions + +### Policy: `PracticeGemmBlockPolicy` +Contains strategies for: +1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`) + - Defines how 256 threads in a block map to elements of a block tile + - Each thread knows which elements to load/store from DRAM to its registers + - We'll cover tile distribution construction in detail later + +2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`) + - Describes how data is logically organized in Local Data Share (LDS) + - Optimizes for bank conflict avoidance and efficient access patterns + - We'll cover LDS descriptor construction in detail later + +3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`) + - Returns the warp-level GEMM implementation + +--- + +## Inputs and Outputs + +```cpp +template +CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const +``` + +### Inputs: +- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock) +- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock) +- `num_loop`: Number of iterations along K dimension +- `p_smem`: Pointer to shared memory (LDS) + +### Output: +- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs) + +--- + +## Step-by-Step Walkthrough + +### Step 1: Create LDS Tensor Views + +```cpp +// A tile in LDS +ADataType* p_a_lds = static_cast(p_smem); +constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); +auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + +// B tile in LDS (placed after A in shared memory) +BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); +constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); +auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); +``` + +**What's happening:** +- We partition the shared memory (`p_smem`) into two regions: one for A, one for B +- We create **tensor views** over these LDS regions using descriptors from the policy +- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory + +**Memory Layout:** +``` +Shared Memory (LDS): +┌─────────────────────┬─────────────────────┐ +│ A Block Tile │ B Block Tile │ +│ (256×32 fp16) │ (128×32 fp16) │ +└─────────────────────┴─────────────────────┘ +↑ ↑ +p_a_lds p_b_lds +``` + +--- + +### Step 2: Create Tile Windows for Data Movement + +We create **6 tile windows** for different purposes: + +#### 2a. DRAM → Registers (Load from DRAM) + +```cpp +auto a_copy_dram_window = make_tile_window( + a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), // 256×32 + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); // ← Tile distribution! +``` + +**Key Points:** +- `a_copy_dram_window` is a `tile_window_with_static_distribution` +- The **tile distribution** tells each thread which elements to load from DRAM +- This window will **slide along the K dimension** in the loop + +#### 2b. Registers → LDS (Store to LDS) + +```cpp +auto a_copy_lds_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), // 256×32 + {0, 0}, // Origin at (0, 0) in LDS + a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM! +``` + +**Key Points:** +- Uses the **same tile distribution** as `a_copy_dram_window` +- This ensures each thread stores to LDS in the same pattern it loaded from DRAM +- Origin is always `{0, 0}` because LDS is reused for each K iteration + +#### 2c. LDS → Registers (GEMM Input) + +```cpp +auto a_lds_gemm_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), + {0, 0}); // No tile distribution! +``` + +**Key Points:** +- This is a `tile_window_with_static_lengths` (no explicit distribution) +- Used as input to the warp-level GEMM +- The warp GEMM will handle its own thread mapping internally + +**Similar windows are created for B:** +- `b_copy_dram_window`: Load B from DRAM +- `b_copy_lds_window`: Store B to LDS +- `b_lds_gemm_window`: Read B from LDS for GEMM + +--- + +### Step 3: Create Distributed Tensors (VGPRs) + +```cpp +using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); +using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + +using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); +using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + +ABlockTile a_block_tile; // Per-thread registers for A +BBlockTile b_block_tile; // Per-thread registers for B +``` + +#### What is `make_static_distributed_tensor`? + +**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**. + +**Key Properties:** +1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers +2. **Compile-time sized**: Buffer size determined by tile distribution at compile time +3. **Zero-overhead**: All indexing and layout transformations happen at compile time + +**How it works:** + +```cpp +template +struct static_distributed_tensor +{ + using DataType = remove_cvref_t; + using StaticTileDistribution = remove_cvref_t; + + // Calculate per-thread storage size from tile distribution + using ThreadTensorDesc = + remove_cvref_t; + + static constexpr index_t kThreadElementSpaceSize = + ThreadTensorDesc{}.get_element_space_size(); + + // Per-thread register array (VGPRs) + thread_buffer thread_buf_; +}; +``` + +**The tile distribution defines:** +- **Which elements each thread owns** in the tile +- **How many elements** each thread stores (buffer size) +- **How elements are laid out** in each thread's registers + +**Concrete Example for 256×32 tile with 256 threads:** + +``` +Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values) +Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values) +Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values) +... +Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values) +``` + +**Collectively:** +- All 256 threads together hold the **entire 256×32 tile** (8192 elements) +- Each thread's buffer lives in its **own VGPRs** +- No two threads own the same element + +**Distributed Ownership Analogy:** +Think of a tile as a **jigsaw puzzle**: +- The **tile distribution** is the cutting pattern +- Each **thread** gets one puzzle piece (its slice) +- Each **`static_distributed_tensor`** is a box holding all pieces +- Each thread's **`thread_buf_`** is its individual piece in its own registers + +--- + +### Step 4: The GEMM Loop + +```cpp +// Initialize C accumulator to zero +auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; +tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + +index_t iCounter = num_loop; // Number of K iterations + +while(iCounter > 0) +{ + // 1. Load from DRAM to registers + a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs + b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs + + // 2. Move windows for next iteration + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32) + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32) + + // 3. Store from registers to LDS + store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS + store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS + + // 4. Synchronize threads (ensure all data is in LDS) + block_sync_lds(); + + // 5. Compute GEMM using data in LDS + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + // 6. Synchronize threads (before overwriting LDS in next iteration) + block_sync_lds(); + + iCounter--; +} + +return c_block_tile; // Return accumulated result in registers +``` + +--- + +## Detailed Loop Breakdown + +### Phase 1: Load (DRAM → VGPRs) + +```cpp +a_block_tile = load_tile(a_copy_dram_window); +``` + +**What happens:** +1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution) +2. Data is loaded into **per-thread registers** (VGPRs) +3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once) + +**Example for Thread 0:** +``` +Thread 0 loads: + A[0,0:7] (8 fp16 values, one vector load) + A[1,0:7] (8 fp16 values, one vector load) + ... +``` + +### Phase 2: Move Windows + +```cpp +constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); +move_tile_window(a_copy_dram_window, a_dram_tile_window_step); +``` + +**What happens:** +- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example) +- This prepares for the next K iteration +- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ... + +**Visualization for Problem Size 512×256×64:** +``` +Matrix A (512×64): +┌─────────────────────────────────────┐ +│ Block 0: rows 0-255 │ +│ ┌──────────┬──────────┐ │ +│ │ K=0:31 │ K=32:63 │ │ ← Window slides right +│ │ Iter 0 │ Iter 1 │ │ +│ └──────────┴──────────┘ │ +└─────────────────────────────────────┘ +``` + +### Phase 3: Store (VGPRs → LDS) + +```cpp +store_tile(a_copy_lds_window, a_block_tile); +``` + +**What happens:** +1. Each thread writes **its elements** from registers to LDS +2. Uses the **same distribution** as the DRAM load +3. Data is now in **shared memory**, accessible to all threads in the block + +**Why this step?** +- GEMM computation needs **all threads** to access **all data** +- Registers are per-thread; LDS is shared across the block +- LDS acts as a "staging area" for collaborative computation + +### Phase 4: Synchronize + +```cpp +block_sync_lds(); +``` + +**What happens:** +- All threads in the block **wait** until everyone has finished storing to LDS +- Ensures no thread starts reading from LDS before all writes are complete +- Critical for correctness! + +### Phase 5: GEMM Computation + +```cpp +block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); +``` + +**What happens:** +1. The warp-level GEMM reads data from LDS +2. Performs matrix multiplication using MFMA instructions +3. Accumulates results into `c_block_tile` (in registers) + +**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results. + +### Phase 6: Synchronize Again + +```cpp +block_sync_lds(); +``` + +**What happens:** +- Ensures all threads have finished reading from LDS +- Safe to overwrite LDS in the next iteration + +--- + +## Memory Flow Diagram + +``` +Iteration 0 (K=0:31): +┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ +│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ +│ A[0:255,│ │ (per- │ │ A_block │ +│ 0:31] │ │ thread) │ │ │ +└─────────┘ └──────────┘ └─────────┘ + │ + │ block_gemm + ↓ + ┌──────────┐ + │ c_block_ │ + │ tile │ + │ (VGPRs) │ + └──────────┘ + +Iteration 1 (K=32:63): +┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ +│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ +│ A[0:255,│ │ (per- │ │ A_block │ +│ 32:63] │ │ thread) │ │ (reused)│ +└─────────┘ └──────────┘ └─────────┘ + │ + │ block_gemm + ↓ + ┌──────────┐ + │ c_block_ │ + │ tile │ + │ (accum.) │ + └──────────┘ +``` + +--- + +## Example: Problem Size 512×256×64 + +### Block 0 Computation + +**Input:** +- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially +- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed) +- `num_loop`: 2 (since K=64, KPerBlock=32) + +**Iteration 0:** +1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs +2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63] +3. Store to LDS +4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T` + +**Iteration 1:** +1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs +2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends) +3. Store to LDS +4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T` + +**Output:** +- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers + +--- + +## Key Concepts Summary + +### 1. Tile Distribution +- **Maps threads to data elements** for load/store operations +- Each thread knows exactly which elements it's responsible for +- Enables **parallel, vectorized** memory access +- **Same distribution** used for DRAM load and LDS store + +### 2. Static Distributed Tensor +- **Per-thread register storage** (VGPRs) +- Each thread owns a **different slice** of the tile +- **Compile-time sized** for zero-overhead abstraction +- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile` + +### 3. Tile Window Movement +- Windows **slide** over larger tensors +- Enables iteration over the K dimension +- `move_tile_window(window, step)` updates the origin + +### 4. LDS as Staging Area +- **Shared memory** accessible to all threads in a block +- Required because GEMM needs all threads to access all data +- **Reused** across K iterations (same LDS buffer) + +### 5. Synchronization +- `block_sync_lds()` ensures memory consistency +- **Before GEMM**: All stores to LDS are complete +- **After GEMM**: All reads from LDS are complete + +--- + +## Deep Dive: `static_distributed_tensor` Mechanics + +### How Tile Distribution Creates Per-Thread Storage + +When you call: +```cpp +using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); +ABlockTile a_block_tile; +``` + +**Step 1: Extract Thread Tensor Descriptor** + +The tile distribution contains a `ys_to_d_descriptor` that maps: +- **Y dimensions** (logical tile coordinates, e.g., M, K) +- **D dimension** (per-thread register index, linearized) + +```cpp +using ThreadTensorDesc = + decltype(StaticTileDistribution{}.get_ys_to_d_descriptor()); +``` + +**Step 2: Calculate Per-Thread Buffer Size** + +```cpp +static constexpr index_t kThreadElementSpaceSize = + ThreadTensorDesc{}.get_element_space_size(); + +static constexpr index_t get_thread_buffer_size() +{ + return kThreadElementSpaceSize / PackedSize; +} +``` + +**Example:** +- 256×32 tile distributed across 256 threads +- Each thread owns 32 elements (one row) +- `thread_buffer_size = 32` (for PackedSize=1) + +**Step 3: Allocate Thread Buffer** + +```cpp +thread_buffer thread_buf_; +``` + +This is essentially: +```cpp +fp16_t data[32]; // Per-thread register array (VGPRs) +``` + +### Usage in Load/Store Operations + +**Load from DRAM:** +```cpp +a_block_tile = load_tile(a_copy_dram_window); +``` + +What happens internally: +1. Each thread queries the tile distribution: "Which elements do I own?" +2. Thread 0 learns it owns A[0,0:31] +3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]` +4. All 256 threads do this **in parallel** + +**Store to LDS:** +```cpp +store_tile(a_copy_lds_window, a_block_tile); +``` + +What happens internally: +1. Each thread reads from its `a_block_tile.thread_buf_` +2. Thread 0 writes A[0,0:31] from its registers to LDS +3. All 256 threads do this **in parallel** +4. After `block_sync_lds()`, the entire tile is in shared LDS + +### Distributed Indexing + +The `static_distributed_tensor` supports compile-time indexing: + +```cpp +// Access using distributed indices +auto value = a_block_tile(tile_distributed_index{}); +``` + +Internally: +1. Convert distributed index → Y index (logical tile coordinates) +2. Calculate buffer offset using `ThreadTensorDesc` +3. Access `thread_buf_[offset]` + +All of this happens **at compile time** with zero runtime overhead! + +### Why This Design? + +**Benefits:** +1. **Parallel Memory Access**: All threads load/store simultaneously +2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once) +3. **Zero Overhead**: All indexing resolved at compile time +4. **Type Safety**: Distribution mismatch caught at compile time +5. **Register Pressure**: Compiler knows exact VGPR usage + +**Trade-offs:** +- Requires compile-time tile sizes +- Distribution must be static +- More complex type system + +### Memory Hierarchy Summary + +``` +┌─────────────────────────────────────────────────────────────┐ +│ DRAM (Global Memory) │ +│ Full matrices A, B, C │ +└─────────────────────────────────────────────────────────────┘ + │ + │ load_tile (parallel, vectorized) + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ VGPRs (Per-Thread Registers) │ +│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │ +│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │ +│ ... │ +│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │ +│ │ +│ ← static_distributed_tensor manages this distribution │ +└─────────────────────────────────────────────────────────────┘ + │ + │ store_tile (parallel, vectorized) + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ LDS (Shared Memory) │ +│ Entire block tile (256×32) │ +│ Accessible to all threads in block │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Key Insight:** +`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time. + + + diff --git a/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt new file mode 100644 index 0000000000..e16977921a --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_tutorial_naive_gemm EXCLUDE_FROM_ALL practice_gemm.cpp) + +target_compile_options(tile_tutorial_naive_gemm PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) + +add_dependencies(tutorials tile_tutorial_naive_gemm) \ No newline at end of file diff --git a/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md b/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md new file mode 100644 index 0000000000..43cb01fb36 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md @@ -0,0 +1,618 @@ +# Host-Level Pipeline: Orchestrating Block-Level GEMM + +This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation. + +## Overview + +The host-level pipeline is responsible for: +1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C +2. **Block-to-tile mapping**: Assigning each thread block to a specific tile +3. **Creating tile windows**: Establishing sliding windows over tensor views +4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM +5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM + +```cpp +template +struct PracticeGemmHostPipeline +{ + template + CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, + const BDRAMTensorView& b_dram, + CDRAMTensorView& c_dram) const + { + // 1. Calculate problem dimensions and tile coverage + // 2. Map thread block to tile coordinates + // 3. Create tile windows over A and B + // 4. Call block-level pipeline to compute + // 5. Store result to C + } +}; +``` + +--- + +## Step 1: Calculate Problem Dimensions and Tile Coverage + +```cpp +// Size of the entire problem +const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K +const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N +const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + +// Size of the block tile +const auto MPerBlock = BlockTile::at(number<0>{}); // 256 +const auto NPerBlock = BlockTile::at(number<1>{}); // 128 +const auto KPerBlock = BlockTile::at(number<2>{}); // 32 + +// Number of block tiles needed to cover C matrix +const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2 +const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2 +``` + +### What's Happening: + +1. **Extract problem dimensions** from tensor descriptors: + - `M = 512`: Rows in A and C + - `N = 256`: Columns in B and C + - `K = 64`: Inner dimension (columns of A, rows of B) + +2. **Get block tile sizes** from the `BlockTile` configuration: + - `MPerBlock = 256`: Each block processes 256 rows + - `NPerBlock = 128`: Each block processes 128 columns + - `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration + +3. **Calculate tile coverage**: + - `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction + - `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction + - **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**! + +### Visual Representation: + +``` +Matrix C (512 × 256): +┌──────────────────────┬──────────────────────┐ +│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2 +│ 256×128 │ 256×128 │ +│ Block 0 │ Block 1 │ +│ │ │ +├──────────────────────┼──────────────────────┤ +│ Tile (1,0) │ Tile (1,1) │ +│ 256×128 │ 256×128 │ +│ Block 2 │ Block 3 │ +│ │ │ +└──────────────────────┴──────────────────────┘ + ↑ + num_tile_m = 2 + +Total blocks needed = 2 × 2 = 4 blocks + +Each block computes one 256×128 tile of the output matrix C. +``` + +### How Blocks Cover Matrices A and B: + +``` +Matrix A (512 × 64): Matrix B (256 × 64): +┌─────────────┬──────┐ ┌─────────────┬──────┐ +│ Block 0,2 │ K │ │ Block 0,1 │ K │ +│ uses rows │ → │ │ uses rows │ → │ +│ 0-255 │ │ │ 0-127 │ │ +├─────────────┼──────┤ ├─────────────┼──────┤ +│ Block 1,3 │ K │ │ Block 2,3 │ K │ +│ uses rows │ → │ │ uses rows │ → │ +│ 256-511 │ │ │ 128-255 │ │ +└─────────────┴──────┘ └─────────────┴──────┘ + 256 rows 64 cols 128 rows 64 cols + +Each block needs to iterate over K dimension (64/32 = 2 iterations) +``` + +--- + +## Step 2: Map Thread Block to Tile Coordinates + +```cpp +// Get block id (0 to total_blocks - 1) +const auto id_block = get_block_id(); + +// Map block id to 2D tile coordinates +const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); +const auto tile_id = block2tile(id_block); + +const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate +const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate +``` + +### What's Happening: + +Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates. + +### The `MakeBlock2TileMap` Function: + +```cpp +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + // Create a merge transform: (N0, M0) → linear index + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + // Convert linear block_id back to 2D coordinates + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + // Return (m_idx, n_idx) - note the swap! + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; +} +``` + +### In Our Example (2×2 Grid): + +```cpp +// Block 0: +id_block = 0 +tile_id = block2tile(0) = (0, 0) // Top-left tile +tile_id_m = 0, tile_id_n = 0 + +// Block 1: +id_block = 1 +tile_id = block2tile(1) = (1, 0) // Bottom-left tile +tile_id_m = 1, tile_id_n = 0 + +// Block 2: +id_block = 2 +tile_id = block2tile(2) = (0, 1) // Top-right tile +tile_id_m = 0, tile_id_n = 1 + +// Block 3: +id_block = 3 +tile_id = block2tile(3) = (1, 1) // Bottom-right tile +tile_id_m = 1, tile_id_n = 1 +``` + +**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing! + +--- + +## Step 3: Calculate Tile Origin and Create Tile Windows + +```cpp +// Calculate the starting position of this tile in the global matrix +const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256 +const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128 + +// Create tile windows over A and B tensor views +const auto a_block_window = make_tile_window( + a_dram, // Tensor view over A + make_tuple(number{}, number{}), // Window size: 256×32 + {tile_origin_m, 0} // Origin: varies by block +); + +const auto b_block_window = make_tile_window( + b_dram, // Tensor view over B + make_tuple(number{}, number{}), // Window size: 128×32 + {tile_origin_n, 0} // Origin: varies by block +); +``` + +### Tile Origins for Each Block: + +```cpp +// Block 0 (Tile 0,0): +tile_origin_m = 0 * 256 = 0 +tile_origin_n = 0 * 128 = 0 +a_block_window origin: (0, 0) → covers A rows 0-255 +b_block_window origin: (0, 0) → covers B rows 0-127 + +// Block 1 (Tile 1,0): +tile_origin_m = 1 * 256 = 256 +tile_origin_n = 0 * 128 = 0 +a_block_window origin: (256, 0) → covers A rows 256-511 +b_block_window origin: (0, 0) → covers B rows 0-127 + +// Block 2 (Tile 0,1): +tile_origin_m = 0 * 256 = 0 +tile_origin_n = 1 * 128 = 128 +a_block_window origin: (0, 0) → covers A rows 0-255 +b_block_window origin: (128, 0) → covers B rows 128-255 + +// Block 3 (Tile 1,1): +tile_origin_m = 1 * 256 = 256 +tile_origin_n = 1 * 128 = 128 +a_block_window origin: (256, 0) → covers A rows 256-511 +b_block_window origin: (128, 0) → covers B rows 128-255 +``` + +### What are Tile Windows? + +A **tile window** is a **sliding window** over a larger tensor view. It: +- Defines a **rectangular region** within the tensor +- Has a **fixed size** (e.g., 256×32 for A) +- Has an **origin** (starting position) +- Can be **moved** to access different regions +### Visual Representation (Block 0 Example): + +``` +Matrix A (512 × 64): Matrix B (256 × 64): +┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │ +│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │ +│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │ +│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │ +│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │ +│ │ │ ├─────────────┼─────────────┤ +├─────────────┼─────────────┤ │ │ │ +│ │ │ │ │ │ +│ │ │ │ │ │ +│ │ │ │ │ │ +└─────────────┴─────────────┘ └─────────────┴─────────────┘ + Origin: (0, 0) Origin: (0, 0) + Covers rows 0-255 Covers rows 0-127 + Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration) +``` + +**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration. + +### Tile Window Properties: + +```cpp +// Tile window structure (conceptual): +struct tile_window { + TensorView& tensor_view; // Reference to underlying tensor + Tuple window_lengths; // Size of the window (256, 32) + MultiIndex window_origin; // Starting position (0, 0) + + // Can move the window: + void move(MultiIndex step); // Shift window by step + + // Access data through the window: + auto load(); // Load data from windowed region +}; +``` + + +### Tile Window Movement: Iterating Over K Dimension + +In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension: + +``` +Matrix A (512 × 64) - Block 0's view: +┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ +│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K +│ ┃ 256×32 ┃ │ ║ 256×32 ║ │ +│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ +│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ +├─────────────┼─────────────┤ +│ │ │ +│ Block 1's │ │ +│ region │ │ +└─────────────┴─────────────┘ + +Matrix B (256 × 64) - Block 0's view: +┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ +│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ +│ ┃ 128×32 ┃ │ ║ 128×32 ║ │ +│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ +│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ +├─────────────┼─────────────┤ +│ Block 2's │ │ +│ region │ │ +└─────────────┴─────────────┘ +``` + +### How Windows Move (Conceptual - handled by block pipeline): + +```cpp +// Iteration 0: +a_block_window origin: (tile_origin_m, 0) // K columns 0-31 +b_block_window origin: (tile_origin_n, 0) // K columns 0-31 +// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] + +// Move windows to next K position: +move_tile_window(a_block_window, {0, 32}); +move_tile_window(b_block_window, {0, 32}); + +// Iteration 1: +a_block_window origin: (tile_origin_m, 32) // K columns 32-63 +b_block_window origin: (tile_origin_n, 32) // K columns 32-63 +// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] + +// Final result: +// C_tile = C_partial_0 + C_partial_1 +``` + +**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C. + +--- + +## Step 4: Delegate to Block-Level Pipeline + +```cpp +// Get the block-level pipeline from policy +constexpr auto block_gemm_pipeline = + Policy::template GetPracticeGemmBlockPipeline(); + +// Calculate number of K iterations needed +int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2 + +// Allocate shared memory (LDS) for block-level computation +__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; + +// Call block-level pipeline to compute C tile +const auto c_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); +``` + +### What's Happening: + +1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation +2. **Calculate K iterations**: How many times to iterate over the K dimension + - In our example: `K=64, KPerBlock=32` → **2 iterations** + - Each iteration processes 32 elements of the K dimension + - Results are accumulated across iterations + +3. **Allocate shared memory**: + - `__shared__` declares memory shared by all threads in the block + - `GetStaticLDSSize()` returns the required size in bytes + - This memory is used for: + - Staging data from DRAM → LDS + - Cooperative loading by threads + - Fast access during computation + +4. **Execute block pipeline**: + - Takes A and B tile windows as input + - Performs the GEMM computation: `C_tile = A_tile × B_tile` + - Returns result in `c_block_tile` (stored in VGPRs - registers) + +### Memory Hierarchy During Computation: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ DRAM (Global Memory) - Slowest, Largest │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ A matrix │ │ B matrix │ │ C matrix │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + ↓ load ↓ load ↑ store +┌─────────────────────────────────────────────────────────────┐ +│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │ +│ ┌─────────────┐ ┌─────────────┐ │ +│ │ A_tile │ │ B_tile │ ← Staged here │ +│ │ (p_smem) │ │ (p_smem) │ │ +│ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + ↓ load ↓ load +┌─────────────────────────────────────────────────────────────┐ +│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ c_block_tile (accumulated result) │ │ +│ │ Computation happens here using MFMA instructions │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Block Pipeline Responsibilities: + +The block pipeline (called here) will: +1. Load A and B tiles from DRAM → LDS (cooperative loading) +2. Distribute work among warps +3. Each warp loads its portion from LDS → VGPRs +4. Perform MFMA operations: `C += A × B` +5. Accumulate results in VGPRs +6. Return final `c_block_tile` in registers + +--- + +## Step 5: Store Results to DRAM + +```cpp +// Create a tile window over C for writing results +auto c_window = make_tile_window( + c_dram, // Tensor view over C + make_tuple(number{}, number{}), // Window size: 256×128 + {tile_origin_m, tile_origin_n} // Origin: varies by block +); + +// Store computed tile from VGPRs to DRAM +store_tile(c_window, c_block_tile); +``` + +### C Window Origins for Each Block: + +```cpp +// Block 0: Writes to top-left tile +c_window origin: (0, 0) → writes to C[0:255, 0:127] + +// Block 1: Writes to bottom-left tile +c_window origin: (256, 0) → writes to C[256:511, 0:127] + +// Block 2: Writes to top-right tile +c_window origin: (0, 128) → writes to C[0:255, 128:255] + +// Block 3: Writes to bottom-right tile +c_window origin: (256, 128) → writes to C[256:511, 128:255] +``` + +### What's Happening: + +1. **Create C tile window**: + - Size: 256×128 (matches our block tile size) + - Origin: Varies by block - each block writes to its assigned region + - This window defines **where** to write the results + +2. **Store tile to DRAM**: + - `c_block_tile`: Computed results in VGPRs (registers) + - `c_window`: Destination window in DRAM + - `store_tile()`: Efficiently writes data from registers → DRAM + +### The `store_tile` Function: + +Recall from our earlier discussion, `store_tile` does: + +```cpp +template +void store_tile(TileWindow& tile_window_tmp, + const DistributedTensor& dstr_tensor) +{ + // 1. Extract tile distribution from distributed tensor + using TileDstr = typename DistributedTensor::TileDistribution; + + // 2. Upgrade simple tile window to one with distribution + auto tile_window = make_tile_window( + tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + TileDstr{} // Add distribution info + ); + + // 3. Store using vectorized writes + tile_window.store(dstr_tensor); +} +``` + +### Memory Flow: + +``` +VGPRs (Registers) DRAM (Global Memory) +┌─────────────────────┐ ┌─────────────────────┐ +│ c_block_tile │ │ C matrix │ +│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │ +│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │ +│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │ +│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │ +│ └───┴───┴───┴───┘ │ │ │ │ │ +│ Distributed across │ │ └───────────────┘ │ +│ threads/warps │ │ Origin: (0, 0) │ +└─────────────────────┘ └─────────────────────┘ + +Each thread writes its portion using vector stores (e.g., float4) +``` + +### Store Optimization: + +The `store_tile` function: +- Uses **vectorized stores** (write multiple elements at once) +- Ensures **coalesced memory access** (adjacent threads write adjacent memory) +- Respects **tile distribution** (each thread knows what data it owns) +- Handles **out-of-bounds** checking (for partial tiles at boundaries) + +--- + +## Complete Flow Visualization + +Let's trace the complete flow for **Block 0** (other blocks follow the same pattern): + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Step 1: Calculate Tile Coverage │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ M=512, N=256, K=64 │ │ +│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │ +│ │ num_tile_m = ceil(512/256) = 2 │ │ +│ │ num_tile_n = ceil(256/128) = 2 │ │ +│ │ Total blocks needed = 2 × 2 = 4 blocks │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 2: Map Block to Tile (Block 0 example) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Block ID: 0 │ │ +│ │ Tile coordinates: (0, 0) - top-left tile │ │ +│ │ Tile origin: (0, 0) │ │ +│ │ │ │ +│ │ (Blocks 1,2,3 get different tile coordinates) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 3: Create Tile Windows │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ a_block_window: 256×32 starting at (0,0) over A │ │ +│ │ b_block_window: 128×32 starting at (0,0) over B │ │ +│ │ Windows initially cover K columns 0-31 │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 4: Execute Block Pipeline (2 K iterations) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Allocate shared memory (LDS) │ │ +│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │ +│ │ │ │ +│ │ K Iteration 0 (K=0-31): │ │ +│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │ +│ │ └─ Move windows: {0, 32} │ │ +│ │ │ │ +│ │ K Iteration 1 (K=32-63): │ │ +│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │ +│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │ +│ │ │ │ +│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 5: Store Results │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Create c_window: 256×128 starting at (0,0) over C │ │ +│ │ store_tile(c_window, c_block_tile) │ │ +│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │ +│ │ │ │ +│ │ Block 0 writes to C[0:255, 0:127] │ │ +│ │ (Other blocks write to their respective regions) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + +All 4 blocks execute in parallel, each computing its assigned 256×128 tile! +``` + +--- + +## Key Concepts Summary + +### 1. **Tile Coverage** +- Determines how many thread blocks are needed +- Each block processes one tile of the output matrix C +- Calculated as `ceil(dimension / tile_size)` + +### 2. **Block-to-Tile Mapping** +- Maps linear block ID to 2D tile coordinates +- Uses column-major ordering for better memory coalescing +- Each block knows which tile it's responsible for + +### 3. **Tile Windows** +- **Sliding windows** over larger tensor views +- Define a rectangular region with fixed size and movable origin +- Provide efficient, structured access to tensor data +- Can be moved to access different regions (e.g., for K iterations) + +### 4. **Memory Hierarchy** +- **DRAM (Global)**: Largest, slowest - stores full matrices +- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access +- **VGPRs (Registers)**: Smallest, fastest - performs computation + +### 5. **Data Flow** +``` +DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM + ↑ ↓ + A, B matrices C matrix +``` + +--- + +## Next Steps + +The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore: +- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed +- **Warp-level pipeline**: How warps perform MFMA operations +- **Memory optimization**: LDS usage, bank conflicts, coalescing + +The host level provides the **orchestration**, while the block and warp levels provide the **execution**! + diff --git a/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md b/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md new file mode 100644 index 0000000000..7cd0d06fc5 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md @@ -0,0 +1,464 @@ +# PracticeGemmKernel: Understanding the Kernel Entry Point + +This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views. + +## Overview + +The `PracticeGemmKernel` is a templated struct that: +1. Takes raw device memory pointers for matrices A, B, and C +2. Wraps them into **tensor views** - logical, structured views over physical memory +3. Dispatches to the host-level pipeline for computation + +```cpp +template +struct PracticeGemmKernel +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, + const typename Problem::BDataType* p_b, + typename Problem::CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c) const + { + // Step 1: Create tensor views over raw memory + auto a_dram = make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); + + auto b_dram = make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); + + const auto c_dram = make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); + + // Step 2: Dispatch to host-level pipeline + PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); + } +}; +``` + +--- + +## What are Tensor Views? + +A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor. + +### Key Components of a Tensor View: + +1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers) +2. **Raw Pointer**: Points to the actual data in memory +3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A) +4. **Strides**: How to navigate through memory to access elements +5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction +6. **Guaranteed Vector Stride**: The stride of those vectorizable elements + +--- + +## The Memory Abstraction Hierarchy + +CK Tile uses a three-layer abstraction to go from raw memory to structured tensors: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Layer 3: TENSOR VIEW │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ • Logical multi-dimensional structure │ │ +│ │ • Shape: (M, K) = (256, 32) │ │ +│ │ • Strides: (32, 1) for row-major layout │ │ +│ │ • Provides: operator[], coordinate-based access │ │ +│ │ • Knows: How to map (i,j) → linear offset │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ ↓ wraps │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Layer 2: BUFFER VIEW │ │ +│ │ ┌─────────────────────────────────────────────────────┐ │ │ +│ │ │ • Linear view of memory │ │ │ +│ │ │ • Pointer: p_data_ → device memory │ │ │ +│ │ │ • Size: Total number of elements │ │ │ +│ │ │ • Address space: global/LDS/generic │ │ │ +│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │ +│ │ └─────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ ↓ wraps │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Layer 1: RAW PHYSICAL MEMORY │ │ +│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │ +│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │ +│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │ +│ │ ↑ │ │ +│ │ p_a (raw pointer from hipMalloc) │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## Deep Dive: `make_naive_tensor_view` + +Let's break down the function call for matrix A: + +```cpp +auto a_dram = make_naive_tensor_view( + p_a, // Raw pointer to device memory + make_tuple(M, K), // Shape: (256, 32) + make_tuple(stride_a, 1), // Strides: (32, 1) - row-major + number<8>{}, // Guaranteed vector length + number<1>{} // Guaranteed vector stride +); +``` + +### Function Signature: + +```cpp +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_view(DataType* __restrict__ p, + const tuple& lengths, + const tuple& strides, + number = number<-1>{}, + number = number<-1>{}) +{ + // Step 1: Create tensor descriptor (shape + stride information) + auto desc = make_naive_tensor_descriptor(lengths, + strides, + number{}, + number{}); + + // Step 2: Create buffer view (pointer + size + address space) + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); + + // Step 3: Combine into tensor view + return tensor_view{buffer_view, desc}; +} +``` + +--- + +## Parameter Breakdown + +### 1. **Template Parameter: `address_space_enum::global`** + +Specifies where the memory lives: +- `global`: GPU global memory (DRAM) - slowest but largest +- `lds`: Local Data Share (shared memory) - fast, limited size +- `generic`: Generic address space +- `vgpr`: Vector General Purpose Registers - fastest, smallest + +In our case, `global` means the data is in GPU DRAM. + +### 2. **`p_a` - Raw Pointer** + +The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data. + +### 3. **`make_tuple(M, K)` - Shape/Lengths** + +Defines the logical dimensions of the tensor: +- For matrix A: `(256, 32)` means 256 rows, 32 columns +- This is the **logical view**, independent of how data is physically laid out + +### 4. **`make_tuple(stride_a, 1)` - Strides** + +Defines how to navigate through memory: +- **Stride for dimension 0 (rows)**: `stride_a = K = 32` + - To move to the next row, skip 32 elements +- **Stride for dimension 1 (columns)**: `1` + - To move to the next column, skip 1 element + +**Row-major layout example:** +``` +Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...] + ↑ ↑ + Row 0 starts here Row 1 starts here (offset = 32) + +To access element A[i][j]: + offset = i * stride_a + j * 1 + = i * 32 + j +``` + +### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length** + +This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."** + +#### Why is this important? + +Modern GPUs can load multiple elements in one instruction (vectorized loads): +- `float4`: Load 4 floats at once +- `float8`: Load 8 floats at once (if supported) + +By specifying `number<8>{}`, we're telling the system: +- "You can safely use vector loads of up to 8 elements" +- "The memory alignment and layout support this" + +**Example:** +```cpp +// Without vectorization (slow): +for (int j = 0; j < 8; j++) { + data[j] = memory[offset + j]; // 8 separate loads +} + +// With vectorization (fast): +float8 vec = *reinterpret_cast(&memory[offset]); // 1 load! +``` + +### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride** + +This specifies the **stride between consecutive vectorizable elements** in the last dimension. + +- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)" +- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively + +**Why does this matter?** + +For efficient vectorized loads, elements must be: +1. **Contiguous** (stride = 1) ✓ +2. **Aligned** properly in memory +3. **Within the same cache line** (ideally) + +If the stride were `2`, it would mean: +``` +A[i][0] is at offset 0 +A[i][1] is at offset 2 (not 1!) +A[i][2] is at offset 4 +``` +This would prevent efficient vectorization. + +--- + +## What is a Buffer View? + +A **buffer view** is the middle layer between raw memory and tensor view. It provides: + +### Core Responsibilities: + +1. **Memory Management** + - Holds the raw pointer: `T* p_data_` + - Tracks buffer size: `BufferSizeType buffer_size_` + - Knows the address space: `global`, `lds`, etc. + +2. **Vectorized Access** + ```cpp + template + CK_TILE_DEVICE VectorType get(index_t offset); + ``` + - Provides efficient vector loads/stores + - Handles alignment requirements + +3. **Bounds Checking** (optional) + ```cpp + template + CK_TILE_DEVICE auto get(index_t i, index_t linear_offset); + ``` + - Can optionally check if access is within bounds + - Returns invalid value (default 0) for out-of-bounds access + +4. **Address Space Awareness** + - Uses different load/store instructions based on address space + - Global memory: `global_load`, `global_store` + - LDS: `ds_read`, `ds_write` + +### Buffer View Structure: + +```cpp +template +struct buffer_view +{ + T* p_data_; // Raw pointer + BufferSizeType buffer_size_; // Total elements + remove_cvref_t invalid_element_value_; // Value for OOB access + + // Access operators + const T& operator[](index_t i) const; // Read + T& operator()(index_t i); // Write + + // Vectorized access + template + VectorType get(index_t offset); +}; +``` + +--- + +## Visual Example: Matrix A Memory Layout + +Let's visualize how matrix A (256×32, fp16) is organized: + +### Raw Physical Memory (Linear): +``` +GPU DRAM Address Space: +┌─────────────────────────────────────────────────────────────────┐ +│ Byte 0 │ +│ ↓ │ +│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │ +│ ↑ ↑ │ +│ Row 0 (32 elements) Row 1 (32 elements) │ +│ │ +│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │ +└─────────────────────────────────────────────────────────────────┘ + ↑ + p_a (raw pointer) +``` + +### Buffer View Layer: +``` +buffer_view +┌─────────────────────────────────────────────────────────────────┐ +│ p_data_ = p_a │ +│ buffer_size_ = 256 × 32 = 8,192 elements │ +│ address_space = global (DRAM) │ +│ │ +│ Provides: │ +│ • Linear indexing: buffer_view[i] → element at offset i │ +│ • Vectorized loads: get(offset) → load 4 fp16s at once│ +│ • Bounds checking: is offset < buffer_size_? │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Tensor View Layer: +``` +tensor_view +┌─────────────────────────────────────────────────────────────────┐ +│ Shape: (256, 32) │ +│ Strides: (32, 1) │ +│ Guaranteed vector length: 8 │ +│ Guaranteed vector stride: 1 │ +│ │ +│ Logical 2D View: │ +│ Col: 0 1 2 ... 31 │ +│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│ +│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │ +│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │ +│ ... │ +│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │ +│ │ +│ Provides: │ +│ • Multi-dimensional indexing: A[i][j] │ +│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │ +│ • Tile window creation: Extract sub-tensors │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Complete Flow: Raw Memory → Tensor View + +Let's trace the complete transformation for matrix A: + +### Step 1: Kernel Launch (Host Side) +```cpp +// On host: Allocate device memory +hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer + +// Launch kernel +kernel<<>>(p_a, p_b, p_c, M, N, K, ...); +``` + +### Step 2: Inside Kernel (Device Side) +```cpp +// Receive raw pointer +const fp16_t* p_a; // Points to GPU DRAM + +// Step 2a: Create tensor descriptor +auto desc = make_naive_tensor_descriptor( + make_tuple(256, 32), // Shape + make_tuple(32, 1), // Strides + number<8>{}, // Vector length + number<1>{} // Vector stride +); +// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8" + +// Step 2b: Create buffer view +auto buffer_view = make_buffer_view( + p_a, // Raw pointer + 256 * 32 // Total elements +); +// buffer_view now wraps p_a with size and address space info + +// Step 2c: Create tensor view +auto a_dram = tensor_view{buffer_view, desc}; +// a_dram now provides structured, multi-dimensional access to p_a +``` + +### Step 3: Using the Tensor View +```cpp +// Access element A[i][j] +auto value = a_dram[make_tuple(i, j)]; + +// Create a tile window (sub-tensor) +auto tile = make_tile_window( + a_dram, + make_tuple(16, 16), // 16×16 tile + make_tuple(0, 0) // Starting at origin +); + +// Load tile into registers with vectorization +auto tile_data = load_tile(tile); // Uses vector loads internally! +``` + +--- + +## Why This Abstraction? + +### Benefits: + +1. **Type Safety**: Can't accidentally access wrong dimensions +2. **Performance**: Compiler knows about vectorization opportunities +3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers) +4. **Maintainability**: Logical structure separate from physical layout +5. **Optimization**: Guaranteed vector properties enable aggressive optimizations + +### Example: Without Tensor Views (Manual Indexing) +```cpp +// Ugly, error-prone, hard to optimize: +for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j]; + // Hope the compiler vectorizes this? 🤞 + } +} +``` + +### Example: With Tensor Views (Clean, Optimized) +```cpp +// Clean, safe, automatically vectorized: +auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin); +auto tile_data = load_tile(tile); // Vectorized loads guaranteed! +``` + +--- + +## Summary + +The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction: + +1. **Raw Memory**: Linear array of bytes in GPU DRAM +2. **Buffer View**: Adds size, address space, and vectorized access +3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing + +This abstraction enables: +- ✅ Clean, readable code +- ✅ Type-safe multi-dimensional access +- ✅ Automatic vectorization +- ✅ Flexible memory space handling +- ✅ Efficient tile-based computation + +The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation! + diff --git a/tutorial/ck_tile/01_naive_gemm/README.md b/tutorial/ck_tile/01_naive_gemm/README.md new file mode 100644 index 0000000000..f2caf7d993 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/README.md @@ -0,0 +1,150 @@ +# CK Tile Practice GEMM Example + +This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system. + +## CK Tile API Structure + +In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**: + +1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices +2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads +3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue) + +## Overview + +This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing: + +- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch +- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM +- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory +- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation + +## Problem Setup and Data Flow + +### Problem Size Configuration +We set the problem size using the M, N and K variables: +```cpp +ck_tile::index_t M = 1024; // Number of rows in A and C +ck_tile::index_t N = 512; // Number of columns in B and C +ck_tile::index_t K = 256; // Number of columns in A, rows in B +``` + +### Host Matrix Creation +Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory: +```cpp +// Host tensors with proper strides +ck_tile::HostTensor a_host(a_lengths, a_strides); // M × K +ck_tile::HostTensor b_host(b_lengths, b_strides); // N × K +ck_tile::HostTensor c_host(c_lengths, c_strides); // M × N + +// Initialize with random data +ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); +ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + +// Allocate device memory and transfer data +ck_tile::DeviceMem a_device(a_host); +a_device.ToDevice(a_host.data()); +``` + +### PracticeGemmShape Configuration +A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile: + +```cpp +using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block +using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave +``` +- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix. + + +- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile. +- BlockTiles are further subdivided into WarpTiles. +- WarpTiles over A and B similarly work together to calculate the WarpTile of C. + +### Problem and Policy Composition +```cpp +// A Problem is composed from Shape and info about the data +using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + +// A Policy is created describing data-to-thread mapping +using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; + +// A Kernel is then composed of Problem and Policy +using gemm_kernel = ck_tile::PracticeGemmKernel; +``` + +### Kernel Launch +`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`: +```cpp +float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel( + gemm_kernel{}, // Kernel composed of Problem + Policy + kGridSize, // Grid dimensions + kBlockSize, // Block dimensions + 0, // Dynamic shared memory + // Kernel arguments: device buffers and problem dimensions + a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(), + M, N, K, stride_a, stride_b, stride_c)); +``` + +### Result Verification +The results from the kernel are compared with results from CPU based computation function: +```cpp +// CPU reference implementation +ck_tile::HostTensor c_host_ref(c_lengths, c_strides); +reference_basic_gemm(a_host, b_host, c_host_ref); + +// Device results +ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + +// Verify correctness +bool pass = ck_tile::check_err(c_host_dev, c_host_ref); +``` + +### Runtime Flow + +The main program (`practice_gemm.cpp`) is the entry point for the runtime flow: + +```cpp +int main() +{ + // 1. Define data types and problem sizes + using ADataType = ck_tile::half_t; + ck_tile::index_t M = 2048, N = 1024, K = 512; + + // 2. Create host tensors and initialize + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + + // 3. Allocate device memory and transfer data + ck_tile::DeviceMem a_device(a_host); + + // 4. Configure tile shapes + using BlockTile = ck_tile::sequence<256, 128, 32>; + using WaveTile = ck_tile::sequence<16, 16, 16>; + + // 5. Launch kernel + using gemm_kernel = ck_tile::PracticeGemmKernel; + float ave_time = ck_tile::launch_kernel(/*...*/); + + // 6. Verify results + bool pass = verify_results(a_host, b_host, c_host); + + // 7. Print performance metrics + print_performance_metrics(ave_time, M, N, K); +} +``` + +## Building and Running + +```bash +# From composable_kernel root directory +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ +make tile_example_practice_gemm -j + +# Run with sample sizes +./bin/tile_example_practice_gemm +``` +This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework. diff --git a/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md b/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md new file mode 100644 index 0000000000..d0b8400b9c --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md @@ -0,0 +1,506 @@ +# Practice GEMM: Step-by-Step Code Walkthrough + +This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API. + +## Overview + +We'll implement `C = A × B` where: +- `A` is an `M × K` matrix +- `B` is an `N × K` matrix (note: transposed layout) +- `C` is an `M × N` matrix + +The implementation uses a hierarchical tiling strategy with two levels: +1. **Block Tiles**: Processed by thread blocks +2. **Wave Tiles**: Processed by warps (wavefronts) within blocks + +--- + +## Step 1: Define Data Types + +```cpp +using ADataType = ck_tile::half_t; +using BDataType = ck_tile::half_t; +using CDataType = float; +using AccDataType = float; +``` + +**What's happening:** +- We use `half_t` (FP16) for input matrices A and B. +- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy +- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity +--- + +## Step 2: Define Problem Size + +```cpp +ck_tile::index_t M = 512; +ck_tile::index_t N = 256; +ck_tile::index_t K = 64; +ck_tile::index_t verification = 1; + +ck_tile::index_t stride_a = K; +ck_tile::index_t stride_b = K; +ck_tile::index_t stride_c = N; +``` + +**What's happening:** +- `M = 512`: Number of rows in A and C +- `N = 256`: Number of columns in B and C +- `K = 64`: Inner dimension (columns of A, rows of B) +- Strides define memory layout (row-major for A and C, transposed for B) + +**Memory Layout:** +``` +Matrix A (M×K): Matrix B (N×K): Matrix C (M×N): +[512 rows] [256 rows] [512 rows] +[64 cols] [64 cols] [256 cols] +stride = K stride = K stride = N +``` + +--- + +## Step 3: Create Host Tensors + +```cpp +auto a_lengths = std::array{M, K}; +auto b_lengths = std::array{N, K}; +auto c_lengths = std::array{M, N}; + +auto a_strides = std::array{stride_a, 1}; +auto b_strides = std::array{stride_b, 1}; +auto c_strides = std::array{stride_c, 1}; + +ck_tile::HostTensor a_host(a_lengths, a_strides); +ck_tile::HostTensor b_host(b_lengths, b_strides); +ck_tile::HostTensor c_host(c_lengths, c_strides); +``` + +**What's happening:** +- We create three tensors on the host (CPU) memory +- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`) +- `HostTensor` is a CK Tile utility class that manages CPU memory + +**Stride explanation:** +- For A: `stride_a = K` means moving to the next row requires skipping K elements +- For B: `stride_b = K` means B is stored in transposed format +- For C: `stride_c = N` means row-major layout + +--- + +## Step 4: Initialize Tensors with Random Data + +```cpp +ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); +ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); +c_host.SetZero(); +``` + +**What's happening:** +- A and B are filled with random values in the range [-5.0, 5.0] +- C is initialized to zero (will store the output) + +**Optional: Print Tensor Contents** +```cpp +// Commented out in the code, but available for debugging: +// a_host.print_first_n(10); // Print first 10 elements of A +``` + +The `print_first_n()` helper function can display tensor contents for debugging purposes. + +--- + +## Step 5: Allocate Device Memory and Transfer Data + +```cpp +ck_tile::DeviceMem a_device(a_host); +ck_tile::DeviceMem b_device(b_host); +ck_tile::DeviceMem c_device(c_host); +``` + +**What's happening:** +- `DeviceMem` allocates GPU memory matching the size of host tensors +- The constructor **automatically transfers data from host to device** +- This is a convenience wrapper around `hipMalloc` and `hipMemcpy` + +**Memory Flow:** +``` +CPU (Host) GPU (Device) +┌─────────┐ ┌─────────┐ +│ a_host │ ────────> │a_device │ +│ b_host │ ────────> │b_device │ +│ c_host │ ────────> │c_device │ +└─────────┘ └─────────┘ +``` + +--- + +## Step 6: Configure Hierarchical Tiling + +```cpp +using BlockTile = ck_tile::sequence<256, 128, 32>; +using WaveTile = ck_tile::sequence<16, 16, 16>; +``` + +**What's happening:** +- We define a two-level tiling hierarchy for the GEMM computation + +### Block Tile (256 × 128 × 32) +- **256**: M dimension per block (rows of A and C) +- **128**: N dimension per block (columns of B and C) +- **32**: K dimension per block (inner dimension) +- Each block tile is processed by one **thread block** (256 threads) + +### Wave Tile (16 × 16 × 16) +- **16 × 16**: Output tile dimensions (M × N) per warp iteration +- **16**: K dimension per warp iteration +- Each wave tile is processed by one **warp** (64 threads on AMD GPUs) + +**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile + +**Important Note:** +In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem. + +### Tiling Visualization: + +#### Matrix A (M × K = 256 × 32): +``` +┌─────────────────────────────────────┐ +│ One Block Tile (256 × 32) │ +│ ┌────┬────┐ │ +│ │16×│16× │ ← Wave tiles (16×16) │ +│ │ 16│ 16 │ in M×K space │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ ├────┼────┤ │ +│ │ .. │ .. │ 16 tiles in M │ +│ ├────┼────┤ 2 tiles in K │ +│ │ │ │ │ +│ └────┴────┘ │ +│ │ +└─────────────────────────────────────┘ +``` + +#### Matrix B (N × K = 128 × 32): +``` +┌──────────────────────────────┐ +│ One Block Tile (128 × 32) │ +│ ┌────┬────┐ │ +│ │16×│16× │ ← Wave tiles │ +│ │ 16│ 16 │ (16×16) │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ ├────┼────┤ 8 tiles in N │ +│ │ .. │ .. │ 2 tiles in K │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ └────┴────┘ │ +└──────────────────────────────┘ +``` + +#### Matrix C (M × N = 256 × 128) - Output: +``` +┌─────────────────────────────────────────────────┐ +│ One Block Tile (256 × 128) │ +│ │ +│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │ +│ │16× │ │ │ │ │ │ │ │ │ +│ │ 16 │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ └────┴────┴────┴────┴────┴────┴────┴────┘ │ +│ │ +│ 16 wave tiles in M direction │ +│ 8 wave tiles in N direction │ +│ Total: 128 wave tiles (16×16 each) │ +└─────────────────────────────────────────────────┘ +``` + +#### How Wave Tiles Combine (C = A × B): +``` +Matrix A Matrix B (stored transposed N×K) Matrix C +(256×32) (128×32) (256×128) + +Row of A tiles: Row of B tiles: One wave tile in C: +┌────┬────┐ ┌────┬────┐ ┌────┐ +│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16) +└────┴────┘ └────┴────┘ └────┘ + 16×16 each 16×16 each + +Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ + ↑ ↑ + K=0..15 K=16..31 + +Each wave tile in C is computed by: +- Taking one row of wave tiles from A (2 tiles along K) +- Taking one row of wave tiles from B (2 tiles along K) + Note: B is stored transposed (N×K), so a "row" in storage corresponds + to a "column" in the logical B^T matrix used in computation +- Performing dot product: Σ(A_k × B_k^T) for k=0,1 +``` + +**Key Insight:** +- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B +- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory +- This is the fundamental operation repeated across all 128 wave tiles in C +- Each warp computes one wave tile using MFMA instructions + +--- + +## Step 7: Create Shape, Problem, and Policy Structs + +```cpp +using PracticeGemmShape = ck_tile::PracticeGemmShape; +std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; + +using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + +using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; +``` + +**What's happening:** + +### 1. **Shape Struct** +Encapsulates all tile shape information (BlockTile and WaveTile dimensions). + +### 2. **Problem Struct** +Holds complete problem description: +- Data types (ADataType, BDataType, CDataType, AccDataType) +- Shape information (BlockTile, WaveTile) + +In more complex examples, this would also include: +- Data layouts (row-major, column-major) +- Mathematical operations (e.g., transposed GEMM) + +### 3. **Policy Struct** +Describes data movement and thread-to-data mapping: +- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions +- In more complex kernels, includes: + - DRAM access patterns + - LDS (Local Data Share) usage strategies + - Thread distribution within blocks + +**CK Tile Design Pattern:** +``` +Kernel = Problem + Policy + Epilogue + ↑ ↑ ↑ + (What) (How) (Post-processing) +``` + +--- + +## Step 8: Calculate Grid and Block Dimensions + +```cpp +ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * + ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); + +std::cout << "kGridSize: " << kGridSize << std::endl; + +constexpr ck_tile::index_t kBlockSize = 256; +constexpr ck_tile::index_t kBlockPerCU = 1; +``` + +**What's happening:** + +### Grid Size Calculation +```cpp +kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N) + = ceil(512 / 256) × ceil(256 / 128) + = 2 × 2 + = 4 thread blocks +``` + +Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction). + +### Block Configuration +- `kBlockSize = 256`: Each thread block has 256 threads + - 256 threads / 64 threads per warp = **4 warps per block** +- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity) + +**Thread Hierarchy:** +``` +GPU +└── 1 Thread Block (Grid) + └── 256 Threads + ├── Warp 0 (threads 0-63) + ├── Warp 1 (threads 64-127) + ├── Warp 2 (threads 128-191) + └── Warp 3 (threads 192-255) +``` + +--- + +## Step 9: Create and Launch the Kernel + +```cpp +using gemm_kernel = + ck_tile::PracticeGemmKernel; + +float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_device.GetDeviceBuffer()), + static_cast(b_device.GetDeviceBuffer()), + static_cast(c_device.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c)); +``` + +**What's happening:** + +### 1. Kernel Composition +```cpp +using gemm_kernel = ck_tile::PracticeGemmKernel; +``` +The kernel is composed from Problem and Policy structs, following the CK Tile design pattern. + +### 2. Kernel Launch +`launch_kernel()` is a CK Tile utility that: +- Launches the GPU kernel using HIP runtime +- Measures execution time +- Returns average execution time in milliseconds + +### 3. Launch Parameters +- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled +- **Grid size**: `kGridSize = 1` - number of thread blocks +- **Block size**: `kBlockSize = 256` - threads per block +- **Shared memory**: `0` - no dynamic shared memory in this example +- **Kernel arguments**: Device pointers and problem dimensions + +### 4. Kernel Execution Flow +``` +launch_kernel() calls gemm_kernel.operator()() + ↓ +PracticeGemmKernel::operator() + ↓ +Creates tensor views over device memory + ↓ +Calls block-level pipeline + ↓ +Block pipeline calls warp-level pipeline + ↓ +Warp pipeline calls MFMA instructions + ↓ +Results written back to C matrix +``` + +--- + +## Step 10: Verify Results + +```cpp +auto pass = true; + +if(verification) +{ + // Reference gemm on CPU + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + + // Copy GPU results back to host + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + c_device.FromDevice(c_host_dev.mData.data()); + + // Compare results + pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; +} +``` + +**What's happening:** + +### 1. CPU Reference Implementation +```cpp +reference_basic_gemm<...>(a_host, b_host, c_host_ref); +``` +Computes GEMM on CPU using a simple nested loop implementation (ground truth). + +### 2. Copy GPU Results to Host +```cpp +c_device.FromDevice(c_host_dev.mData.data()); +``` +Transfers the computed result from GPU memory back to CPU for comparison. + +### 3. Error Checking +```cpp +ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); +``` +Compares GPU and CPU results element-wise with tolerance: +- **Relative error**: 1e-3 (0.1%) +- **Absolute error**: 1e-3 + +**Verification Flow:** +``` +CPU GPU +┌─────────┐ ┌─────────┐ +│ a_host │ ────────> │a_device │ +│ b_host │ ────────> │b_device │ +└─────────┘ └─────────┘ + │ │ + ↓ ↓ +reference_gemm() GPU kernel + │ │ + ↓ ↓ +┌──────────┐ ┌──────────┐ +│c_host_ref│ │c_device │ +└──────────┘ └──────────┘ + │ │ + │ ↓ + │ FromDevice() + │ │ + ↓ ↓ + └────> check_err() <───┘ + │ + ↓ + Pass/Fail +``` + +--- + +## Complete Execution Flow Summary + +``` +1. Define data types (FP16 inputs, FP32 output) + ↓ +2. Set problem size (M=256, N=128, K=32) + ↓ +3. Create host tensors and initialize with random data + ↓ +4. Allocate device memory and transfer data (CPU → GPU) + ↓ +5. Configure hierarchical tiling (BlockTile, WaveTile) + ↓ +6. Create Shape, Problem, and Policy structs + ↓ +7. Calculate grid/block dimensions (1 block, 256 threads) + ↓ +8. Compose and launch kernel (Problem + Policy) + ↓ +9. Execute GEMM on GPU + │ ├─ Block-level pipeline + │ ├─ Warp-level pipeline + │ └─ MFMA instructions + ↓ +10. Verify results (compare GPU vs CPU reference) + ↓ +11. Calculate and print performance metrics + ↓ +12. Return success/failure +``` + +--- \ No newline at end of file diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..31fa4ac3eb --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { + +template +struct PracticeGemmBlockPipelineAGmemBGmemCreg +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + + using BlockTile = typename Problem::Shape::BlockTile; + using WaveTile = typename Problem::Shape::WaveTile; + + static constexpr index_t MPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t NPerBlock = BlockTile::at(number<1>{}); + static constexpr index_t KPerBlock = BlockTile::at(number<2>{}); + + static constexpr index_t MPerWave = WaveTile::at(number<0>{}); + static constexpr index_t NPerWave = WaveTile::at(number<1>{}); + static constexpr index_t KPerWave = WaveTile::at(number<2>{}); + + using BlockGemm = + remove_cvref_t())>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // ----------------------------------------------------------------------------------------- + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); + + // ------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // non-prefetch + index_t iCounter = num_loop; + + while(iCounter > 0) + { + a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers + b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS + store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers + block_sync_lds(); + + iCounter--; + } + + return c_block_tile; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..99c4379ad8 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp" +#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmBlockPipelineProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using Shape = Shape_; +}; + +struct PracticeGemmBlockPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline() + { + return PracticeGemmWarpPipelineASmemBSmemCreg{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kKPack = 8; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kKPack = 8; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + constexpr index_t kMWarp = BlockGemm::MWarp; + constexpr index_t kNWarp = BlockGemm::NWarp; + constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + + constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + constexpr index_t kMWarp = BlockGemm::MWarp; + constexpr index_t kNWarp = BlockGemm::NWarp; + constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + + constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..ef12634e42 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { +template +struct PracticeGemmHostPipeline +{ + using ADataType = typename Problem_::ADataType; + using BDataType = typename Problem_::BDataType; + using CDataType = typename Problem_::CDataType; + using AccDataType = typename Problem_::AccDataType; + + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using BlockTile = typename Problem::Shape::BlockTile; + using WaveTile = typename Problem::Shape::WaveTile; + + template + CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, + const BDRAMTensorView& b_dram, + CDRAMTensorView& c_dram_ref) const + { + + // Size of the entire problem + const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K + const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N + const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + + // Size of the block tile + const auto MPerBlock = BlockTile::at(number<0>{}); + const auto NPerBlock = BlockTile::at(number<1>{}); + const auto KPerBlock = BlockTile::at(number<2>{}); + + // Number of block tile in the N direction to cover C (resultant) matrix + const auto num_tile_n = integer_divide_ceil(N, NPerBlock); + // Number of block tile in the M direction to cover C (resultant) matrix + const auto num_tile_m = integer_divide_ceil(M, MPerBlock); + + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n); + // printf("total number of tiles: %d\n", num_tile_m * num_tile_n); + // } + + // Get block id + const auto id_block = + get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1 + + // Map block id to tile id + const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto tile_id = block2tile(id_block); + + const auto tile_id_m = tile_id.at(number<0>{}); + const auto tile_id_n = tile_id.at(number<1>{}); + + // if(get_thread_id() == 0 && get_block_id() == 15) + // { + // printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n); + // } + + const auto tile_origin_m = tile_id_m * MPerBlock; + const auto tile_origin_n = tile_id_n * NPerBlock; + + // create a tile window over dram for A and B + const auto a_block_window = make_tile_window( + a_dram, make_tuple(number{}, number{}), {tile_origin_m, 0}); + + const auto b_block_window = make_tile_window( + b_dram, make_tuple(number{}, number{}), {tile_origin_n, 0}); + + constexpr auto block_gemm_pipeline = + Policy::template GetPracticeGemmBlockPipeline(); + + int num_loops_k = integer_divide_ceil(K, KPerBlock); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; + const auto c_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); + auto c_window = make_tile_window(c_dram, + make_tuple(number{}, number{}), + {tile_origin_m, tile_origin_n}); + store_tile(c_window, c_block_tile); + } +}; +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..d66c3c8522 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp" +#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmHostProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using Shape = remove_cvref_t; +}; + +struct PracticeGemmHostPolicy +{ + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline() + { + using PracticeGemmBlockPipelineProblem_ = + PracticeGemmBlockPipelineProblem; + return PracticeGemmBlockPipelineAGmemBGmemCreg{}; + } +}; +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp new file mode 100644 index 0000000000..ee2e125e24 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck_tile/host.hpp" +#include "practice_gemm.hpp" +#include "reference_gemm.hpp" + +int main() +{ + // TODO: GemmTypeConfig + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using CDataType = float; + using AccDataType = float; + + // ArgParser + ck_tile::index_t M = 512; + ck_tile::index_t N = 256; + ck_tile::index_t K = 64; + ck_tile::index_t verification = 1; + + ck_tile::index_t stride_a = K; + ck_tile::index_t stride_b = K; + ck_tile::index_t stride_c = N; + + auto a_lengths = std::array{M, K}; + auto b_lengths = std::array{N, K}; + auto c_lengths = std::array{M, N}; + + auto a_strides = std::array{stride_a, 1}; + auto b_strides = std::array{stride_b, 1}; + auto c_strides = std::array{stride_c, 1}; + + // tensors on host (cpu) + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host(c_lengths, c_strides); + + // initialize tensors + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); + c_host.SetZero(); + + // Print the tensors using the new print_first_n member function + // std::cout << "Tensor A (first 10 elements): "; + // a_host.print_first_n(10); + // std::cout << std::endl; + + // std::cout << "Tensor B (first 10 elements): "; + // b_host.print_first_n(10); + // std::cout << std::endl; + + // std::cout << "Tensor C (first 10 elements): "; + // c_host.print_first_n(10); + // std::cout << std::endl; + + // Create device tensors of same size as host tensors and copy data + ck_tile::DeviceMem a_device(a_host); + ck_tile::DeviceMem b_device(b_host); + ck_tile::DeviceMem c_device(c_host); + + // TODO: BlockTileConfig + // constexpr ck_tile::index_t warpSize = 64; + constexpr ck_tile::index_t kBlockSize = 256; + + using BlockTile = ck_tile::sequence<256, 128, 32>; + using WaveTile = ck_tile::sequence<16, 16, 16>; + + std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl; + using PracticeGemmShape = ck_tile::PracticeGemmShape; + std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; + using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; + + ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * + ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); + + std::cout << "kGridSize: " << kGridSize << std::endl; + constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU + + std::cout << "kBlockSize: " << kBlockSize << std::endl; + std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl; + + using gemm_kernel = + ck_tile::PracticeGemmKernel; + + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_device.GetDeviceBuffer()), + static_cast(b_device.GetDeviceBuffer()), + static_cast(c_device.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c)); + + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + c_device.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp new file mode 100644 index 0000000000..88879ee221 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp" +#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmShape +{ + using BlockTile = remove_cvref_t; + using WaveTile = remove_cvref_t; + + static constexpr index_t BlockTile_M = BlockTile::at(number<0>{}); + static constexpr index_t BlockTile_N = BlockTile::at(number<1>{}); + static constexpr index_t BlockTile_K = BlockTile::at(number<2>{}); + + static constexpr index_t WaveTile_M = WaveTile::at(number<0>{}); + static constexpr index_t WaveTile_N = WaveTile::at(number<1>{}); + static constexpr index_t WaveTile_K = WaveTile::at(number<2>{}); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "practice_gemm_shape", + concat('x', BlockTile_M, BlockTile_N, BlockTile_K), + concat('x', WaveTile_M, WaveTile_N, WaveTile_K)); + // clang-format on + } +}; + +template +struct PracticeGemmKernel +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, + const typename Problem::BDataType* p_b, + typename Problem::CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c) const + { + + auto a_dram = make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); + + auto b_dram = make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); + + const auto c_dram = make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); + + PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp new file mode 100644 index 0000000000..8f975be7dc --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, + const ck_tile::HostTensor& b_n_k, + ck_tile::HostTensor& c_m_n) +{ + const int N = b_n_k.mDesc.get_lengths()[0]; + const int K = b_n_k.mDesc.get_lengths()[1]; + + auto f = [&](auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_n_k(n, k); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(v_acc); + } + }; + + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1); +} diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..bf058af9c5 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp @@ -0,0 +1,195 @@ +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { + +template +struct PracticeGemmWarpPipelineASmemBSmemCreg +{ + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using WaveGemmShape = remove_cvref_t; + + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == WaveGemmShape::BlockTile_M && + NPerBlock == WaveGemmShape::BlockTile_N && + KPerBlock == WaveGemmShape::BlockTile_K, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + +#if !defined(ENABLE_PREFETCH) + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == WaveGemmShape::BlockTile_M && + NPerBlock == WaveGemmShape::BlockTile_N && + KPerBlock == WaveGemmShape::BlockTile_K, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..2efa2bcc2a --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCReg +// Default policy class should not be templated, put template on member functions instead +struct PracticeGemmWarpPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; + + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/CMakeLists.txt b/tutorial/ck_tile/CMakeLists.txt new file mode 100644 index 0000000000..9895f5a71d --- /dev/null +++ b/tutorial/ck_tile/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(00_copy_kernel) +add_subdirectory(01_naive_gemm) + From 92c1f4981ab1d081978c8f6132ca93949d4749e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 11 Nov 2025 22:55:33 +0100 Subject: [PATCH 013/114] [CK_BUILDER] Add grouped conv fwd ck tile traits (#3183) * [CK BUILDER] Add grouped conv fwd ck tile traits * Update instance_traits_tile_grouped_convolution_forward.hpp * Update grouped_convolution_forward_kernel.hpp --- .../ck_tile/builder/reflect/conv_traits.hpp | 3 + ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 1 + ...raits_tile_grouped_convolution_forward.hpp | 140 ++++++++++++++++++ .../builder/reflect/instance_traits_util.hpp | 81 +++++++++- .../builder/test/test_fwd_instance_traits.cpp | 123 +++++++++++++++ include/ck_tile/core/arch/arch.hpp | 10 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 6 +- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v6.hpp | 7 + .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 7 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 7 + .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 7 + .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 7 + .../kernel/grouped_gemm_quant_kernel.hpp | 4 +- .../grouped_convolution_forward_kernel.hpp | 17 +++ 18 files changed, 433 insertions(+), 15 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp mode change 100755 => 100644 include/ck_tile/core/arch/arch.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 86cf11f647..4b946011c2 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -15,6 +15,9 @@ #include #include #include +#include +#include "ck_tile/ops/epilogue.hpp" +#include namespace ck_tile::reflect::conv { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index a0def3e5d9..6913889c4f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -4,6 +4,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" // Forward declaration to avoid circular dependency diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp new file mode 100644 index 0000000000..f364b37ae5 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp @@ -0,0 +1,140 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// InstanceTraits specialization for GroupedConvolutionForwardKernel +// +// CRITICAL MAINTENANCE NOTE: +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: +// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +// "In sync" means that the template parameter order, names, and types in the declaration below +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are +// difficult to diagnose. Always update both files together and review changes carefully. + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" + +// Forward declaration to avoid circular dependency. +namespace ck_tile::device { + +template +struct GroupedConvolutionForwardKernel; + +} // namespace ck_tile::device + +namespace ck_tile { +namespace reflect { + +// Specialization for GroupedConvolutionForwardKernel +template +struct InstanceTraits> +{ + // CK Tile Conv Traits + // Spatial dimension + static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial; + // Specialization + static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType_::ConvSpecialization; + // DataType types + using InLayout = typename GroupedConvTraitsType_::InLayout; + using WeiLayout = typename GroupedConvTraitsType_::WeiLayout; + using DsLayout = typename GroupedConvTraitsType_::DsLayout; + using OutLayout = typename GroupedConvTraitsType_::OutLayout; + // Vector size + static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA; + static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB; + static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC; + // Num Groups To Merge + static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + // Split image (large tensors) + static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + + // TilePartitioner + // Block configuration + static constexpr int kMPerBlock = TilePartitioner_::MPerBlock; + static constexpr int kNPerBlock = TilePartitioner_::NPerBlock; + static constexpr int kKPerBlock = TilePartitioner_::KPerBlock; + + static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{}); + static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{}); + static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{}); + + // Data types + using ADataType = typename GemmPipeline_::ADataType; + using BDataType = typename GemmPipeline_::BDataType; + // Gemm Pipeline + using GemmPipeline = GemmPipeline_; + static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler; + static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer; + static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups; + + // Epilogue Pipeline + using AccDataType = typename EpiloguePipeline_::AccDataType; + using EDataType = typename EpiloguePipeline_::ODataType; + using DsDataType = typename EpiloguePipeline_::DsDataType; + using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "GroupedConvolutionForwardKernel"; + + // Template parameters in exact order matching InstanceTraits member order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," + << ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization + oss << "," << detail::layout_name(); // 3. InLayout + oss << "," << detail::layout_name(); // 4. WeiLayout + oss << "," << detail::tuple_name(); // 5. DsLayout + oss << "," << detail::layout_name(); // 6. OutLayout + oss << "," << kVectorSizeA; // 7. VectorSizeA + oss << "," << kVectorSizeB; // 8. VectorSizeB + oss << "," << kVectorSizeC; // 9. VectorSizeC + oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge + oss << "," << kEnableSplitImage; // 11. EnableSplitImage + oss << "," << kMPerBlock; // 12. MPerBlock + oss << "," << kNPerBlock; // 13. NPerBlock + oss << "," << kKPerBlock; // 14. KPerBlock + oss << "," << kMWarp; // 15. MWarp + oss << "," << kNWarp; // 16. NWarp + oss << "," << kKWarp; // 17. KWarp + oss << "," << kMWarpTile; // 18. MWarpTile + oss << "," << kNWarpTile; // 19. NWarpTile + oss << "," << kKWarpTile; // 20. KWarpTile + oss << "," << detail::type_name(); // 21. ADataType + oss << "," << detail::type_name(); // 22. BDataType + oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched + oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups + oss << "," << kNumWaveGroups; // 26. NumWaveGroups + oss << "," << detail::type_name(); // 27. AccDataType + oss << "," << detail::type_name(); // 28. EDataType + oss << "," << detail::tuple_name(); // 29. DsDataType + oss << "," + << detail::elementwise_op_name(); // 30. + // CDEElementwiseOperation + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index e4d154ae10..2e918c5c2d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -28,6 +28,10 @@ #include #include #include +#include +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" namespace ck_tile::reflect::detail { @@ -38,7 +42,7 @@ namespace impl { template consteval std::string_view type_name_impl() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "fp32"; @@ -50,11 +54,11 @@ consteval std::string_view type_name_impl() return "s8"; else if constexpr(std::is_same_v) return "s32"; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || std::is_same_v) return "bf16"; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || std::is_same_v) return "fp8"; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || std::is_same_v) return "bf8"; else return std::string_view{}; // Return empty for supported types @@ -168,6 +172,17 @@ constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineSchedule } } +constexpr std::string_view pipeline_scheduler_name(ck_tile::GemmPipelineScheduler sched) +{ + using enum ck_tile::GemmPipelineScheduler; + switch(sched) + { + case Default: return "Default"; + case Intrawave: return "Intrawave"; + case Interwave: return "Interwave"; + } +} + // Convert BlockGemmPipelineVersion enum to string constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver) { @@ -206,6 +221,26 @@ constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched) } } +// Convert TailNumber enum to string +constexpr std::string_view tail_number_name(ck_tile::TailNumber tail_num) +{ + using enum ck_tile::TailNumber; + switch(tail_num) + { + case Odd: return "Odd"; + case Even: return "Even"; + case One: return "One"; + case Two: return "Two"; + case Three: return "Three"; + case Four: return "Four"; + case Five: return "Five"; + case Six: return "Six"; + case Seven: return "Seven"; + case Empty: return "Empty"; + case Full: return "Full"; + } +} + // Convert std::array to string template inline std::string array_to_string(const std::array& arr) @@ -356,17 +391,53 @@ constexpr std::string tuple_name() }(static_cast(nullptr)); } +template + requires requires { [](ck_tile::tuple*) {}(static_cast(nullptr)); } +constexpr std::string tuple_name() +{ + return [](ck_tile::tuple*) constexpr { + if constexpr(sizeof...(Ts) == 0) + { + return std::string("EmptyTuple"); + } + else if constexpr((IsLayoutType && ...)) + { + // Lambda wrapper for layout_name + auto layout_name_fn = []() { return layout_name(); }; + return detail::build_list_string("tuple", + layout_name_fn); + } + else if constexpr((IsDataType && ...)) + { + // Lambda wrapper for type_name + auto type_name_fn = []() { return type_name(); }; + return detail::build_list_string("tuple", type_name_fn); + } + else + { + static_assert((IsLayoutType && ...) || (IsDataType && ...), + "tuple elements must be all layouts or all data types, not mixed"); + return std::string{}; // unreachable + } + }(static_cast(nullptr)); +} + // Concept to check if a type is a ck::Tuple template concept IsCkTuple = requires { [](ck::Tuple*) {}(static_cast(nullptr)); }; +// Concept to check if a type is a ck_tile::tuple +template +concept IsCkTileTuple = + requires { [](ck_tile::tuple*) {}(static_cast(nullptr)); }; + // Deduces whether to use tuple_name or type_name // Handles both scalar data types and ck::Tuple types template constexpr std::string type_or_type_tuple_name() { - if constexpr(IsCkTuple) + if constexpr(IsCkTuple || IsCkTileTuple) { return tuple_name(); } diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index b57b20eb7d..af950b441c 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace { @@ -720,4 +721,126 @@ TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat) EXPECT_EQ(instance_str, expected_str); } +TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) +{ + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits<2 /*NDimSpatial*/, + ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/, + ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/, + ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/, + ck_tile::tuple<> /*DsLayout*/, + ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/, + 4 /*VectorSizeA*/, + 4 /*VectorSizeB*/, + 4 /*VectorSizeC*/, + 1 /*NumGroupsToMerge*/, + false /*EnableSplitImage*/>; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>, + ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>, + ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + false /*DoubleSmemBuffer*/, + typename GroupedConvTraitsType::AsLayoutFwd, + typename GroupedConvTraitsType::BsLayoutFwd, + typename GroupedConvTraitsType::CLayoutFwd, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + 1 /*NumWaveGroups*/>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ck_tile::bf16_t /*InDataType*/, + ck_tile::bf16_t /*WeiDataType*/, + float /*AccDataType*/, + GemmShape, + GemmUniversalTraits, + ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, + true /*has_hot_loop_v*/, + ck_tile::TailNumber::Full /*tail_number_v*/, + ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, + ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, + ck_tile::bf16_t /*OutDataType*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem /*DsDataType*/, + float /*AccDataType*/, + ck_tile::bf16_t /*OutDataType*/, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + ck_tile::element_wise::PassThrough /*CDElementWise*/, + 128 /*MPerBlock*/, + 128 /*NPerBlock*/, + 4 /*M_Warp*/, + 1 /*N_Warp*/, + 16 /*M_Warp_Tile*/, + 16 /*N_Warp_Tile*/, + 16 /*K_Warp_Tile*/, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + ck_tile::memory_operation_enum::set /*memory_operation*/, + 1 /*kNumWaveGroups*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeC>>; + + using GroupedConvFwdKernel = + ck_tile::device::GroupedConvolutionForwardKernel; + + std::string instance_str = ck_tile::reflect::instance_string(); + + std::string expected_str = "GroupedConvolutionForwardKernel" + "<2" // NDimSpatial + ",Default" // ConvSpecialization + ",NHWGC" // InLayout + ",GKYXC" // WeiLayout + ",EmptyTuple" // DsLayout + ",NHWGK" // OutLayout + ",4" // VectorSizeA + ",4" // VectorSizeB + ",4" // VectorSizeC + ",1" // NumGroupsToMerge + ",0" // EnableSplitImage + ",128" // MPerBlock + ",128" // NPerBlock + ",32" // KPerBlock + ",4" // MWarp + ",1" // NWarp + ",1" // KWarp + ",16" // MWarpTile + ",16" // NWarpTile + ",16" // KWarpTile + ",bf16" // ADataType + ",bf16" // BDataType + ",COMPUTE_V3" // BlkGemmPipelineVer + ",Intrawave" // BlkGemmPipeSched + ",0" // DoubleSmemBuffer + ",1" // NumWaveGroups + ",fp32" // AccDataType + ",bf16" // EDataType + ",EmptyTuple" // DsDataType + ",PassThrough" // CDEElementwiseOperation + ">"; + + EXPECT_EQ(instance_str, expected_str); +} + } // anonymous namespace diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp old mode 100755 new mode 100644 index 5bf8548470..b66c00e392 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -299,12 +299,12 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) #endif } -#define CK_CONSTANT_ADDRESS_SPACE \ - __attribute__((address_space( \ +#define CK_TILE_CONSTANT_ADDRESS_SPACE \ + __attribute__((address_space( \ static_cast>(address_space_enum::constant)))) template -__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) +__device__ T* cast_pointer_to_generic_address_space(T CK_TILE_CONSTANT_ADDRESS_SPACE* p) { // cast a pointer in "Constant" address space (4) to "Generic" address space (0) // only c-style pointer cast seems be able to be compiled @@ -315,13 +315,13 @@ __device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* } template -__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) +__host__ __device__ T CK_TILE_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) { // cast a pointer in "Generic" address space (0) to "Constant" address space (4) // only c-style pointer cast seems be able to be compiled; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" - return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) + return (T CK_TILE_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) #pragma clang diagnostic pop } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 551dc6f50d..a72b1ba544 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -190,7 +190,7 @@ struct GroupedGemmKernel */ CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { - using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; + using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*; const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; int occupancy; HIP_CHECK_ERROR( @@ -518,7 +518,7 @@ struct GroupedGemmKernel // For non-persistent kernels template > - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, index_t group_count) const { const index_t block_id = ck_tile::get_block_1d_id(); @@ -541,7 +541,7 @@ struct GroupedGemmKernel template , typename = void> // extra template parameter to avoid redefinition - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count) const { const index_t grid_size = ck_tile::get_grid_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 91da3cd27b..b293097d89 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -164,6 +164,13 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_ASYNC"; + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index aaa04615fd..a1bbcbe990 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -170,6 +170,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using Base::PrefetchStages; using Base::UsePersistentKernel; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V3"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index ff1e33bd5d..238b4e2389 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -172,6 +172,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V4"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 7263ddd5a1..6343ff9872 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -99,6 +99,13 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 static constexpr index_t NumWarps = BlockGemmShape::NumWarps; static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{}); + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V5"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp index 2ae9001098..5b57560f6e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp @@ -159,6 +159,13 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V6"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index d363626efd..ba71e3b6cb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -214,6 +214,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "MEMORY"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index eb363d59b8..8a4fb59b51 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t kLdsAlignmentInBytes = 16; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "BASIC_V1"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index c309f8908a..32217e0024 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV2 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. static constexpr bool DoubleSmemBuffer = false; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "BASIC_V2"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 87f6c753b4..cae2bd0e9f 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -176,6 +176,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "PRESHUFFLE_V2"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 75ac1ca6ab..32f1279e93 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -208,7 +208,7 @@ struct QuantGroupedGemmKernel */ CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { - using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; + using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*; const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>; int occupancy; HIP_CHECK_ERROR( @@ -499,7 +499,7 @@ struct QuantGroupedGemmKernel template , typename = void> // extra template parameter to avoid redefinition - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count) const { const index_t grid_size = ck_tile::get_grid_size(); diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 7e70d2b422..6de331fe6d 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -16,6 +16,10 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp" +#endif + namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. @@ -568,6 +572,19 @@ struct GroupedConvolutionForwardKernel // clang-format on } +#ifdef CK_EXPERIMENTAL_BUILDER + CK_TILE_HOST std::string GetInstanceString() const + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_tile_grouped_convolution_forward.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } +#endif + CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) { return dim3( From 40d2ed0f2a442026c57dc17e6e7bd281b6c2535c Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 12 Nov 2025 10:26:14 +0800 Subject: [PATCH 014/114] [CK_TILE] Share partition index across threads and specify offset in load_tile()/async_load_tile()/load_tile_transpose() (#2905) * Allow sharing partition index across threads * Fix typo PartitoinIndex -> PartitionIndex * Remove C++20 'requires' usages * Add missing template arguments * Fix load_tile() overload ambiguity issue * Use SFINAE to exclude invalid arguments * Add additional offset parameter to the async_load_tile() * Remove async_load_tile() default argument to avoid ambiguity * Extract tile_window coordinate compute logic as method * Use warp-shared LDS base address in tile_window::async_load() * Add constraint to tile_window::load() templates * Fix wrong type traits is_class_v<> usages * Add missing constraint to async_load_tile() * Add missing tile_window::load() overload * Add more constraint to avoid load_tile() call ambiguity * Rename ParitionIndex as ReplacementPartitionIndex * Update pre_computed_warp_coords_ in move_extended() * Fix inconsistency between template parameters and documentation * Allow specifying pre-computed parition index * Add type straits is_sequence<> & is_tile_distribution<> * Add type straits is_tensor_view<> * Add type constraints to make_tile_window() templates * Allow passing partition_index to set_tile_if() * Allow specifying partition_index to store_tile() * Add missing template parameter of replace_bottom_tensor_view() * Allow passing partition_index to Default2DEpilogue * Make get_partition_index() public * Add _with_offset() postfix to avoid resolution error * Remove ReplacementPartitionIndex template param * Add missing comments * Add load_tile_transpose_with_offset() overload --- include/ck_tile/core/container/sequence.hpp | 11 + include/ck_tile/core/tensor/load_tile.hpp | 51 ++++- .../core/tensor/load_tile_transpose.hpp | 60 +++++- .../core/tensor/static_distributed_tensor.hpp | 41 +++- include/ck_tile/core/tensor/store_tile.hpp | 51 +++++ include/ck_tile/core/tensor/tensor_view.hpp | 15 ++ .../ck_tile/core/tensor/tile_distribution.hpp | 27 ++- .../core/tensor/tile_scatter_gather.hpp | 5 +- include/ck_tile/core/tensor/tile_window.hpp | 195 +++++++++++++++--- .../ops/epilogue/default_2d_epilogue.hpp | 41 +++- .../ck_tile/ops/reduce/block/block_reduce.hpp | 2 +- 11 files changed, 441 insertions(+), 58 deletions(-) diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index cfec2237f9..1a88a98cbf 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -214,6 +214,17 @@ CK_TILE_HOST_DEVICE static void print(const sequence&) printf(">"); } +template +struct is_sequence : std::false_type +{ +}; +template +struct is_sequence> : std::true_type +{ +}; +template +inline constexpr bool is_sequence_v = is_sequence::value; + namespace impl { template struct __integer_sequence; diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 2e9ab0f5c6..1be4259e97 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -17,6 +17,19 @@ #include "ck_tile/core/tensor/null_tensor.hpp" namespace ck_tile { +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template >> +CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load_with_offset( + offset, number{}, bool_constant{}); +} template CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, @@ -49,6 +62,23 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, tile_window, elementwise, number{}, bool_constant{}); } +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template > && + std::is_class_v>> +CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile, + const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load_with_offset( + offset, dst_tile, number{}, bool_constant{}); +} + template {}, bool_constant{}, bool_constant{}); } +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template > && + std::is_class_v>> +CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.async_load_with_offset( + offset, lds_tile, number{}, bool_constant{}); +} + template = {}, bool_constant = {}) { - return tile_window.async_load( - lds_tile, number{}, bool_constant{}); + return async_load_tile_with_offset( + lds_tile, tile_window, 0, number{}, bool_constant{}); } template ::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto -load_tile_transpose(const tile_window_with_static_distribution& tile_window) +CK_TILE_DEVICE auto load_tile_transpose_with_offset( + const tile_window_with_static_distribution& __restrict__ tile_window, + index_t offset) { using OutTileDstrEncode = typename OutputTileDistributionTraits< typename TileDistribution_::DstrEncode, typename BottomTensorView_::DataType>::TransposedDstrEncode; auto out_tensor = make_static_distributed_tensor( make_static_tile_distribution(OutTileDstrEncode{})); - auto trans_tensor = tile_window.template load_transpose(); + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); constexpr auto input_distr = TileDistribution_{}; constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); @@ -443,4 +446,49 @@ load_tile_transpose(const tile_window_with_static_distribution, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE auto +load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) +{ + return load_tile_transpose_with_offset(tile_window, 0); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index b73a27c8d5..5228ad978a 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -155,11 +155,11 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi // get X indices from tuple of tile_distributed_index<> template -CK_TILE_HOST_DEVICE constexpr auto -get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, - DistributedIndices distributed_indices) +CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices( + StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices, + decltype(get_partition_index(tile_distribution)) partition_index) { - const auto partition_index = detail::get_partition_index(tile_distribution); constexpr auto y_indices = tile_distribution.get_y_indices_from_distributed_indices(distributed_indices); @@ -170,6 +170,16 @@ get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, return x_coord.get_bottom_index(); } +// get X indices from tuple of tile_distributed_index<> +template +CK_TILE_HOST_DEVICE constexpr auto +get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices) +{ + return get_x_indices_from_distributed_indices( + tile_distribution, distributed_indices, get_partition_index(tile_distribution)); +} + template CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor& out_tensor, @@ -192,6 +202,29 @@ set_tile_if(static_distributed_tensor& out_ten }); } +template +CK_TILE_HOST_DEVICE void +set_tile_if(static_distributed_tensor& out_tensor, + DataType value, + XIndicesPredicate predicate, + decltype(get_partition_index(std::declval())) partition_index) +{ + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + StaticTileDistribution{}, distributed_indices, partition_index); + + if(predicate(x_indices)) + { + out_tensor(distributed_indices) = value; + } + }); + }); +} + // this function used inside span loop over template CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number) diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index d5a716664d..b535b40534 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -38,6 +39,31 @@ store_tile(tile_window_with_static_lengths& t tile_window.store(dstr_tensor); } +template +CK_TILE_DEVICE void +store_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor, + decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr, + partition_index); + + tile_window.store(dstr_tensor); +} + template +CK_TILE_DEVICE void +store_tile_raw(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor, + decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr, + partition_index); + + tile_window.store_raw(dstr_tensor); +} + template +struct is_tensor_view : std::false_type +{ +}; +template +struct is_tensor_view> : std::true_type +{ +}; +template <> +struct is_tensor_view : std::true_type +{ +}; +template +inline constexpr bool is_tensor_view_v = is_tensor_view::value; + template CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) { - return Distribution::_get_partition_index(); + return Distribution::get_partition_index(); } -} // namespace detail // distributed span template @@ -91,7 +89,7 @@ struct tile_distribution CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } - CK_TILE_HOST_DEVICE static auto _get_partition_index() + CK_TILE_HOST_DEVICE static auto get_partition_index() { // only support warp-tile and block-tile static_assert(NDimP == 1 or NDimP == 2, "wrong!"); @@ -172,9 +170,9 @@ struct tile_distribution } #endif - template + template CK_TILE_HOST_DEVICE auto - calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const + calculate_index(const PartitionIndex& ps_idx = get_partition_index()) const { const auto ps_ys_idx = container_concat(ps_idx, array{0}); const auto window_adaptor_thread_coord_tmp = @@ -230,6 +228,23 @@ struct tile_distribution } }; +template +struct is_tile_distribution : std::false_type +{ +}; +template +struct is_tile_distribution> : std::true_type +{ +}; +template +inline constexpr bool is_tile_distribution_v = is_tile_distribution::value; + namespace detail { template diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 4b04fd513d..e77ca805bb 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -189,8 +189,7 @@ struct tile_scatter_gather // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat(get_partition_index(tile_distribution), array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = @@ -836,7 +835,7 @@ struct tile_scatter_gather // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_dstr_), array{0})); + container_concat(get_partition_index(tile_dstr_), array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index cfa2420f2f..1123ce7604 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -12,6 +12,7 @@ #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_view.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" @@ -67,18 +68,54 @@ struct tile_window_with_static_distribution const typename Base::BottomTensorView& bottom_tensor_view, const typename Base::WindowLengths& window_lengths, const typename Base::BottomTensorIndex& window_origin, - const typename Base::TileDstr& tile_distribution) + const typename Base::TileDstr& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) : pre_computed_coords_{} { - this->window_origin_ = window_origin; - this->window_lengths_ = window_lengths; - this->bottom_tensor_view_ = bottom_tensor_view; - this->tile_dstr_ = tile_distribution; + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; + this->tile_dstr_ = tile_distribution; + + pre_computed_coords_ = + prepare_coords(bottom_tensor_view, window_origin, tile_distribution, partition_index); + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + auto use_lane_id_0 = partition_index; + use_lane_id_0[1] = 0; + + pre_computed_warp_coords_ = + prepare_coords(bottom_tensor_view, window_origin, tile_distribution, use_lane_id_0); + } + } + + CK_TILE_DEVICE constexpr tile_window_with_static_distribution( + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution) + : tile_window_with_static_distribution(bottom_tensor_view, + window_lengths, + window_origin, + tile_distribution, + get_partition_index(tile_distribution)) + { + } + + CK_TILE_DEVICE constexpr auto + prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) const + { + array, NumCoord> + coords; + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat(partition_index, multi_index{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -105,18 +142,31 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - pre_computed_coords_(iCoord) = - make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); }); + + return coords; } template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const + { + return load_with_offset( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_with_offset(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - load(dst_tensor, number{}, bool_constant{}); + load_with_offset(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -236,6 +286,19 @@ struct tile_window_with_static_distribution CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const + { + load_with_offset( + 0, dst_tensor, number{}, bool_constant{}); + } + + template >>> + CK_TILE_DEVICE auto load_with_offset(index_t offset, + DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -258,7 +321,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = this->get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, 0, bool_constant{}); + bottom_tensor_thread_coord, offset, bool_constant{}); // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -450,10 +513,12 @@ struct tile_window_with_static_distribution template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - number = {}, - bool_constant = {}) const + bool oob_conditional_check = true, + typename = std::enable_if_t>>> + CK_TILE_DEVICE auto async_load_with_offset(index_t offset, + LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; @@ -472,12 +537,15 @@ struct tile_window_with_static_distribution auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0]; + auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1]; + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; // Use precomputed window origin auto lds_bottom_tensor_thread_idx = - window_origin + window_adaptor_thread_coord.get_bottom_index(); + window_origin + window_adaptor_warp_coord.get_bottom_index(); // Use precomputed tensor descriptor const auto lds_coord = @@ -490,7 +558,7 @@ struct tile_window_with_static_distribution this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, - number<0>{}, + offset, bool_constant{}); // Move thread coordinate if not last access @@ -503,18 +571,33 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); } }); }); } template - CK_TILE_DEVICE auto load_transpose() const + CK_TILE_DEVICE auto load_transpose(number = {}, + bool_constant = {}) const + { + return this->template load_transpose_with_offset( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - this->template load_transpose( - dst_tensor, number{}, bool_constant{}); + this->template load_transpose_with_offset(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -522,9 +605,10 @@ struct tile_window_with_static_distribution typename DistributedTensor, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true> - CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor, - number = {}, - bool_constant = {}) const + CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset, + DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -550,7 +634,7 @@ struct tile_window_with_static_distribution const vector_t vec_value = this->get_bottom_tensor_view() .template get_transpose_vectorized_elements( - bottom_tensor_thread_coord, 0); + bottom_tensor_thread_coord, offset); // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto orig_idx_ys = generate_tuple( @@ -862,16 +946,26 @@ struct tile_window_with_static_distribution pre_computed_coords_(iCoord)(I1), step); }); + + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_warp_coords_(iCoord)(I1), + step); + }); + } } CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) { // TODO: this use less register for FA, but more register for GEMM // need investigation - const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - this->tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(this->tile_dstr_), - array{0})); + const auto window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(this->tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(get_partition_index(this->tile_dstr_), + array{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -908,6 +1002,12 @@ struct tile_window_with_static_distribution // per-thread coordinate for bottom tensor array, NumCoord> pre_computed_coords_; + // pre_computed_warp_coords_ exists only in the global memory tile_window + std::conditional_t< + Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global, + array, NumCoord>, + std::byte> + pre_computed_warp_coords_; }; // TODO: use strategy @@ -929,6 +1029,27 @@ make_tile_window(const TensorView_& tensor_view, tensor_view, window_lengths, origin, tile_distribution}; } +template && + is_tile_distribution_v>> +CK_TILE_DEVICE constexpr auto +make_tile_window(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index, + number = {}) +{ + return tile_window_with_static_distribution, + remove_cvref_t, + remove_cvref_t, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, partition_index}; +} + // this version can't be called in a constexpr context template +CK_TILE_DEVICE constexpr auto +make_tile_window(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) +{ + return make_tile_window(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + partition_index); +} + template CK_TILE_DEVICE constexpr auto make_tile_window_raw(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution) { - auto w = make_tile_window(tile_window.get_bottom_tensor_view(), - tile_window.get_window_lengths(), - tile_window.get_window_origin(), - tile_distribution); + auto w = make_tile_window(tile_window, tile_distribution); w.init_raw(); return w; } diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 2843966cd7..8cf47c46e7 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -93,13 +93,27 @@ struct Default2DEpilogue const DsDramWindows& ds_dram_windows, void* = nullptr) const { + constexpr bool is_partition_index = + std::is_convertible_v; + const auto storeOrUpdateTile = [&](const auto& o_tile) { // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { if constexpr(MemoryOperation == memory_operation_enum::set) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + store_tile_raw(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + } } else { @@ -111,16 +125,35 @@ struct Default2DEpilogue { if constexpr(MemoryOperation == memory_operation_enum::set) { - store_tile(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + store_tile(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_tile)); + } } else { - update_tile(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + update_tile(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_tile)); + } } } }; - if constexpr(!std::is_same_v && Problem::NumDTensor >= 1) + if constexpr(!std::is_same_v && !is_partition_index && + Problem::NumDTensor >= 1) { using elementwise_result_t = decltype(load_tile( make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(), diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 7a10d1fa56..2fd8a48eee 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -32,7 +32,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, constexpr index_t idim_p_lane = NDimP - 1; - const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution()); + const auto ps_idx = get_partition_index(acc_tensor.get_tile_distribution()); const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size(); From 299c9bca1bee2ef77bb78878bcdd9d11a13564e5 Mon Sep 17 00:00:00 2001 From: Yashvardhan Agarwal Date: Wed, 12 Nov 2025 17:30:20 +0200 Subject: [PATCH 015/114] [CK_Tile] Pooling example readme update (#3174) * pooling example readme update - The updated readme explains the transformations of the pooling kernel using a mermaid diagram * Update example/ck_tile/36_pooling/README.md Co-authored-by: spolifroni-amd * resolve comments --------- Co-authored-by: spolifroni-amd --- example/ck_tile/36_pooling/README.md | 110 +++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/example/ck_tile/36_pooling/README.md b/example/ck_tile/36_pooling/README.md index ab49b57095..4417e03734 100644 --- a/example/ck_tile/36_pooling/README.md +++ b/example/ck_tile/36_pooling/README.md @@ -2,6 +2,116 @@ This folder contains example for the pooling operator using ck_tile tile-programming implementation. Currently the pooling kernel only supports 2D and 3D pooling. +## Tensor Descriptor Transformations + +The pooling kernel transforms the input tensor into 2D format suitable for reduction. This section explains the transformation pipeline for both 2D and 3D pooling operations. + +### 3D Pooling Transformations + +For 3D pooling, the input tensor has shape `(N, D, H, W, C)` where: +- `N`: batch size +- `D`: depth dimension +- `H`: height dimension +- `W`: width dimension +- `C`: channel dimension + +The transformations convert this 5D tensor into a 2D tensor where rows represent output positions (M) and columns represent pooling window elements (K). + +```mermaid +graph TD + %% Input Tensor: (N, D, H, W, C) + Input["Input Tensor
(N, D, H, W, C)"] + style Input fill:#e1f5fe + + %% Pass-through N dimension + PassN["Pass-through N
(batch size)"] + style PassN fill:#f3e5f5 + Input --> PassN + + %% Pad spatial dimensions + PadD["Pad D
(depth with left/right padding)"] + style PadD fill:#fff9c4 + Input --> PadD + + PadH["Pad H
(height with left/right padding)"] + style PadH fill:#fff9c4 + Input --> PadH + + PadW["Pad W
(width with left/right padding)"] + style PadW fill:#fff9c4 + Input --> PadW + + %% Pass-through C dimension + PassC["Pass-through C
(channels)"] + style PassC fill:#f3e5f5 + Input --> PassC + + %% Embed sliding windows + EmbedD["Embed D
window(Z) × output_positions(Dₒ)"] + style EmbedD fill:#fff3e0 + PadD --> EmbedD + + EmbedH["Embed H
window(Y) × output_positions(Hₒ)"] + style EmbedH fill:#fff3e0 + PadH --> EmbedH + + EmbedW["Embed W
window(X) × output_positions(Wₒ)"] + style EmbedW fill:#fff3e0 + PadW --> EmbedW + + %% Merge into 2D matrix + MergeM["Merge M
(N, Dₒ, Hₒ, Wₒ, C)
→ output positions"] + style MergeM fill:#e8f5e9 + PassN --> MergeM + EmbedD --> MergeM + EmbedH --> MergeM + EmbedW --> MergeM + PassC --> MergeM + + MergeK["Merge K
(Z, Y, X)
→ window elements"] + style MergeK fill:#e8f5e9 + EmbedD --> MergeK + EmbedH --> MergeK + EmbedW --> MergeK + + %% Final padding for block alignment + PadM["Right-pad M
(for block alignment)"] + style PadM fill:#fff9c4 + MergeM --> PadM + + PadK["Right-pad K
(for block alignment)"] + style PadK fill:#fff9c4 + MergeK --> PadK + + %% Result + Result["2D Matrix
(M × K)"] + style Result fill:#c8e6c9 + PadM --> Result + PadK --> Result +``` + +**Transformation Steps:** +1. **Padding**: Apply left and right padding to spatial dimensions (D, H, W) to handle boundary conditions +2. **Sliding Windows**: Use embed transforms to create sliding windows across each spatial dimension, expanding each dimension into (window_size, output_positions) +3. **Reshaping**: Merge all dimensions into a 2D matrix where: + - M dimension = N × Dₒ × Hₒ × Wₒ × C (total output positions) + - K dimension = Z × Y × X (elements per pooling window) +4. **Block Alignment**: Apply right padding to ensure M and K dimensions are aligned to block size + +### 2D Pooling Transformations + +2D pooling follows the same transformation pipeline but operates on 4D tensors with shape `(N, H, W, C)`. The process is identical except: +- Only H and W dimensions are padded and embedded +- K dimension merges only (Y, X) window elements +- M dimension merges (N, Hₒ, Wₒ, C) + +### Output Tensor Transformations + +The output tensor transformations are simpler: +- Merge all output dimensions (N, Dₒ/Hₒ, Wₒ, C) into a single M dimension +- Apply right padding for block alignment +- The result is a 1D tensor that maps directly to the M dimension of the computation matrix + ## build ``` # in the root of ck_tile From 7414a0f4d43fbb581421c236c7c68bf2ba7664ca Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Wed, 12 Nov 2025 20:23:54 +0100 Subject: [PATCH 016/114] Wmma support for gemm_reduce (#3145) * Initial implementation GEMM+Reduce: - device struct - epilogue struct * Fix tests, improve profiler and add initial instances * Add instances * Fix compilation error * Address review comments * Fix logging --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 661 ++++++++++++++++++ .../grid/epilogue_cshuffle_v3_reduce_wmma.hpp | 470 +++++++++++++ .../grid/epilogue_cshuffle_v3_wmma_base.hpp | 1 + .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 24 + .../gpu/gemm_reduce/CMakeLists.txt | 7 +- ..._f16_f16_f16_f32_f32_km_kn_mn_instance.cpp | 88 +++ ..._f16_f16_f16_f32_f32_km_nk_mn_instance.cpp | 88 +++ ..._f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp | 88 +++ ..._f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp | 86 +++ .../profiler/profile_gemm_reduce_impl.hpp | 55 +- test/gemm_reduce/CMakeLists.txt | 10 +- ...duce_fp16_xdl.cpp => gemm_reduce_fp16.cpp} | 2 +- 12 files changed, 1568 insertions(+), 12 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp rename test/gemm_reduce/{gemm_reduce_fp16_xdl.cpp => gemm_reduce_fp16.cpp} (96%) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..166c1a7581 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -0,0 +1,661 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_reduce_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = + EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_reduces_grid; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + + using CDEShuffleBlockTransferScalarPerVectors = + Sequence; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using ReduceTrait = ReduceTrait_; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideC_{StrideC}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + index_t StrideC_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_c_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + float ave_time = 0; + + index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_); + + const auto Run = [&](const auto& kernel) { + // Note: cache flushing not supported + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.p_reduces_grid_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline setting"); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v1 setting"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Even) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else if(TailNum == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v3 setting"); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_c_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmReduce_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp new file mode 100644 index 0000000000..c2bd65f134 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -0,0 +1,470 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" + +namespace ck { + +template +struct ReduceTrait_ +{ + using ReduceAccDataType_ = ReduceAccDataType; + using ReducePtrsGlobal_ = ReducePtrsGlobal; + using ReduceOperations_ = ReduceOperations; + using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations; + using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations; + using ReduceGlobalMemoryDataOperation_ = ReduceGlobalMemoryDataOperation; + using CReduceThreadClusterLengths_MPerBlock_NPerBlock_ = + CReduceThreadClusterLengths_MPerBlock_NPerBlock; + static constexpr index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_ = + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock; + static constexpr index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_ = + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock; +}; + +template +struct EpilogueReduceCShuffle + : EpilogueCShuffleBase +{ + using Base = EpilogueCShuffleBase< + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + MPerBlock, + NPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + CDEElementwiseOperation, + ThisThreadBlock, + BlockwiseGemmPipe>; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + using Base::GetCShuffleLDSDescriptor; + using Base::GetVgprToLDSEpilogueDescriptor; + using Base::I0; + using Base::I1; + using Base::I3; + using Base::NumDTensor; + + // assume Reduce is packed tensor + __device__ static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + __device__ EpilogueReduceCShuffle( + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_, + const index_t MRaw_) + : p_reduces_grid(p_reduces_grid_), + reduce_in_element_ops(reduce_in_element_ops_), + reduce_out_element_ops(reduce_out_element_ops_), + MRaw(MRaw_), + reduce_grid_desc_m{MakeReduceGridDescriptor_M(MRaw)} + { + } + + template + __device__ void Run(CThreadBuf& c_thread_buf, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + auto reduce_grid_desc_mblock_mperblock = + MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // LDS buffer + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + // Thread transfer Vgpr to LDS + auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor(); + + // Space Filling Curve Vgpr + constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{}; + + // Space Filling Curve Vmem + constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{}; + + // Block descriptor + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + GetCShuffleLDSDescriptor(); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // Thread transfer LDS to Vmem + auto cde_shuffle_block_copy_lds_and_global = + Base::template GetLDSToVmemEpilogueDescriptor( + c_ds_desc_refs, + e_grid_desc_mblock_mperblock_nblock_nperblock, + cde_element_op, + block_m_id, + block_n_id); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert( + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) * + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == + BlockSize, + "wrong!"); + + static_assert( + (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) % + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) == + 0 && + (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) % + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) / + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) / + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1); + + static constexpr index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size(); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = + make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{}, + Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + CShuffleDataType, + typename ReduceTrait::ReduceAccDataType_, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + typename ReduceTrait::ReduceAccDataType_, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_, + ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_m_id, // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + // CShuffle and Store + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block loads its C data from LDS, D from global, applies elementwise + // operation and stores result E to global + cde_shuffle_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + static_for<0, NumReduce, 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = + remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = ReduceOperation::template GetIdentityValue< + typename ReduceTrait::ReduceAccDataType_>(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_cde_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); + } + }); + } + + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid; + typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops; + typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops; + index_t MRaw; + ReduceGridDesc_M reduce_grid_desc_m; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp index d2c6c92c9f..30f81b7411 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 56f09cee96..020d0110cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -25,6 +25,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" namespace ck { @@ -622,6 +623,29 @@ struct GridwiseGemm_wmma_cshuffle_v3_base BlockwiseGemmPipe, BlockSize>; + template + using EpilogueReduceCShuffle = EpilogueReduceCShuffle< + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + MPerBlock, + NPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + CDEElementwiseOperation, + ThisThreadBlock, + BlockwiseGemmPipe, + GemmSpec, + BlockSize, + ReduceTrait>; + template __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 7ee3efe7f5..12d1026ea1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,7 +1,12 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_reduce_instance device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp + + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..d92e84380f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..b21531e394 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d32e663b1c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..f4013b5414 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_reduce_impl.hpp index 74a1b60fe3..c870a95cbe 100644 --- a/profiler/include/profiler/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_reduce_impl.hpp @@ -34,6 +34,7 @@ using ReduceOutElementOps = ck::Tuple; using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; +#ifdef CK_USE_XDL void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); @@ -45,6 +46,20 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector&); +#endif +#ifdef CK_USE_WMMA +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); +#endif } // namespace instance } // namespace device @@ -211,33 +226,61 @@ bool profile_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); +#endif } } @@ -274,6 +317,8 @@ bool profile_gemm_reduce_impl(int do_verification, auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + std::string gemm_name = gemm_ptr->GetTypeString(); + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { ++num_kernel; @@ -289,8 +334,6 @@ bool profile_gemm_reduce_impl(int do_verification, float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - std::string gemm_name = gemm_ptr->GetTypeString(); - std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + @@ -317,9 +360,9 @@ bool profile_gemm_reduce_impl(int do_verification, reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); - ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); - ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); - ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); + pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); if(do_log) { @@ -346,7 +389,7 @@ bool profile_gemm_reduce_impl(int do_verification, } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << gemm_name << ": does not support this GEMM problem" << std::endl; } } diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 121ecde609..ae2246e628 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,4 +1,6 @@ -add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) -endif() \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) + endif() +endif() diff --git a/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp similarity index 96% rename from test/gemm_reduce/gemm_reduce_fp16_xdl.cpp rename to test/gemm_reduce/gemm_reduce_fp16.cpp index b1f2c36c9f..30657c87c5 100644 --- a/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include From 3784c0e7c395af214fdddd5f702691b354bfe8d4 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 12 Nov 2025 11:47:07 -0800 Subject: [PATCH 017/114] add permissions for /tmp folder (#3201) --- Dockerfile.aiter | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Dockerfile.aiter b/Dockerfile.aiter index dab3f9588d..94591f9012 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -17,4 +17,6 @@ RUN pip install pandas zmq einops ninja && \ useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ chown -R jenkins:jenkins /home/jenkins && \ chmod -R a+rwx /home/jenkins && \ + chown -R jenkins:jenkins /tmp && \ + chmod -R a+rwx /tmp && \ sudo usermod -aG irc jenkins From 9342365713f6c8601e35921e7adeba9769b784b7 Mon Sep 17 00:00:00 2001 From: John Afaganis Date: Wed, 12 Nov 2025 17:05:53 -0700 Subject: [PATCH 018/114] Add C++17 deprecation warning to CHANGELOG.md (#3203) * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md --- CHANGELOG.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 213631721f..44d0837b40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## (Unreleased) Composable Kernel for ROCm - -### Added +## Composable Kernel 1.1.0 for ROCm 7.2.0 ### Added * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM @@ -32,6 +30,10 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added an optional template parameter `Arch` (`gfx9_t`, `gfx12_t` etc.) to `make_kernel` to support linking multiple object files that have the same kernel compiled for different architectures. * FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time. +### Upcoming changes + +* To enhance capabilities and user experience, Composable Kernel will adopt C++20 features in ROCm 8.0, updating the minimum compiler requirement to C++20. Please ensure your development environment meets this requirement for a seamless transition. + ## Composable Kernel 1.1.0 for ROCm 7.1.0 ### Added From 797ddfa41e5e2c45f9eea9e6c969ba528e5a9c39 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 12 Nov 2025 19:07:28 -0500 Subject: [PATCH 019/114] chore(copyright): update copyright header for test_data directory (#3194) * chore(copyright): update copyright header for tile_engine directory * chore(copyright): update copyright header for script directory * chore(copyright): update copyright header for test_data directory --- test_data/generate_model_configs.py | 530 ++++++++++++++-------------- test_data/generate_test_dataset.sh | 2 +- test_data/miopen_to_csv.py | 2 +- test_data/run_model_with_miopen.py | 2 +- 4 files changed, 268 insertions(+), 268 deletions(-) diff --git a/test_data/generate_model_configs.py b/test_data/generate_model_configs.py index 567870fd73..f3c47e3715 100644 --- a/test_data/generate_model_configs.py +++ b/test_data/generate_model_configs.py @@ -1,265 +1,265 @@ -#!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -""" -Generate Model Configuration Combinations for MIOpen Testing - -This script generates all possible combinations of model parameters -and saves them as CSV files that can be read by the shell script. -""" - -import csv -import argparse - - -def generate_2d_configs(mode="full"): - """Generate all 2D model configuration combinations - - Args: - mode: 'small' for minimal set (~50 configs), 'half' for reduced set (~250 configs), 'full' for comprehensive set (~500 configs) - """ - - # Define parameter ranges - models_2d = [ - "resnet18", - "resnet34", - "resnet50", - "mobilenet_v2", - "mobilenet_v3_large", - "mobilenet_v3_small", - "vgg11", - "vgg16", - "vgg19", - "alexnet", - "googlenet", - "densenet121", - "densenet161", - "squeezenet1_0", - "squeezenet1_1", - "shufflenet_v2_x1_0", - ] - - if mode == "small": - # Minimal set for quick testing - batch_sizes = [1, 8] # Just two batch sizes - # Very limited input dimensions - only 2 key sizes - input_dims = [ - (224, 224), # Standard (most common) - (256, 256), # Medium - ] - # Use only first 3 models for minimal testing - models_2d = models_2d[:3] # Only resnet18, resnet34, resnet50 - elif mode == "half": - # Reduced set for faster testing - batch_sizes = [1, 8, 32] # Small, medium, large - # Reduced input dimensions - 5 key sizes - input_dims = [ - (64, 64), # Small - (224, 224), # Standard (most common) - (512, 512), # Large - (224, 320), # Rectangular - (227, 227), # AlexNet preferred - ] - else: # full mode - # More comprehensive but still limited - batch_sizes = [1, 4, 8, 16, 32] - # More dimensions but skip some redundant ones - input_dims = [ - (64, 64), - (128, 128), - (224, 224), - (256, 256), - (512, 512), # Square - (224, 320), - (320, 224), # Rectangular (reduced from 4) - (227, 227), # AlexNet preferred - (299, 299), # Inception preferred - ] - - precisions = ["fp32"] # , 'fp16', 'bf16'] - channels = [3] # Most models expect RGB - - configs = [] - config_id = 1 - - # Generate all combinations (but limit to reasonable subset) - for model in models_2d: - for batch_size in batch_sizes: - for height, width in input_dims: - for precision in precisions: - # Skip some combinations to keep dataset manageable - if batch_size > 16 and height > 256: - continue # Skip large batch + large image combinations - if precision != "fp32" and batch_size < 8: - continue # Skip mixed precision with tiny batches - - config_name = f"{model}_b{batch_size}_{height}x{width}_{precision}" - - config = { - "config_name": config_name, - "model": model, - "batch_size": batch_size, - "channels": channels[0], - "height": height, - "width": width, - "precision": precision, - } - - configs.append(config) - config_id += 1 - - return configs - - -def generate_3d_configs(mode="full"): - """Generate all 3D model configuration combinations - - Args: - mode: 'small' for minimal set (~10 configs), 'half' for reduced set (~50 configs), 'full' for comprehensive set (~100 configs) - """ - - models_3d = ["r3d_18", "mc3_18", "r2plus1d_18"] - - if mode == "small": - # Minimal set for quick testing - batch_sizes = [1, 4] # Just two batch sizes - temporal_sizes = [8] # Only smallest temporal size - # Very limited spatial dimensions - input_dims = [ - (112, 112), # Standard for 3D - ] - # Use only first model for minimal testing - models_3d = models_3d[:1] # Only r3d_18 - elif mode == "half": - # Reduced set for faster testing - batch_sizes = [1, 4, 8] # Skip batch_size=2 - temporal_sizes = [8, 16] # Skip 32 (most expensive) - # Reduced spatial dimensions - input_dims = [ - (112, 112), # Small (common for video) - (224, 224), # Standard - (224, 320), # Rectangular - ] - else: # full mode - # More comprehensive but still reasonable - batch_sizes = [1, 2, 4, 8] # 3D models are more memory intensive - temporal_sizes = [8, 16, 32] - # More dimensions - input_dims = [ - (112, 112), - (224, 224), - (256, 256), # Standard sizes - (224, 320), - (320, 224), # Rectangular - ] - - precisions = ["fp32"] # , 'fp16'] # Skip bf16 for 3D to reduce combinations - channels = [3] - - configs = [] - - for model in models_3d: - for batch_size in batch_sizes: - for temporal_size in temporal_sizes: - for height, width in input_dims: - for precision in precisions: - # Skip very large combinations - if batch_size > 4 and temporal_size > 16: - continue - if batch_size > 2 and height > 224: - continue - - config_name = f"{model}_b{batch_size}_t{temporal_size}_{height}x{width}_{precision}" - - config = { - "config_name": config_name, - "model": model, - "batch_size": batch_size, - "channels": channels[0], - "temporal_size": temporal_size, - "height": height, - "width": width, - "precision": precision, - } - - configs.append(config) - - return configs - - -def save_configs_to_csv(configs, filename, config_type): - """Save configurations to CSV file""" - - if not configs: - print(f"No {config_type} configurations generated") - return - - fieldnames = list(configs[0].keys()) - - with open(filename, "w", newline="\n", encoding="utf-8") as csvfile: - csvfile.write(f"# {config_type} Model Configurations\n") - csvfile.write(f"# Generated {len(configs)} configurations\n") - - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") - writer.writeheader() - - for config in configs: - writer.writerow(config) - - print(f"Generated {len(configs)} {config_type} configurations → {filename}") - - -def main(): - parser = argparse.ArgumentParser( - description="Generate model configuration combinations" - ) - parser.add_argument( - "--output-2d", - type=str, - default="model_configs_2d.csv", - help="Output file for 2D configurations", - ) - parser.add_argument( - "--output-3d", - type=str, - default="model_configs_3d.csv", - help="Output file for 3D configurations", - ) - parser.add_argument( - "--mode", - choices=["small", "half", "full"], - default="full", - help="Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)", - ) - parser.add_argument( - "--limit", - type=int, - help="Limit number of configurations per type (for testing)", - ) - - args = parser.parse_args() - - print(f"Generating {args.mode} model configurations...") - - print("Generating 2D model configurations...") - configs_2d = generate_2d_configs(mode=args.mode) - if args.limit: - configs_2d = configs_2d[: args.limit] - save_configs_to_csv(configs_2d, args.output_2d, "2D") - - print("Generating 3D model configurations...") - configs_3d = generate_3d_configs(mode=args.mode) - if args.limit: - configs_3d = configs_3d[: args.limit] - save_configs_to_csv(configs_3d, args.output_3d, "3D") - - print( - f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}" - ) - print("\nTo use these configurations:") - print(" Update generate_test_dataset.sh to read from these CSV files") - - -if __name__ == "__main__": - main() +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate Model Configuration Combinations for MIOpen Testing + +This script generates all possible combinations of model parameters +and saves them as CSV files that can be read by the shell script. +""" + +import csv +import argparse + + +def generate_2d_configs(mode="full"): + """Generate all 2D model configuration combinations + + Args: + mode: 'small' for minimal set (~50 configs), 'half' for reduced set (~250 configs), 'full' for comprehensive set (~500 configs) + """ + + # Define parameter ranges + models_2d = [ + "resnet18", + "resnet34", + "resnet50", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "vgg11", + "vgg16", + "vgg19", + "alexnet", + "googlenet", + "densenet121", + "densenet161", + "squeezenet1_0", + "squeezenet1_1", + "shufflenet_v2_x1_0", + ] + + if mode == "small": + # Minimal set for quick testing + batch_sizes = [1, 8] # Just two batch sizes + # Very limited input dimensions - only 2 key sizes + input_dims = [ + (224, 224), # Standard (most common) + (256, 256), # Medium + ] + # Use only first 3 models for minimal testing + models_2d = models_2d[:3] # Only resnet18, resnet34, resnet50 + elif mode == "half": + # Reduced set for faster testing + batch_sizes = [1, 8, 32] # Small, medium, large + # Reduced input dimensions - 5 key sizes + input_dims = [ + (64, 64), # Small + (224, 224), # Standard (most common) + (512, 512), # Large + (224, 320), # Rectangular + (227, 227), # AlexNet preferred + ] + else: # full mode + # More comprehensive but still limited + batch_sizes = [1, 4, 8, 16, 32] + # More dimensions but skip some redundant ones + input_dims = [ + (64, 64), + (128, 128), + (224, 224), + (256, 256), + (512, 512), # Square + (224, 320), + (320, 224), # Rectangular (reduced from 4) + (227, 227), # AlexNet preferred + (299, 299), # Inception preferred + ] + + precisions = ["fp32"] # , 'fp16', 'bf16'] + channels = [3] # Most models expect RGB + + configs = [] + config_id = 1 + + # Generate all combinations (but limit to reasonable subset) + for model in models_2d: + for batch_size in batch_sizes: + for height, width in input_dims: + for precision in precisions: + # Skip some combinations to keep dataset manageable + if batch_size > 16 and height > 256: + continue # Skip large batch + large image combinations + if precision != "fp32" and batch_size < 8: + continue # Skip mixed precision with tiny batches + + config_name = f"{model}_b{batch_size}_{height}x{width}_{precision}" + + config = { + "config_name": config_name, + "model": model, + "batch_size": batch_size, + "channels": channels[0], + "height": height, + "width": width, + "precision": precision, + } + + configs.append(config) + config_id += 1 + + return configs + + +def generate_3d_configs(mode="full"): + """Generate all 3D model configuration combinations + + Args: + mode: 'small' for minimal set (~10 configs), 'half' for reduced set (~50 configs), 'full' for comprehensive set (~100 configs) + """ + + models_3d = ["r3d_18", "mc3_18", "r2plus1d_18"] + + if mode == "small": + # Minimal set for quick testing + batch_sizes = [1, 4] # Just two batch sizes + temporal_sizes = [8] # Only smallest temporal size + # Very limited spatial dimensions + input_dims = [ + (112, 112), # Standard for 3D + ] + # Use only first model for minimal testing + models_3d = models_3d[:1] # Only r3d_18 + elif mode == "half": + # Reduced set for faster testing + batch_sizes = [1, 4, 8] # Skip batch_size=2 + temporal_sizes = [8, 16] # Skip 32 (most expensive) + # Reduced spatial dimensions + input_dims = [ + (112, 112), # Small (common for video) + (224, 224), # Standard + (224, 320), # Rectangular + ] + else: # full mode + # More comprehensive but still reasonable + batch_sizes = [1, 2, 4, 8] # 3D models are more memory intensive + temporal_sizes = [8, 16, 32] + # More dimensions + input_dims = [ + (112, 112), + (224, 224), + (256, 256), # Standard sizes + (224, 320), + (320, 224), # Rectangular + ] + + precisions = ["fp32"] # , 'fp16'] # Skip bf16 for 3D to reduce combinations + channels = [3] + + configs = [] + + for model in models_3d: + for batch_size in batch_sizes: + for temporal_size in temporal_sizes: + for height, width in input_dims: + for precision in precisions: + # Skip very large combinations + if batch_size > 4 and temporal_size > 16: + continue + if batch_size > 2 and height > 224: + continue + + config_name = f"{model}_b{batch_size}_t{temporal_size}_{height}x{width}_{precision}" + + config = { + "config_name": config_name, + "model": model, + "batch_size": batch_size, + "channels": channels[0], + "temporal_size": temporal_size, + "height": height, + "width": width, + "precision": precision, + } + + configs.append(config) + + return configs + + +def save_configs_to_csv(configs, filename, config_type): + """Save configurations to CSV file""" + + if not configs: + print(f"No {config_type} configurations generated") + return + + fieldnames = list(configs[0].keys()) + + with open(filename, "w", newline="\n", encoding="utf-8") as csvfile: + csvfile.write(f"# {config_type} Model Configurations\n") + csvfile.write(f"# Generated {len(configs)} configurations\n") + + writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") + writer.writeheader() + + for config in configs: + writer.writerow(config) + + print(f"Generated {len(configs)} {config_type} configurations → {filename}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate model configuration combinations" + ) + parser.add_argument( + "--output-2d", + type=str, + default="model_configs_2d.csv", + help="Output file for 2D configurations", + ) + parser.add_argument( + "--output-3d", + type=str, + default="model_configs_3d.csv", + help="Output file for 3D configurations", + ) + parser.add_argument( + "--mode", + choices=["small", "half", "full"], + default="full", + help="Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)", + ) + parser.add_argument( + "--limit", + type=int, + help="Limit number of configurations per type (for testing)", + ) + + args = parser.parse_args() + + print(f"Generating {args.mode} model configurations...") + + print("Generating 2D model configurations...") + configs_2d = generate_2d_configs(mode=args.mode) + if args.limit: + configs_2d = configs_2d[: args.limit] + save_configs_to_csv(configs_2d, args.output_2d, "2D") + + print("Generating 3D model configurations...") + configs_3d = generate_3d_configs(mode=args.mode) + if args.limit: + configs_3d = configs_3d[: args.limit] + save_configs_to_csv(configs_3d, args.output_3d, "3D") + + print( + f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}" + ) + print("\nTo use these configurations:") + print(" Update generate_test_dataset.sh to read from these CSV files") + + +if __name__ == "__main__": + main() diff --git a/test_data/generate_test_dataset.sh b/test_data/generate_test_dataset.sh index 1124311feb..e9c4937445 100755 --- a/test_data/generate_test_dataset.sh +++ b/test_data/generate_test_dataset.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT # Generate Comprehensive Convolution Test Dataset for CK diff --git a/test_data/miopen_to_csv.py b/test_data/miopen_to_csv.py index d6a85e1e3f..e4ca42adeb 100644 --- a/test_data/miopen_to_csv.py +++ b/test_data/miopen_to_csv.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/test_data/run_model_with_miopen.py b/test_data/run_model_with_miopen.py index 9eee3b53fb..2e655fb82c 100644 --- a/test_data/run_model_with_miopen.py +++ b/test_data/run_model_with_miopen.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ From 9af30f04b65b8e50877d01ce8377a8cd581d462c Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Thu, 13 Nov 2025 00:56:18 -0600 Subject: [PATCH 020/114] Ck tile engine commons (#3166) * Moving Preshuffle to commons * Fixing Common Validations * Addressing Review Comments * Partial Rebasing * Partial Rebasing * Partial Rebasing * Rebasing Complete --- ...tion_utils.py => gemm_validation_utils.py} | 450 ++++++++++++++-- tile_engine/ops/gemm/codegen_utils.py | 210 -------- tile_engine/ops/gemm/gemm_instance_builder.py | 3 +- .../gemm_multi_d_instance_builder.py | 5 +- .../commons/validation_utils.py | 483 ------------------ .../gemm_preshuffle_instance_builder.py | 36 +- 6 files changed, 434 insertions(+), 753 deletions(-) rename tile_engine/ops/commons/{validation_utils.py => gemm_validation_utils.py} (58%) delete mode 100644 tile_engine/ops/gemm/codegen_utils.py delete mode 100644 tile_engine/ops/gemm_preshuffle/commons/validation_utils.py diff --git a/tile_engine/ops/commons/validation_utils.py b/tile_engine/ops/commons/gemm_validation_utils.py similarity index 58% rename from tile_engine/ops/commons/validation_utils.py rename to tile_engine/ops/commons/gemm_validation_utils.py index 5787446e8c..1b4a7191cd 100644 --- a/tile_engine/ops/commons/validation_utils.py +++ b/tile_engine/ops/commons/gemm_validation_utils.py @@ -1,16 +1,19 @@ #!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT - -""" -Validation utilities for GEMM kernel generation. -Extracted from tile_engine_develop for consistency. -""" +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. import logging from typing import Tuple, List -# Element size mapping for different data types +GEMM_PIPELINES = ["mem", "compv3", "compv4"] + +GEMM_PRESHUFFLE_PIPELINES = ["preshufflev2"] + +LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", +} + ELEMENT_SIZE_MAP = { "fp16": 2, "bf16": 2, @@ -47,9 +50,79 @@ WARP_SUPPORTED_COMBINATIONS = { ], } -# [TODO] Handle this while moving code to commons -# Supported warp tile combinations for different GPU architectures and data types -WARP_TILE_SUPPORTED_COMBINATIONS = { +GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { "gfx90a": { "fp16_fp16_fp16": [ [32, 32, 8], @@ -132,7 +205,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { }, } -# Unsupported trait combinations TRAIT_UNSUPPORTED_COMBINATIONS = { ("compv3", "cshuffle", "interwave"), ("compv3", "default", "interwave"), @@ -220,7 +292,7 @@ def validate_lds_capacity( matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) total_tile_in_lds = matrix_a_size + matrix_b_size - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 if total_tile_in_lds > max_tile_size: error_msg = ( @@ -234,7 +306,7 @@ def validate_lds_capacity( return True, "" -def validate_warp_tile_combination( +def validate_gemm_warp_tile_combination( warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, @@ -250,7 +322,51 @@ def validate_warp_tile_combination( current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] # Check if we have GPU-specific combinations - gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + gpu_warp_tile_combinations = GEMM_WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def validate_gemm_preshuffle_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str, +) -> Tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get( + gpu_name, {} + ) if not gpu_warp_tile_combinations: # If GPU not recognized, try to be permissive but log warning logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") @@ -292,7 +408,6 @@ def is_tile_config_valid( pipeline: str, layout: str, gpu_target: str, - trait_name: str = None, ) -> bool: """ Comprehensive tile configuration validation. @@ -349,37 +464,81 @@ def is_tile_config_valid( logging.debug(f"LDS validation failed: {lds_error}") return False - # Validate whole workgroup cover configuration - wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - layout, - a_datatype, - b_datatype, - ) - if not wr_cover_valid: - logging.debug( - f"Whole workgroup cover configuration validation failed: {wg_cover_error}" + if pipeline in GEMM_PIPELINES: + gemm_valid, gemm_valid_error = validate_gemm( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + gpu_target, ) - return False + if not gemm_valid: + logging.debug(f"GEMM validation failed: {gemm_valid_error}") + return False - # Validate warp tile combination - warp_tile_valid, warp_tile_error = validate_warp_tile_combination( - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - gpu_target, - ) - if not warp_tile_valid: - logging.debug(f"Warp tile validation failed: {warp_tile_error}") - return False + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_gemm_warp_tile_combination( + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + gpu_target, + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + elif pipeline in GEMM_PRESHUFFLE_PIPELINES: + preshuffle_valid, preshuffle_valid_error = validate_gemm_preshuffle( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + gpu_target, + ) + if not preshuffle_valid: + logging.debug( + f"GEMM Preshuffle validation failed: {preshuffle_valid_error}" + ) + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = ( + validate_gemm_preshuffle_warp_tile_combination( + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + gpu_target, + ) + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False return True @@ -398,12 +557,6 @@ def get_dtype_string(datatype: str) -> str: return dtype_map.get(datatype, "float") -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - - def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. @@ -600,3 +753,200 @@ def get_global_vector_load_size( return int(PackedSize * 2 / element_size(DataType)) else: return PackedSize + + +def validate_gemm( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + layout: str, + gpu_target: str, + trait_name: str = None, +) -> bool: + # GEMM Validation + # Validate whole workgroup cover configuration + whole_workgroup_cover_valid, whole_workgroup_cover_error = ( + validate_whole_wg_cover_configuration( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + layout, + a_datatype, + b_datatype, + ) + ) + if not whole_workgroup_cover_valid: + logging.debug( + f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}" + ) + return False, whole_workgroup_cover_error + + return True, "" + + +def validate_gemm_preshuffle( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + layout: str, + gpu_target: str, + trait_name: str = None, +) -> bool: + # Preshuffle Validations + # Validate vector load alignment + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + vector_valid, vector_error = validate_vector_load_alignment( + warp_tile_m, + warp_tile_k, + a_datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ) + if not vector_valid: + logging.debug(f"Vector load alignment failed: {vector_error}") + return False, "vector load alignment error" + + # Validate M0, M1, M2 configuration for matrix A row-major layout + m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + a_datatype, + vector_load_size=16, + warp_size=64, + ) + if not m0_m1_m2_valid: + logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") + return False, m0_m1_m2_error + + return True, "" + + +def validate_vector_load_alignment( + wg_m: int, + wg_k: int, + a_datatype: str, + m_iter_per_warp: int, + wave_size: int, + vector_load_size: int, +) -> Tuple[bool, str]: + try: + # Calculate the memory access pattern size + a_element_size = element_size(a_datatype) + access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size + + # Check if it's aligned to vector load size + if access_size % vector_load_size != 0: + error_msg = ( + f"Vector load alignment violation: " + f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " + f"% {vector_load_size} = {access_size % vector_load_size} != 0. " + f"Access size: {access_size} bytes" + ) + return False, error_msg + + return True, "" + + except Exception as e: + return False, f"Error in vector load validation: {str(e)}" + + +def validate_m0_m1_m2_configuration( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + a_datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> Tuple[bool, str]: + """ + Validate M0, M1, M2 configuration for matrix A row-major layout. + This ensures proper memory access pattern alignment. + """ + try: + # Validation for A as row-major + MPerBlock = tile_m + + # Calculate K1 using element size + K1 = vector_load_size / element_size(a_datatype) + + # Check if K1 is valid (must be integer) + if K1 != int(K1): + return ( + False, + f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", + ) + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False, f"tile_k({tile_k}) must be divisible by K1({K1})" + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False, f"warp_size({warp_size}) must be divisible by K0({K0})" + M2 = warp_size // K0 + + # Calculate number of warps and block size + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + # Calculate M0 (assuming get_warp_size() returns warp_size) + M0 = BlockSize // warp_size # This should equal NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False, f"M2({M2}) * M0({M0}) cannot be zero" + + if MPerBlock % (M2 * M0) != 0: + return ( + False, + f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", + ) + M1 = MPerBlock // (M2 * M0) + + # Validate the assertion: M0 * M1 * M2 == MPerBlock + calculated_m_per_block = M0 * M1 * M2 + if calculated_m_per_block != MPerBlock: + error_msg = ( + f"Incorrect M0, M1, M2 configuration! " + f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " + f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" + ) + return False, error_msg + + return True, "" + + except ZeroDivisionError as e: + return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" + except Exception as e: + return False, f"Error in M0/M1/M2 validation: {str(e)}" diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py deleted file mode 100644 index eecc2228a6..0000000000 --- a/tile_engine/ops/gemm/codegen_utils.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -# -*- coding: utf-8 -*- - -""" -Mappings and utility functions for kernel code generation. -""" - -DATA_TYPE_MAP = { - "fp32": "float", - "fp16": "ck_tile::half_t", - "bf16": "ck_tile::bf16_t", - "int8": "ck_tile::int8_t", - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "int4": "ck_tile::pk_int4_t", - "int32": "ck_tile::int32_t", -} - -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - -DEFAULT_EPILOGUE = """ - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - kPadM, - kPadN, - WarpTileM, - WarpTileN, - WarpTileK, - UniversalGemmProblem::TransposeC, - true, - memory_operation>>; -""" - -CSHUFFLE_EPILOGUE = """ - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - WarpM, - WarpN, - WarpTileM, - WarpTileN, - WarpTileK, - UniversalGemmProblem::TransposeC, - memory_operation>>; -""" - -PIPELINE_MAP = { - "mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"], - "compv3": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "ck_tile::GemmPipelineAgBgCrCompV3", - ], - "compv4": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV4", - "ck_tile::GemmPipelineAgBgCrCompV4", - ], -} - -SCHEDULER_MAP = { - "interwave": "ck_tile::GemmPipelineScheduler::Interwave", - "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", -} - -EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE} - - -def BOOL_MAP(b_): - return {True: "true", False: "false"}[bool(b_)] - - -# To Do: add some more supported combinations -warp_tile_supported_combinations = { - "gfx90a": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], - }, - "gfx942": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], - "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], - }, - "gfx950": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 32], - [16, 16, 64], - [16, 16, 128], - [32, 32, 64], - ], - "bf8_bf8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 64], - [16, 16, 32], - [16, 16, 128], - [32, 32, 64], - ], - "fp8_bf8_fp16": [ - [16, 16, 128], - [32, 32, 64], - ], - "bf8_fp8_fp16": [ - [16, 16, 128], - [32, 32, 64], - ], - }, - "gfx1201": { - "fp16_fp16_fp16": [ - [16, 16, 16], - ], - }, -} - -# To Do: remove some unsupported combinations -trait_unsupported_combinations = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), -} - - -ELEMENT_SIZE_MAP = { - "fp16": 2, - "bf16": 2, - "int8": 1, - "fp8": 1, - "bf8": 1, - "int4": 0.5, - "int32": 4, -} - - -def element_size(data_type: str) -> float: - """Calculate the size (in bytes) of a single element for given data type.""" - data_type = data_type.lower() - if data_type not in ELEMENT_SIZE_MAP: - raise ValueError(f"Unsupported data type: {data_type}") - return ELEMENT_SIZE_MAP[data_type] diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 8885c821c1..d450f20105 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -21,7 +21,8 @@ def _import_validation_utils(): # Load the module dynamically spec = importlib.util.spec_from_file_location( - "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index cc167fb75f..06da7ea8a2 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -21,7 +21,8 @@ def _import_validation_utils(): # Load the module dynamically spec = importlib.util.spec_from_file_location( - "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) @@ -824,7 +825,7 @@ def main(): elif elementwise_function == "add": function_name = "MultiDAdd" elif elementwise_function == "passthrough": - function_name = "PassThrough" # TODO Change this + function_name = "PassThrough" args.elementwise_function = function_name diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py deleted file mode 100644 index 70ce3b0d72..0000000000 --- a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py +++ /dev/null @@ -1,483 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -""" -Validation utilities for GEMM kernel generation. -Extracted from tile_engine_develop for consistency. -""" - -import logging -from typing import Tuple, List - -# Element size mapping for different data types -ELEMENT_SIZE_MAP = { - "fp16": 2, - "bf16": 2, - "int8": 1, - "fp8": 1, - "bf8": 1, - "int4": 0.5, - "int32": 4, - "fp32": 4, - "fp64": 8, -} - -# [TODO] Handle this while moving code to commons -# Supported warp tile combinations for different GPU architectures and data types -WARP_TILE_SUPPORTED_COMBINATIONS = { - "gfx90a": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], - }, - "gfx942": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], - "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], - }, - "gfx950": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 32], - [16, 16, 64], - [16, 16, 128], - [32, 32, 64], - ], - "bf8_bf8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 64], - [16, 16, 32], - [16, 16, 128], - [32, 32, 64], - ], - }, -} - -# Unsupported trait combinations -TRAIT_UNSUPPORTED_COMBINATIONS = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), -} - - -def element_size(data_type: str) -> float: - """Calculate the size (in bytes) of a single element for given data type.""" - data_type = data_type.lower() - if data_type not in ELEMENT_SIZE_MAP: - raise ValueError(f"Unsupported data type: {data_type}") - return ELEMENT_SIZE_MAP[data_type] - - -def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: - """Check if a trait combination is valid.""" - if pipeline not in ["preshufflev2"]: - raise ValueError("Accepted pipeline values are: ['preshufflev2']") - if epilogue not in ["default", "cshuffle"]: - return ValueError("Accepted epilogue values are: ['default', 'cshuffle']") - if scheduler not in ["default"]: - return ValueError("Accepted scheduler values are: ['default']") - return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS - - -def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: - """Validate warp configuration.""" - return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] - - -def validate_dimension_alignment( - tile_m: int, - tile_n: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, -) -> Tuple[bool, List[str]]: - """Check if tile dimensions are properly aligned with warp dimensions.""" - alignment_issues = [] - - if tile_m % (warp_m * warp_tile_m) != 0: - alignment_issues.append( - f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" - ) - if tile_n % (warp_n * warp_tile_n) != 0: - alignment_issues.append( - f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" - ) - if tile_k % (warp_k * warp_tile_k) != 0: - alignment_issues.append( - f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" - ) - - return len(alignment_issues) == 0, alignment_issues - - -def validate_lds_capacity( - tile_m: int, - tile_n: int, - tile_k: int, - a_datatype: str, - b_datatype: str, - pipeline: str, -) -> Tuple[bool, str]: - """Validate LDS capacity requirements.""" - matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) - matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) - total_tile_in_lds = matrix_a_size + matrix_b_size - - max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 - - if total_tile_in_lds > max_tile_size: - error_msg = ( - f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" - f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" - f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" - ) - return False, error_msg - - return True, "" - - -def validate_warp_tile_combination( - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, - a_datatype: str, - b_datatype: str, - c_datatype: str, - gpu_name: str, -) -> Tuple[bool, str]: - """Validate warp tile combination against GPU-specific supported combinations.""" - - # Construct the key for looking up supported combinations - warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" - current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - - # Check if we have GPU-specific combinations - gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) - if not gpu_warp_tile_combinations: - # If GPU not recognized, try to be permissive but log warning - logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") - return True, "" - - # Check if we have combinations for this data type combination - allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) - if not allowed_combinations: - # For data type combinations not in the list, be permissive - logging.debug( - f"No warp tile combinations found for data types: {warp_tile_key}" - ) - return True, "" - - # Check if current combination is in the allowed list - if current_combination not in allowed_combinations: - error_msg = ( - f"Invalid warp tile combination: {current_combination} not in allowed list. " - f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" - ) - return False, error_msg - - return True, "" - - -def is_tile_config_valid( - tile_m: int, - tile_n: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, - a_datatype: str, - b_datatype: str, - c_datatype: str, - pipeline: str, - gpu_target: str, - trait_name: str = None, -) -> bool: - """ - Comprehensive tile configuration validation. - Returns True if configuration is valid, False otherwise. - """ - # Basic sanity checks - if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: - return False - if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: - return False - if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: - return False - - # Check that warp tiles fit within block tiles - if warp_m * warp_tile_m > tile_m: - return False - if warp_n * warp_tile_n > tile_n: - return False - if warp_k * warp_tile_k > tile_k: - return False - - # Validate vector load alignment - m_iter_per_warp = tile_m / (warp_m * warp_tile_m) - vector_valid, vector_error = validate_vector_load_alignment( - warp_tile_m, - warp_tile_k, - a_datatype, - m_iter_per_warp, - wave_size=64, - vector_load_size=16, - ) - if not vector_valid: - logging.debug(f"Vector load alignment failed: {vector_error}") - return False - - # Validate M0, M1, M2 configuration for matrix A row-major layout - m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( - tile_m, - tile_k, - warp_m, - warp_n, - warp_k, - a_datatype, - vector_load_size=16, - warp_size=64, - ) - if not m0_m1_m2_valid: - logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") - return False - - # Validate warp configuration - if not validate_warp_configuration(warp_m, warp_n, warp_k): - logging.debug( - f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" - ) - return False - - # Validate dimension alignment - is_aligned, alignment_issues = validate_dimension_alignment( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) - if not is_aligned: - logging.debug( - f"Dimension alignment failed: {', '.join(alignment_issues)}. " - f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - ) - return False - - # Validate LDS capacity - lds_valid, lds_error = validate_lds_capacity( - tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline - ) - if not lds_valid: - logging.debug(f"LDS validation failed: {lds_error}") - return False - - # Validate warp tile combination - warp_tile_valid, warp_tile_error = validate_warp_tile_combination( - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - gpu_target, - ) - if not warp_tile_valid: - logging.debug(f"Warp tile validation failed: {warp_tile_error}") - return False - - return True - - -def validate_vector_load_alignment( - wg_m: int, - wg_k: int, - a_datatype: str, - m_iter_per_warp: int, - wave_size: int, - vector_load_size: int, -) -> Tuple[bool, str]: - try: - # Calculate the memory access pattern size - a_element_size = element_size(a_datatype) - access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size - - # Check if it's aligned to vector load size - if access_size % vector_load_size != 0: - error_msg = ( - f"Vector load alignment violation: " - f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " - f"% {vector_load_size} = {access_size % vector_load_size} != 0. " - f"Access size: {access_size} bytes" - ) - return False, error_msg - - return True, "" - - except Exception as e: - return False, f"Error in vector load validation: {str(e)}" - - -def validate_m0_m1_m2_configuration( - tile_m: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - a_datatype: str, - vector_load_size: int = 16, - warp_size: int = 64, -) -> Tuple[bool, str]: - """ - Validate M0, M1, M2 configuration for matrix A row-major layout. - This ensures proper memory access pattern alignment. - """ - try: - # Validation for A as row-major - MPerBlock = tile_m - - # Calculate K1 using element size - K1 = vector_load_size / element_size(a_datatype) - - # Check if K1 is valid (must be integer) - if K1 != int(K1): - return ( - False, - f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", - ) - K1 = int(K1) - - # Calculate K0 - if tile_k % K1 != 0: - return False, f"tile_k({tile_k}) must be divisible by K1({K1})" - K0 = tile_k // K1 - - # Calculate M2 - if warp_size % K0 != 0: - return False, f"warp_size({warp_size}) must be divisible by K0({K0})" - M2 = warp_size // K0 - - # Calculate number of warps and block size - NumWarps = warp_m * warp_n * warp_k - BlockSize = NumWarps * warp_size - - # Calculate M0 (assuming get_warp_size() returns warp_size) - M0 = BlockSize // warp_size # This should equal NumWarps - - # Calculate M1 - if (M2 * M0) == 0: - return False, f"M2({M2}) * M0({M0}) cannot be zero" - - if MPerBlock % (M2 * M0) != 0: - return ( - False, - f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", - ) - M1 = MPerBlock // (M2 * M0) - - # Validate the assertion: M0 * M1 * M2 == MPerBlock - calculated_m_per_block = M0 * M1 * M2 - if calculated_m_per_block != MPerBlock: - error_msg = ( - f"Incorrect M0, M1, M2 configuration! " - f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " - f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" - ) - return False, error_msg - - return True, "" - - except ZeroDivisionError as e: - return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" - except Exception as e: - return False, f"Error in M0/M1/M2 validation: {str(e)}" - - -# [TODO] Handle this while moving code to commons Add more datatype to this function if needed -def get_dtype_string(datatype: str) -> str: - """Get C++ type string for datatype""" - dtype_map = { - "fp16": "ck_tile::fp16_t", - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "bf16": "ck_tile::bf16_t", - "fp32": "float", - "fp64": "double", - } - return dtype_map.get(datatype, "float") - - -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - - -def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: - """ - Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. - """ - code = str(layout_code).strip().lower() - - a_layout = LAYOUT_MAP[code[0]] - b_layout = LAYOUT_MAP[code[1]] - c_layout = LAYOUT_MAP[code[2]] - return a_layout, b_layout, c_layout diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 9ce6d8cb25..654a039b9c 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -8,15 +8,34 @@ import itertools import logging import multiprocessing import concurrent.futures - from pathlib import Path +import importlib.util -from commons.validation_utils import ( - is_tile_config_valid, - is_trait_combination_valid, - get_dtype_string, - get_abc_layouts, -) + +def _import_validation_utils(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), + ) + validation_utils = importlib.util.module_from_spec(spec) + spec.loader.exec_module(validation_utils) + + return validation_utils + + +# Import validation functions +_validation_utils = _import_validation_utils() +is_tile_config_valid = _validation_utils.is_tile_config_valid +is_trait_combination_valid = _validation_utils.is_trait_combination_valid +get_dtype_string = _validation_utils.get_dtype_string +get_abc_layouts = _validation_utils.get_abc_layouts + +logging.basicConfig(level=logging.INFO) class GemmPreshuffleKernelBuilder: @@ -305,6 +324,8 @@ class GemmPreshuffleKernelBuilder: b_datatype = self.datatype c_datatype = self.datatype + layout = self.layout + # Special handling for certain data types if self.datatype in ["fp8", "bf8"]: c_datatype = "fp16" @@ -324,6 +345,7 @@ class GemmPreshuffleKernelBuilder: b_datatype, c_datatype, pipeline, + layout, self.gpu_target, ) From 6fd8ddabe798b1856a92049c5979611246b5b367 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Thu, 13 Nov 2025 00:43:40 -0700 Subject: [PATCH 021/114] [CK TILE GEMM] Refactor block_scale_gemm examples (#3181) * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * [CK TILE GEMM] Refactor block_scale_gemm examples - Set quant group size to (1, 1, 64) for targets excluding gfx950, where warp tile size (16, 16, 128) is incompatible. --- .../38_block_scale_gemm/CMakeLists.txt | 15 +- example/ck_tile/38_block_scale_gemm/README.md | 42 +- .../gemm_aquant_quantgrouped.cpp | 53 +++ .../gemm_bquant_quantgrouped_prefill_bf8.cpp | 47 ++ ...gemm_bquant_quantgrouped_prefill_bf8i4.cpp | 49 ++ .../gemm_bquant_quantgrouped_prefill_fp8.cpp | 47 ++ ...gemm_bquant_quantgrouped_prefill_fp8i4.cpp | 49 ++ ...quant_quantgrouped_preshuffleb_prefill.cpp | 53 +++ .../38_block_scale_gemm/gemm_quant.cpp | 130 ++++++ .../38_block_scale_gemm/gemm_quant_basic.cpp | 428 ------------------ .../38_block_scale_gemm/gemm_quant_rowcol.cpp | 30 ++ .../38_block_scale_gemm/gemm_quant_tensor.cpp | 30 ++ .../38_block_scale_gemm/gemm_utils.hpp | 54 +-- .../run_gemm_quant_example.inc | 273 ++++++++++- 14 files changed, 805 insertions(+), 495 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant.cpp delete mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index b1ae9369a2..932acb72fd 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -6,8 +6,19 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") - add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) - target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + set(EXE_NAME tile_example_gemm_quant) + add_executable(${EXE_NAME} EXCLUDE_FROM_ALL + gemm_quant.cpp + gemm_aquant_quantgrouped.cpp + gemm_bquant_quantgrouped_prefill_bf8i4.cpp + gemm_bquant_quantgrouped_prefill_fp8i4.cpp + gemm_bquant_quantgrouped_prefill_bf8.cpp + gemm_bquant_quantgrouped_prefill_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_prefill.cpp + gemm_quant_rowcol.cpp + gemm_quant_tensor.cpp + ) + target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 496697ca32..64ecebd15a 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic` ## example ``` args: - -b batch size (default:1) - -m m dimension (default:1024) - -n n dimension (default:2048) - -k k dimension (default:64) - -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: C) - -c_layout Tensor C data layout (default: R) - -stride_a Tensor A stride (default:0) - -stride_b Tensor B stride (default:0) - -stride_c Tensor C stride (default:0) - -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) - -e Absolute error tolerance (default:1e-5) - -prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8) - -warmup number of iterations before benchmark the kernel (default:10) - -repeat number of iterations to benchmark the kernel (default:100) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -quant_mode Which quant method to use (aquant, bquant, tensor, rowcol) + -h Print help message (default:false) + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -a_layout A tensor data layout - Row or Column (default:R) + -b_layout B tensor data layout - Row or Column (default:C) + -bq_layout Bq tensor data layout - Row or Column (default:C) + -c_layout C tensor data layout - Row or Column (default:R) + -stride_a Tensor A stride (default:0) + -stride_q Tensor AQ stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -warmup Number of iterations before benchmarking the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:1000) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k SplitK value (default:1) + -device Device id that will be used to run the kernel (default:0) + -init 0:random, 1:linear, 2:constant(1) (default:0) + -flush_cache Flush cache before running the kernel (default:true) +-rotating_count Rotating count (default:1000) + -quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant) + -preshuffleb Enable preshuffle of tensor B (default:false) + -group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128) ``` User need to select correct mapping of config for each quant mode: diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp new file mode 100644 index 0000000000..3786230ff0 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuant; + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp new file mode 100644 index 0000000000..cb9f8b62cf --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp new file mode 100644 index 0000000000..33ae3bc4a9 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp new file mode 100644 index 0000000000..526c35b081 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp new file mode 100644 index 0000000000..4b2a8efb14 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp new file mode 100644 index 0000000000..d9591bb588 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp new file mode 100644 index 0000000000..a35f867f5d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "gemm_utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("h", "false", "Print help message") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row or Column") + .insert("b_layout", "C", "B tensor data layout - Row or Column") + .insert("bq_layout", "C", "Bq tensor data layout - Row or Column") + .insert("c_layout", "R", "C tensor data layout - Row or Column") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0: No validation, 1: Validation on CPU, 2: Validation on GPU") + .insert("prec", + "fp8", + "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " + "or bf8i4") + .insert("warmup", "50", "Number of iterations before benchmarking the kernel") + .insert("repeat", "1000", "Number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "SplitK value") + .insert("device", "0", "Device id that will be used to run the kernel") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("flush_cache", "true", "Flush cache before running the kernel") + .insert("rotating_count", "1000", "Rotating count") + .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") + .insert("preshuffleb", "false", "Enable preshuffle of tensor B") + .insert("group_size", + "1x1x128", + "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +auto gen_lut_key(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); + + std::vector params = {data_type, quant_mode}; + + if(quant_mode == "bquant") + { + std::string preshuffleb = + arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb"; + params.push_back(preshuffleb); + } + if(quant_mode != "rowcol" && quant_mode != "tensor") + { + // NOTE: rowcol and tensor pipeline do not use group size + std::string group_size_str = arg_parser.get_str("group_size"); + params.push_back(group_size_str); + } + + return hash_multiple_strings(params); +} + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut); +void quant_rowcol_instance_factory( + std::unordered_map>& lut); +void quant_tensor_instance_factory( + std::unordered_map>& lut); + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result || arg_parser.get_bool("h")) + { + arg_parser.print(); + return -1; + } + + auto device_id = arg_parser.get_int("device"); + std::cout << "Device ID: " << device_id << std::endl; + ck_tile::hip_check_error(hipSetDevice(device_id)); + + std::unordered_map> lut; + aquant_quantgrouped_instance_factory(lut); + bquant_quantgrouped_fp8_instance_factory(lut); + bquant_quantgrouped_bf8_instance_factory(lut); + bquant_quantgrouped_fp8i4_instance_factory(lut); + bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_instance_factory(lut); + quant_rowcol_instance_factory(lut); + quant_tensor_instance_factory(lut); + + auto key = gen_lut_key(arg_parser); + + if(lut.find(key) != lut.end()) + { + return lut[key](arg_parser); + } + else + { + std::cerr + << "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported." + << std::endl; + return -1; + } +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp deleted file mode 100644 index d605a2b780..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ /dev/null @@ -1,428 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -// This example demonstrates 2D block scale quantization (N×K) for BQuant -// using non-preshuffled configuration. -// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example -// This is currently done separately to avoid too verbose dispatching. - -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" -#include "gemm_utils.hpp" - -template -float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - static_assert(std::is_same_v); - using ComputeDataType = std::conditional_t; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using GemmTraits = ck_tile::TileGemmQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - // This example only supports BQuant (no AQuant) - // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 - using BaseGemmPipeline = std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; - - const ck_tile::index_t K_split = - (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr bool transpose_c = false; - - // row-col and tensor quants use the regular pipeline, A/B quants use their own - using PipelineProblem = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmRowColTensorQuantPipelineProblem, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>>; - - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrMem, // memory pipeline hardcoded - // for aquant - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - GemmConfig::TiledMMAPermuteN>>; - using Kernel = - ck_tile::QuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << PipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - float ave_time = 0; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper - rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString( - hipMemsetAsync(args.c_ptr, - 0, - args.M * args.N * sizeof(typename TypeConfig::CDataType), - s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -#include "run_gemm_quant_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && - GemmConfig::PreshuffleB) - { - throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); - } - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -// Forward declaration for dispatch function -template