diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt new file mode 100644 index 0000000000..9fbe65e3a7 --- /dev/null +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) + +set(EXAMPLE_FLATMM_COMPILE_OPTIONS) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef) +target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/README.md b/example/ck_tile/18_flatmm/README.md new file mode 100644 index 0000000000..beaac785fc --- /dev/null +++ b/example/ck_tile/18_flatmm/README.md @@ -0,0 +1,35 @@ +# FLATMM Matrix Multiplication + +This folder contains example for FLATMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile FLATMM, but creates the placeholders for the future support on different FLATMM pipeline and different FLATMM modules. In the near future, we will gradually migrate all the FLATMM features from old CK to CK Tile. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the flatmm calculation +make tile_example_flatmm_basic -j +``` +This will result in an executable `build/bin/tile_example_flatmm_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp new file mode 100644 index 0000000000..05d0c73b7e --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "flatmm_basic.hpp" + +template +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) +{ + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 2; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + using CodegenFlatmmShape = + ck_tile::TileFlatmmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +#include "run_flatmm_example.inc" + +int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp new file mode 100644 index 0000000000..355ac45ebe --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -0,0 +1,100 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE 1 +#define CK_TILE_PIPELINE_MEMORY 2 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +using Types = GemmBasicTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc new file mode 100644 index 0000000000..864d888074 --- /dev/null +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + return t; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::FlatmmHostArgs args; + args.a_ptr = a_dev_buf.GetDeviceBuffer(); + args.b_shuffle_ptr = b_shuffle_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_dev_buf.GetDeviceBuffer(); + + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + + float ave_time = flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_flatmm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_host( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_origin_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_rslt_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + // TODO: add different init types + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + + a_dev_buf.ToDevice(a_host.data()); + c_rslt_host.SetZero(); + + // do pre-shuffle + std::string mfma = arg_parser.get_str("prec"); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, 0); + ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); + b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); + + invoke_flatmm(a_dev_buf, + b_shuffle_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_dev_buf.FromDevice(c_rslt_host.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_ref_host.SetZero(); + + ck_tile::reference_gemm( + a_host, b_origin_host, c_ref_host); + const float max_accumulated_value = + *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes()); + b_origin_dev_buf.ToDevice(b_origin_host.data()); + + ck_tile::HostTensor c_gpu_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes()); + c_gpu_ref_host.SetZero(); + c_gpu_ref_dev_buf.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy( + d_A, a_dev_buf.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_origin_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); + const float max_accumulated_value = + *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_gpu_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +int run_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/18_flatmm/script/smoke_test_basic.sh b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh new file mode 100755 index 0000000000..a3fc61cc31 --- /dev/null +++ b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh @@ -0,0 +1,34 @@ +#!/bin/bash +EXE="$(find . -name tile_example_flatmm_basic -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 128 4096; do + + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done +} + +set -x + +run_tests "bf16" +run_tests "fp16" + +set +x diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7f4ba2ed35..88efe0d8d9 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 82f6d48eda..1714789e63 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -3,10 +3,16 @@ #pragma once +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" +#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..935eb2c028 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockFlatmmASmemBSmemCRegV1 +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindow& a_block_window, + const BFlatBlockWindow& b_flat_block_window) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!"); + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + constexpr index_t NFlatPerBlockPerIter = BlockGemmShape::flatNPerWarp; + constexpr index_t KFlatPerBlockPerIter = BlockGemmShape::flatKPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // construct Bflat-warp-window + auto b_flat_warp_windows_tmp = b_flat_block_window; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_warp_windows; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_warp_windows(nIter)(kIter) = b_flat_warp_windows_tmp; + + move_tile_window(b_flat_warp_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + }); + }); + + // auto b_warp_windows = b_origin_warp_windows; + auto b_warp_windows = b_flat_warp_windows; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BFlatBlockWindow& b_flat_block_window) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_flat_block_window); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..d5b062a1b3 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +template +struct BlockFlatmmASmemBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp new file mode 100644 index 0000000000..eb45e6c0bd --- /dev/null +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -0,0 +1,496 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +struct FlatmmProblem +{ + CK_TILE_HOST FlatmmProblem() = default; + CK_TILE_HOST FlatmmProblem( + index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) + : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; +}; + +struct FlatmmHostArgs : public FlatmmProblem +{ + CK_TILE_HOST FlatmmHostArgs() = default; + CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_, + const void* b_shuffle_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_) + : FlatmmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), + a_ptr(a_ptr_), + b_shuffle_ptr(b_shuffle_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t k_batch; +}; + +template +struct FlatmmKernel +{ + using TilePartitioner = remove_cvref_t; + using FlatmmPipeline = remove_cvref_t; + using BlockGemmShape = + remove_cvref_t; // TileFlatmmShape + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, FlatmmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + struct FlatmmKernelArgs + { + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t k_batch; + }; + + CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) + { + return FlatmmKernelArgs{hostArgs.a_ptr, + hostArgs.b_shuffle_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + + if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead; + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead * kargs.stride_A; + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead * kargs.stride_B; + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead; + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = KRead; + } + else + { + splitted_k = kargs.K - KRead * (kargs.k_batch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs) + { + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + CDataType* c_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k / + BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + const auto& b_flat_tensor_view = [&]() { + return make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& b_flat_tensor_view = views.at(I1); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& b_flat_pad_view = views.at(I1); + const auto& c_pad_view = views.at(I2); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& b_flat_block_window = + make_tile_window(b_flat_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / BlockGemmShape::WarpTile::at(idxN)), 0}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_flat_block_window, c_block_window); + } + + template + CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + CDataType* c_ptr, + void* smem_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews(a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& c_block_tile = FlatmmPipeline{}.template operator()( + a_block_window, b_flat_block_window, num_loop, smem_ptr); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, smem_ptr); + } + + CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const + { + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_flat_ptr = static_cast(kargs.b_shuffle_ptr) + + splitk_batch_offset.b_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + if(kargs.k_batch == 1) + { + RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } + else + { + // Do not compile in case where we have unsupported + // VectorSizeC & data type configuration. + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunFlatmm( + a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp new file mode 100644 index 0000000000..3d08c7a788 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" + +namespace ck_tile { + +template +struct FlatmmPipelineAGmemBGmemCRegV1 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockFlatmm = + remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } + static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr index_t kLdsAlignmentInBytes = 16; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV1", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return PipelinePolicy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_flatmm = BlockFlatmm(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + // Acc register tile + auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){}; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledARegBlockDistribution()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } + } + + index_t iCounter = num_loop - 1; + while(iCounter > 0) + { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + block_sync_lds(); + + // GEMM i + block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + iCounter--; + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 1 + block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp new file mode 100644 index 0000000000..d1aac07d54 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +struct UniversalFlatmmPipelineAgBgCrPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + return Problem::VectorLoadSize; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0)) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M2, M1 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M1, M2 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using BDataType = remove_cvref_t; + + using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t KBPerLoad = + Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt + constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(TileShape::idxN); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<1, 2>>, // which direction + tuple, sequence<2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = kMPerBlock / M1; + constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackA(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + using BlockFlatmmPolicy = + BlockFlatmmASmemBSmemCRegV1CustomPolicy; + return BlockFlatmmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp new file mode 100644 index 0000000000..551d390ec6 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct TileFlatmmShape +{ + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto idxM = number<0>{}; + static constexpr auto idxN = number<1>{}; + static constexpr auto idxK = number<2>{}; + + static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + + static constexpr index_t kM = BlockTile::at(idxM); + static constexpr index_t kN = BlockTile::at(idxN); + static constexpr index_t kK = BlockTile::at(idxK); + + static constexpr index_t flatNPerWarp = BlockWarps::at(idxN); + static constexpr index_t flatKPerWarp = WarpTile::at(idxK) * WarpTile::at(idxN); + static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(idxK); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "tile_flatmm_shape", + concat('x', kM, kN, kK, NumWarps), + concat('x', BlockWarps::at(idxM), BlockWarps::at(idxN), BlockWarps::at(idxK)), + concat('x', (WarpTile::at(idxM)), WarpTile::at(idxN), WarpTile::at(idxK))); + // clang-format on + } +}; + +} // namespace ck_tile