diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index d479cd35f6..03f333eb5a 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -20,3 +20,4 @@ add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) +add_subdirectory(tutorial) diff --git a/example/ck_tile/tutorial/00_add_basic/CMakeLists.txt b/example/ck_tile/tutorial/00_add_basic/CMakeLists.txt new file mode 100644 index 0000000000..7dac79b535 --- /dev/null +++ b/example/ck_tile/tutorial/00_add_basic/CMakeLists.txt @@ -0,0 +1,21 @@ +set(EXAMPLE_ADD_BASIC "add_basic") + +message("adding example ${EXAMPLE_ADD_BASIC}") + +add_executable(${EXAMPLE_ADD_BASIC} EXCLUDE_FROM_ALL add_basic.cpp) +target_include_directories(${EXAMPLE_ADD_BASIC} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_ADD_BASIC_COMPILE_OPTIONS) + +# generate assembly +# list(APPEND EXAMPLE_ADD_BASIC_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_ADD_BASIC_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_ADD_BASIC} PRIVATE ${EXAMPLE_ADD_BASIC_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/tutorial/00_add_basic/add_basic.cpp b/example/ck_tile/tutorial/00_add_basic/add_basic.cpp new file mode 100644 index 0000000000..6abf42cc01 --- /dev/null +++ b/example/ck_tile/tutorial/00_add_basic/add_basic.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "reference_add_vector.hpp" +#include "add_basic.hpp" +#include + +// This example demonstrates how to use the ck_tile library to perform an elementwise vector +// addition using a custom kernel. The kernel is defined in the vector_add.hpp file, and the +// reference implementation is provided in the reference_vector_add.hpp file. + +// parse command line arguments +// -m: size of the vectors +// -v: validation flag (1 for validation, 0 for no validation) +// -prec: precision of the data type (fp16, fp32, int8, int32) +// -warmup: number of warmup iterations (number of kernel launches before measuring performance) +// -repeat: number of repeat iterations (number of kernel launches to measure performance) +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "41943040", "m dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; // input data type + using ComputeDataType = float; // compute data type + using YDataType = DataType; // output data type + + ck_tile::index_t m = arg_parser.get_int("m"); // size of the vectors + int do_validation = arg_parser.get_int("v"); // do we verify the result on cpu + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::HostTensor x_host_a( + {m}); // length input vector A, if given two arguments (m, n) the HostTensor will be created + // with shape (m, n) + ck_tile::HostTensor x_host_b( + {m}); // length input vector B, if given two arguments (m, n) the HostTensor will be created + // with shape (m, n) + + ck_tile::HostTensor y_host_ref({m}); + ck_tile::HostTensor y_host_dev({m}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}( + x_host_a); // fill the input vector A with random values + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host_b); + + ck_tile::DeviceMem x_buf_a( + x_host_a.get_element_space_size_in_bytes()); // allocate device memory for input vector A + // (this a wrapper over hipMalloc) + ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf_a.ToDevice( + x_host_a + .data()); // copy the input vector A to device memory, this is a wrapper over hipMemcpy + x_buf_b.ToDevice(x_host_b.data()); + + // Dividing the problem into blocktile, warptile, and vector + // The blocktile is the size of the tile that will be processed by a single thread block (also + // called work group) The warptile is the size of the tile that will be processed by a single + // warp (also called wavefront) The vector is the size of the tile that will be processed by a + // single thread (also called work item) The problem is divided into blocks of size BlockTile, + // each block is further divided into warps of size WarpTile and each warp is composed of 64 or + // 32 threads of size Vector each of the thread in a warp will process one vector worth elements + // of the data + using BlockTile = ck_tile::sequence<8192>; // Size of the block tile (Entire problem is divided + // into blocks of this size) + using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp + // will cover some part of blockTile) + using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp + using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread + // will cover some part of WarpTile) + + // Interpretation of above configurations + // Each thread will cover 1 element (Vector) + // Each WarpTile will cover 64 elements (WarpTile) --> since 64 threads in a warp + // if we have 8 warps in a block (BlockWarps) then we have 8 * 64 = 512 threads in a block + // if 8 warps are not enough to cover the entire blockTile then each of the 8 concurrent warps + // will iterate over the blockTile several times + + constexpr ck_tile::index_t kBlockSize = 512; + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl; + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::AddVectorShape; + std::cout << "Problem Shape:: M = " << m << std::endl; + std::cout << "BlockTile: " << BlockTile::at(ck_tile::number<0>{}) << std::endl; + std::cout << "Number of Blocks in Grid: " << m / BlockTile::at(ck_tile::number<0>{}) + << std::endl; + std::cout << "BlockWarps: " << BlockWarps::at(ck_tile::number<0>{}) << std::endl; + std::cout << "WarpTile: " << WarpTile::at(ck_tile::number<0>{}) << std::endl; + std::cout << "Vector: " << Vector::at(ck_tile::number<0>{}) << std::endl; + std::cout << "Repeat: " << Shape::Repeat_M + << std::endl; // number of times a warp will iterate over the blockTile, covering + // different parts of the blockTile + std::cout << "Threads per Block: " << kBlockSize << std::endl; + std::cout << "ThreadBlocks per CU: " << kBlockPerCu << std::endl; + + // What is a Problem in CKTile? + // A Problem defines the shape of the data, the precision of the data + using Problem = ck_tile::AddVectorProblem; + + // What is a Policy in CKTile? + // A Policy defines how to map the data between threads and data in memory + + // The kernel is the function that will be executed on the device + // It requires a Problem and Policy to be defined + using Kernel = ck_tile::AddVectorKernel; + + // The kernel is launched with the following parameters: + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, // wrapper over hipStreamCreate + ck_tile::make_kernel( // numOfThreadsPerBlock, numOfBlocksPerCU + Kernel{}, // kernel + kGridSize, // number of blocks in the grid + kBlockSize, // number of threads in a block + 0, // shared memory size + static_cast(x_buf_a.GetDeviceBuffer()), // input vector A + static_cast(x_buf_b.GetDeviceBuffer()), // input vector B + static_cast(y_buf.GetDeviceBuffer()), // output vector + m)); + + std::size_t num_btype = sizeof(XDataType) * m + sizeof(YDataType) * m; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + ck_tile::reference_add_vector(x_host_a, x_host_b, y_host_ref); + y_buf.FromDevice(y_host_dev.mData.data()); + pass = ck_tile::check_err(y_host_dev, y_host_ref); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/example/ck_tile/tutorial/00_add_basic/add_basic.hpp b/example/ck_tile/tutorial/00_add_basic/add_basic.hpp new file mode 100644 index 0000000000..1113d22729 --- /dev/null +++ b/example/ck_tile/tutorial/00_add_basic/add_basic.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +// struct that holds the tile size of the block, warp, and vector +// and the number of warps per block +// and the number of threads per warp +// and the number of times the warp tile is repeated in the block tile +// and the block size +template +struct AddVectorShape +{ + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); + + static constexpr index_t Vector_M = Vector::at(number<0>{}); + + static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + + static constexpr index_t Repeat_M = + Block_M / + (WarpPerBlock_M * Warp_M); // Number of times the warp tile is repeated in the block tile + + static constexpr index_t BlockSize = + warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); +}; + +template +struct AddVectorProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +// data mapping beween threads and memory +struct AddDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding, // Replicate + tuple>, // Hierarchical + tuple, sequence<1>>, // Parallel + tuple, sequence<2>>, // Parallel + sequence<1, 1>, // Yield + sequence<0, 3>>{} // Yield + ); + } +}; + +template +struct AddVectorKernel +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + + // body of the kernel + CK_TILE_DEVICE void + operator()(const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M) const + { + using S = typename Problem::BlockShape; + + // create tensor view for the input and output data, this defines how the data is laid out + // in memory + const auto x_m_n_a = make_naive_tensor_view( + p_x_a, + make_tuple(M), + make_tuple(1), + number{}); // raw pointer, shape of the tensor, stride of the tensor, and + // lastGarunteedVectorLength + + const auto x_m_n_b = make_naive_tensor_view( + p_x_b, make_tuple(M), make_tuple(1), number{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M), make_tuple(1), number{}); + + // origin of the block tile + const auto iM = get_block_id() * S::Block_M; + + // creating tile windows for the input and output data + auto x_window_a = make_tile_window(x_m_n_a, + make_tuple(number{}), + {iM}, + Policy::template MakeXBlockTileDistribution()); + + auto x_window_b = make_tile_window(x_m_n_b, + make_tuple(number{}), + {iM}, + Policy::template MakeXBlockTileDistribution()); + + auto y_window = make_tile_window(y_m_n, + make_tuple(number{}), + {iM}, + Policy::template MakeXBlockTileDistribution()); + + // Load tile data + const auto xa = + load_tile(x_window_a); // load tile data from global tensor view, load from where? what? + // how many? logical memory layout? all are defined in x_window_a + const auto xb = load_tile(x_window_b); + auto y_compute = load_tile(y_window); + + // Process the vector add + constexpr auto spans = decltype(xa)::get_distributed_spans(); // shape of the tile + sweep_tile_span(spans[number<0>{}], [&](auto idx) { // iterate over the tile + const auto tile_idx = make_tuple(idx); + const auto a_val = type_convert(xa[tile_idx]); + const auto b_val = type_convert(xb[tile_idx]); + y_compute(tile_idx) = a_val + b_val; + + }); + + // Store results + store_tile(y_window, + cast_tile(y_compute)); // store the result back to global tensor view + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/00_add_basic/reference_add_vector.hpp b/example/ck_tile/tutorial/00_add_basic/reference_add_vector.hpp new file mode 100644 index 0000000000..cd453386c8 --- /dev/null +++ b/example/ck_tile/tutorial/00_add_basic/reference_add_vector.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_add_vector(const HostTensor& xa_m_n, + const HostTensor& xb_m_n, + HostTensor& y_m_n) +{ + auto f = [&](auto m) { + const int N = 1; + + for(int n = 0; n < N; ++n) + { + y_m_n(m, n) = ck_tile::type_convert(xa_m_n(m, n)) + + ck_tile::type_convert(xb_m_n(m, n)); + } + }; + + make_ParallelTensorFunctor(f, + y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/01_add/CMakeLists.txt b/example/ck_tile/tutorial/01_add/CMakeLists.txt new file mode 100644 index 0000000000..2906797992 --- /dev/null +++ b/example/ck_tile/tutorial/01_add/CMakeLists.txt @@ -0,0 +1,21 @@ +set(EXAMPLE_ADD "add") + +message("adding example ${EXAMPLE_ADD}") + +add_executable(${EXAMPLE_ADD} EXCLUDE_FROM_ALL add.cpp) +target_include_directories(${EXAMPLE_ADD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_ADD_COMPILE_OPTIONS) + +# generate assembly +# list(APPEND EXAMPLE_ADD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_ADD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_ADD} PRIVATE ${EXAMPLE_ADD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/tutorial/01_add/add.cpp b/example/ck_tile/tutorial/01_add/add.cpp new file mode 100644 index 0000000000..347ff4cb41 --- /dev/null +++ b/example/ck_tile/tutorial/01_add/add.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "reference_add.hpp" +#include "add.hpp" +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "10240", "m dimension") + .insert("n", "4096", "n dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "200", "cold iter") + .insert("repeat", "1000", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::HostTensor x_host_a({m, n}); + ck_tile::HostTensor x_host_b({m, n}); + + ck_tile::HostTensor y_host_ref({m, n}); + ck_tile::HostTensor y_host_dev({m, n}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host_a); + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host_b); + + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf_a.ToDevice(x_host_a.data()); + x_buf_b.ToDevice(x_host_b.data()); + + using BlockWarps = + ck_tile::sequence<1, 8>; // number of concurrent warps in one block (if 8 warps * 64 threads + // per warp, 512 threads in one block are NEEDED) + using BlockTile = + ck_tile::sequence<1, 4096>; // shape of one blockTile (elements covered by one block) + using WarpTile = ck_tile::sequence<1, 512>; // shape of one warpTile (elements covered by one + // warp (64 threads)) + using Vector = ck_tile::sequence<1, 8>; // shape of one vector (elements covered by one thread) + + constexpr ck_tile::index_t kBlockSize = + 512; // number of blockWarps * number of threads per warp + constexpr ck_tile::index_t kBlockPerCu = 1; + ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl; + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::AddShape; + using Porblem = ck_tile::AddProblem; + + using Kernel = ck_tile::Add; + + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf_a.GetDeviceBuffer()), + static_cast(x_buf_b.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n)); + + std::size_t num_btype = 2 * sizeof(XDataType) * m * n + sizeof(YDataType) * m * n; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + ck_tile::reference_add(x_host_a, x_host_b, y_host_ref); + y_buf.FromDevice(y_host_dev.mData.data()); + pass = ck_tile::check_err(y_host_dev, y_host_ref); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/example/ck_tile/tutorial/01_add/add.hpp b/example/ck_tile/tutorial/01_add/add.hpp new file mode 100644 index 0000000000..a8201e5f6f --- /dev/null +++ b/example/ck_tile/tutorial/01_add/add.hpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WarpTile, // warp size, seq + typename Vector> // contiguous pixels(vector size) along seq +struct AddShape +{ + static constexpr index_t Block_M = BlockTile::at(number<0>{}); // elements along M in one Block + static constexpr index_t Block_N = BlockTile::at(number<1>{}); // elements along N in one Block + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); // elements along M in one Warp + static constexpr index_t Warp_N = WarpTile::at(number<1>{}); // elements along N in one Warp + + static constexpr index_t Vector_M = Vector::at(number<0>{}); // elements along M in one Vector + static constexpr index_t Vector_N = Vector::at(number<1>{}); // elements along N in one Vector + + static constexpr index_t WarpPerBlock_M = + BlockWarps::at(number<0>{}); // num concurrent warps along M + static constexpr index_t WarpPerBlock_N = + BlockWarps::at(number<1>{}); // num concurrent warps along N + + static constexpr index_t ThreadPerWarp_M = + Warp_M / + Vector_M; // num threads along M in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64) + static constexpr index_t ThreadPerWarp_N = + Warp_N / + Vector_N; // num threads along N in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64) + + static constexpr index_t Repeat_M = + Block_M / + (WarpPerBlock_M * + Warp_M); // num of time a warp iterates along M to ensure the entire block is covered + static constexpr index_t Repeat_N = + Block_N / + (WarpPerBlock_N * + Warp_N); // num of time a warp iterates along N to ensure the entire block is covered + + static constexpr index_t BlockSize = + warpSize * + reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); // num of threads in one block +}; + +template +struct AddProblem +{ + using XDataType = remove_cvref_t; // data type of input tensor + using ComputeDataType = remove_cvref_t; // data type of compute tensor + using YDataType = remove_cvref_t; // data type of output tensor + using BlockShape = remove_cvref_t; // block shapes and sizes +}; + +struct AddDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, // how many sub division is a block divided in + sequence>, // how many sub division is a block divided in + tuple, sequence<1, 2>>, // What are the shapes of those sub divisions + tuple, sequence<2, 2>>, // What are the shapes of those sub divisions + sequence<1, 1, 2, 2>, // How much data does a thread work on and how many iterations + // of warps are there + sequence<0, 3, 0, 3>>{}); // How much data does a thread work on and how many + // iterations of warps are there + } +}; + +template +struct Add +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + + CK_TILE_DEVICE void operator()( + const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + const auto x_m_n_a = make_naive_tensor_view( + p_x_a, + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); // raw data, shape of tensor, stride of tensor, lastGarunteedVectorLength, + // lastGarunteedVectorStride + + const auto x_m_n_b = make_naive_tensor_view( + p_x_b, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto iM = get_block_id() * S::Block_M; // origin of the block along + + auto x_window_a = make_tile_window(x_m_n_a, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + auto x_window_b = make_tile_window(x_m_n_b, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + auto y_window = make_tile_window(y_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto xa = load_tile(x_window_a); + const auto xb = load_tile(x_window_b); + auto y_compute = load_tile(y_window); + + constexpr auto spans = decltype(xa)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + const auto x = ck_tile::type_convert(xa[i_j_idx]); + const auto y = ck_tile::type_convert(xb[i_j_idx]); + y_compute(i_j_idx) = x + y; + }); + }); + + store_tile(y_window, cast_tile(y_compute)); + move_tile_window(x_window_a, {0, S::Block_N}); + move_tile_window(x_window_b, {0, S::Block_N}); + move_tile_window(y_window, {0, S::Block_N}); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/01_add/reference_add.hpp b/example/ck_tile/tutorial/01_add/reference_add.hpp new file mode 100644 index 0000000000..0e15f76899 --- /dev/null +++ b/example/ck_tile/tutorial/01_add/reference_add.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_add(const HostTensor& xa_m_n, + const HostTensor& xb_m_n, + HostTensor& y_m_n) +{ + auto f = [&](auto m) { + const int N = xa_m_n.mDesc.get_lengths()[1]; + + for(int n = 0; n < N; ++n) + { + y_m_n(m, n) = ck_tile::type_convert(xa_m_n(m, n)) + + ck_tile::type_convert(xb_m_n(m, n)); + } + }; + + make_ParallelTensorFunctor(f, + y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/02_gemm/CMakeLists.txt b/example/ck_tile/tutorial/02_gemm/CMakeLists.txt new file mode 100644 index 0000000000..db0937b314 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/CMakeLists.txt @@ -0,0 +1,26 @@ +set(EXAMPLE_BASIC_GEMM "basic_gemm") + +message("adding example ${EXAMPLE_BASIC_GEMM}") + +add_executable(${EXAMPLE_BASIC_GEMM} EXCLUDE_FROM_ALL gemm.cpp) +target_include_directories(${EXAMPLE_BASIC_GEMM} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS) + +# generate assembly +# list(APPEND EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +if(DEFINED kernel) + message("Compiling with Kernel: ${kernel}") + target_compile_definitions(${EXAMPLE_BASIC_GEMM} PRIVATE KERNEL_${kernel}=1) +endif() + +target_compile_options(${EXAMPLE_BASIC_GEMM} PRIVATE ${EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..fcd8c0c997 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "block_gemm_asmem_bsmem_creg_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmASmemBSmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + +#if defined(ENABLE_PREFETCH) + // A block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM); + constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN); + constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile aWarpTile; + BLdsTile bWarpTile; + + // Prefetch from LDS to warp register + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + aWarpTile = load_tile(a_block_window); + bWarpTile = load_tile(b_block_window); + } +#endif + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + +#if !defined(ENABLE_PREFETCH) + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; +#if defined(ENABLE_PREFETCH) +#pragma message("local data share prefetch") + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); +#else + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); +#endif + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); +#else + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); +#endif + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + +#if !defined(ENABLE_PREFETCH) + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + // Hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; +#if defined(ENABLE_PREFETCH) + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); +#else + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); +#endif + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); +#else + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); +#endif + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + // Warp GEMM + if constexpr(KIterPerWarp == 0) + { + // c = a * b + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + // c += a * b + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp new file mode 100644 index 0000000000..379f993db7 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +#include "config.h" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCReg +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBSmemCRegDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { +#if defined(ADJUST_BLOCK_TILE_SHAPE) + constexpr index_t kMWarp = 2; + constexpr index_t kNWarp = 2; +#else + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; +#endif + +#if defined(NAIVE_IMPLEMENTATION) +#pragma message("mfma m32 n32 k8") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } +#elif defined(USING_MFMA_32x32x_8x2) +#pragma message("mfma m32 n32 k16") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); + } +#elif defined(USING_MFMA_16x16x16) +#pragma message("mfma m16 n16 k16") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); + } +#elif defined(USING_MFMA_16x16x_16x2) +#pragma message("mfma m16 n16 k32") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); + } +#endif + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/tutorial/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..37e90ff0b7 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + using BlockGemm = remove_cvref_t())>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + +#if defined(ENABLE_INSTRUCTION_SCH) + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + + static constexpr index_t GetSmemPack() { return Policy::template GetSmemPack(); } + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemm::MWarp; + constexpr index_t WaveNumN = BlockGemm::NWarp; + + constexpr index_t AB_LDS_RW_Width = GetSmemPack(); + + constexpr index_t A_Buffer_Load_Inst_Num = + kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + constexpr index_t B_LDS_Write_Inst_Num = + kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 + ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_a = + (num_mfma_per_issue - num_dswrite_per_issue_a * 2 >= 1) ? 2 : 1; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_a * + num_dswrite_per_issue_a, + 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } +#endif + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // ----------------------------------------------------------------------------------------- + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + +#if defined(ENABLE_PREFETCH) + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); +#else + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); +#endif + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); + + // ------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + +#if defined(ENABLE_PREFETCH) +#pragma message("global prefetch") + // Prefetch + // Global read 0 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + + if(num_loop > 1) + { + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // LDS write 0 + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + // Prefetch from LDS to warp register in block gemm + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + } + + __builtin_amdgcn_sched_barrier(0); + + // Main body + if(num_loop > 2) + { + index_t iCounter = 0; + do + { + block_sync_lds(); + + // LDS write 1 + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // Prefetch from LDS to warp register in block gemm + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + +#if defined(ENABLE_INSTRUCTION_SCH) + HotLoopScheduler(); +#endif + + __builtin_amdgcn_sched_barrier(0); + + iCounter += 1; + } while(iCounter < (num_loop - 2)); + } + + // Tail + if(num_loop > 1) + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); +#else + // non-prefetch + index_t iCounter = num_loop; + + while(iCounter > 0) + { + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + iCounter--; + } +#endif + return c_block_tile; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/tutorial/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp new file mode 100644 index 0000000000..9c387b71c9 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "block_gemm_asmem_bsmem_creg.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" + +#include "config.h" + +namespace ck_tile { + +// Default policy for BlockGemmPipelineAGmemBGmemCReg +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy +{ + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + +#if defined(NAIVE_IMPLEMENTATION) + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_K_FIRST) + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_MN_FIRST) + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE) + using ADataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); +#endif + return a_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + +#if defined(PADDING_K_FIRST) || defined(NAIVE_IMPLEMENTATION) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_K_FIRST) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_MN_FIRST) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE) + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); +#endif + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + +#if defined(ENABLE_INSTRUCTION_SCH) + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + + // Assume DataType is even! + if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 && + PackedSize == 2) + { + return (PackedSize * 32 / sizeof(DataType)); + } + else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) + { + return (PackedSize * 16 / sizeof(DataType)); + } + else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) + { + return (PackedSize * 8 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 4 && + XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0) + { + return (PackedSize * 4 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 2 && + XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0) + { + return (PackedSize * 2 / sizeof(DataType)); + } + else + { + return PackedSize; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + { + using ADataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + return GetGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + { + using BDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + return GetGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Problem::TransposeC; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPack() + { + constexpr index_t kKPack = 8; + return kKPack; + } +#endif + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + return BlockGemmASmemBSmemCReg{}; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/02_gemm/config.h b/example/ck_tile/tutorial/02_gemm/config.h new file mode 100644 index 0000000000..f284eefcc4 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/config.h @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#if defined(KERNEL_A) +#define PADDING_K_FIRST +#define USING_MFMA_32x32x_8x2 +#elif defined(KERNEL_B) +#define PADDING_K_FIRST +#define USING_MFMA_16x16x16 +#elif defined(KERNEL_C) +#define PADDING_K_FIRST +#define USING_MFMA_16x16x_16x2 +#elif defined(KERNEL_D) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#elif defined(KERNEL_E) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#elif defined(KERNEL_F) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#elif defined(KERNEL_G) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#define ENABLE_INSTRUCTION_SCH +#elif defined(KERNEL_H) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#define ENABLE_INSTRUCTION_SCH +#define ENABLE_CACHE_AWARE_WG_SCH +#else +#define NAIVE_IMPLEMENTATION +#endif diff --git a/example/ck_tile/tutorial/02_gemm/gemm.cpp b/example/ck_tile/tutorial/02_gemm/gemm.cpp new file mode 100644 index 0000000000..d7aeaa3dd0 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/gemm.cpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "config.h" +#include "ck_tile/host.hpp" +#include "gemm.hpp" +#include "reference_gemm.hpp" + +/* + * Toy code of GEMM + * Assume simplest case. + * A [M, K] + * B [N, K] + * C [M, N] + */ + +// elementwise lambda +struct CElementFunction +{ + template + CK_TILE_HOST_DEVICE auto operator()(const X& x) const + { + return x; + } +}; + +int main(int argc, char* argv[]) +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + + ck_tile::index_t verification = 0; + ck_tile::index_t M = 3328; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + + if(argc == 2) + { + verification = std::stoi(argv[1]); + } + if(argc == 5) + { + verification = std::stoi(argv[1]); + M = std::stoi(argv[2]); + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } + +#if defined(KERNEL_A) + printf("*** Kernel A test *** \n"); + printf(" --> Using mfma_32x32x(8x2)\n"); +#elif defined(KERNEL_B) + printf("*** Kernel B test *** \n"); + printf(" --> Using mfma_16x16x16\n"); +#elif defined(KERNEL_C) + printf("*** Kernel C test *** \n"); + printf(" --> Using mfma_16x16x(16x2)\n"); +#elif defined(KERNEL_D) + printf("*** Kernel D test *** \n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); +#elif defined(KERNEL_E) + printf("*** Kernel E test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); +#elif defined(KERNEL_F) + printf("*** Kernel F test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); +#elif defined(KERNEL_G) + printf("*** Kernel G test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); + printf(" --> Enable instruction schedule\n"); +#elif defined(KERNEL_H) + printf("*** Kernel H test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); + printf(" --> Enable instruction schedule\n"); + printf(" --> Enable cache-aware thread blocks schedule\n"); +#else + printf("*** Naive implementation test ***\n"); +#endif + + const ck_tile::index_t Lda = K; + const ck_tile::index_t Ldb = K; + const ck_tile::index_t Ldc = N; + + const auto a_lengths = std::array{M, K}; + const auto a_strides = std::array{Lda, 1}; + + const auto b_lengths = std::array{N, K}; + const auto b_strides = std::array{Ldb, 1}; + + const auto c_lengths = std::array{M, N}; + const auto c_strides = std::array{Ldc, 1}; + + // host verify + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.mData.data()); + b_buf.ToDevice(b_host.mData.data()); + + // Alignment + constexpr ck_tile::index_t kAAlignment = 8; + constexpr ck_tile::index_t kBAlignment = 8; + constexpr ck_tile::index_t kCAlignment = 8; + + constexpr ck_tile::index_t kBlockSize = 256; + +#ifdef ADJUST_BLOCK_TILE_SHAPE + constexpr ck_tile::index_t kGemmMPerBlock = 128; + constexpr ck_tile::index_t kGemmKPerBlock = 64; +#else + constexpr ck_tile::index_t kGemmMPerBlock = 256; + constexpr ck_tile::index_t kGemmKPerBlock = 32; +#endif + constexpr ck_tile::index_t kGemmNPerBlock = 128; + + ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + using gemm_kernel = ck_tile::Gemm; + + float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000}, + ck_tile::make_kernel( + gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + static_cast(c_buf.GetDeviceBuffer()), + M, + N, + K, + Lda, + Ldb, + Ldc, + CElementFunction{})); + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + c_buf.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/example/ck_tile/tutorial/02_gemm/gemm.hpp b/example/ck_tile/tutorial/02_gemm/gemm.hpp new file mode 100644 index 0000000000..46f4bf1d32 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/gemm.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "config.h" +#include "grid_gemm.hpp" + +namespace ck_tile { + +template +struct GridGemmProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + using CElementFunction = CElementFunction_; +}; + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +// C = A * B +template +struct Gemm +{ + using GridGemmProblem = + GridGemmProblem; + + struct GridGemmPolicy + { + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kMPerBlock_; + static constexpr index_t kNPerBlock = kNPerBlock_; + static constexpr index_t kKPerBlock = kKPerBlock_; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { +#if defined(ENABLE_CACHE_AWARE_WG_SCH) + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = + (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if(rlen > 0) + { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if(rlen > 0 and rm < M0) + { + n = rn + update_N0; + m = rm; + } + else + { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +#else + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() + { + using BlockGemmPipelineProblem_ = + BlockGemmPipelineProblem>; + return BlockGemmPipelineAGmemBGmemCReg{}; + } + }; + + using GridGemm = GridGemm; + + CK_TILE_DEVICE void operator()(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t Lda, + const index_t Ldb, + const index_t Ldc, + const CElementFunction& c_element_func) const + { + const auto a_dram = [&] { + return make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(Lda, 1), number{}, number<1>{}); + }(); + + const auto b_dram = [&] { + return make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(Ldb, 1), number{}, number<1>{}); + }(); + + const auto c_dram = [&] { + return make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(Ldc, 1), number{}, number<1>{}); + }(); + + GridGemm{}(a_dram, b_dram, c_dram, c_element_func); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/02_gemm/grid_gemm.hpp b/example/ck_tile/tutorial/02_gemm/grid_gemm.hpp new file mode 100644 index 0000000000..ba7596853a --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/grid_gemm.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct GridGemm +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + using CElementFunction = typename Problem::CElementFunction; + + static constexpr auto kMPerBlock = Policy::kMPerBlock; + static constexpr auto kNPerBlock = Policy::kNPerBlock; + static constexpr auto kKPerBlock = Policy::kKPerBlock; + + template + CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid, + const BGridTensorView& b_grid, + CGridTensorView& c_grid, + const CElementFunction& c_element_func) const + { + const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{}); + const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{}); + const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{}); + + // divide problem + const auto id_block = get_block_id(); + + const auto num_tile_m = integer_divide_ceil(M, kMPerBlock); + const auto num_tile_n = integer_divide_ceil(N, kNPerBlock); + + const auto block2tile = Policy::template MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto id_tile = block2tile(id_block); + + const auto iM = + __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock); + const auto iN = + __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock); + + // A block window + auto a_block_window = make_tile_window( + a_grid, make_tuple(number{}, number{}), {iM, 0}); + + // B block window + auto b_block_window = make_tile_window( + b_grid, make_tuple(number{}, number{}), {iN, 0}); + + constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline(); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; + + const auto acc_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char); + + // cast to CDataType and apply CElementFunction + const auto c_block_tile = tile_elementwise_in( + [&](const auto& acc) { return c_element_func(type_convert(acc)); }, + acc_block_tile); + + // store C + auto c_window = make_tile_window( + c_grid, make_tuple(number{}, number{}), {iM, iN}); + + store_tile(c_window, c_block_tile); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/02_gemm/reference_gemm.hpp b/example/ck_tile/tutorial/02_gemm/reference_gemm.hpp new file mode 100644 index 0000000000..951f99c252 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/reference_gemm.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, + const ck_tile::HostTensor& b_n_k, + ck_tile::HostTensor& c_m_n) +{ + const int N = b_n_k.mDesc.get_lengths()[0]; + const int K = b_n_k.mDesc.get_lengths()[1]; + + auto f = [&](auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_n_k(n, k); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(v_acc); + } + }; + + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])( + std::thread::hardware_concurrency()); +} diff --git a/example/ck_tile/tutorial/02_gemm/stream_config.hpp b/example/ck_tile/tutorial/02_gemm/stream_config.hpp new file mode 100644 index 0000000000..1469b74b75 --- /dev/null +++ b/example/ck_tile/tutorial/02_gemm/stream_config.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; + int log_level_ = 0; +}; diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/tutorial/03_flash_attention_fwd/CMakeLists.txt new file mode 100644 index 0000000000..11e85d6030 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/CMakeLists.txt @@ -0,0 +1,38 @@ +set(EXAMPLE_BASIC_FLASH_ATTENTION "basic_flash_attention_fwd") + +message("adding example ${EXAMPLE_BASIC_FLASH_ATTENTION}") + +add_executable(${EXAMPLE_BASIC_FLASH_ATTENTION} EXCLUDE_FROM_ALL flash_attention_fwd.cpp) +target_include_directories(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS) + +# list(APPEND EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF) +if(ENABLE_TOY_FA_FWD_OPT) + message("Compiling with toy FA fwd optimization") + target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_OPT) +endif() + +option(ENABLE_TOY_FA_FWD_QK_SWIZZLE "Enable toy FA fwd QK swizzle" OFF) +if(ENABLE_TOY_FA_FWD_QK_SWIZZLE) + message("Compiling with toy FA fwd QK swizzle") + target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_QK_SWIZZLE) +endif() + +option(ENABLE_TOY_FA_FWD_CACHE_AWARE "Enable toy FA fwd cache aware" OFF) +if(ENABLE_TOY_FA_FWD_CACHE_AWARE) + message("Compiling with toy FA fwd cache aware") + target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_CACHE_AWARE) +endif() + +target_compile_options(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE ${EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_problem.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_problem.hpp new file mode 100644 index 0000000000..b56823b1be --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_problem.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// Problem Description for BlockGemmARegBSmemCReg +template +struct BlockGemmARegBSmemCRegProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..ba87ab9c9b --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -0,0 +1,565 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +#include "block_gemm_areg_bsmem_creg_problem.hpp" +#include "block_gemm_areg_bsmem_creg_v1_default_policy.hpp" +#include "block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV1 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using BlockGemmPolicy = Policy; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct C-Block-Tensor + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..ad9c3218b3 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +namespace ck_tile { + +struct BlockGemmARegBSmemCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { +#if !defined(TOY_FA_FWD_QK_SWIZZLE) + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#else +#pragma message("Enable toy FA fwd QK swizzle") + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); +#endif + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp new file mode 100644 index 0000000000..d7e6ee4c96 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +namespace ck_tile { + +struct BlockGemmARegBSmemCRegV1K8Policy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { +#if !defined(TOY_FA_FWD_QK_SWIZZLE) + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}, 4, 1); +#endif + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp new file mode 100644 index 0000000000..eb3ba8f6db --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t k_loops = Policy::AKDim / kKPerBlock; + + // Move this part into Policy? + __host__ __device__ static constexpr index_t GetStaticLdsSize() + { + return sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + // Cold A Register Cache + template + __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + + ignore = a_element_func; + ignore = b_element_func; + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm( + get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), + b_lds_gemm_window)){}; + + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(k_loops > 1) + { + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + sequence<0, 0>{}, + sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + } + + __builtin_amdgcn_sched_barrier(0); + + if constexpr(k_loops > 2) + { + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + sequence<0, (i_k0 + 1) * kKPerBlock>{}, + sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }); + } + + // tail + { + if constexpr(k_loops > 1) + { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + + block_sync_lds(); + } + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + sequence<0, (k_loops - 1) * kKPerBlock>{}, + sequence{}); + + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 1) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + } + + set_slice_tile(a_reg_block_tensor_tmp, + a_copy_reg_tensor, + sequence<0, 0>{}, + sequence{}); + + return c_block_tile; + } + + // Hot A Register Cache + template + __host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + ignore = b_element_func; + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + + set_slice_tile(a_copy_reg_tensor, + a_reg_block_tensor_tmp, + sequence<0, 0>{}, + sequence{}); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); + + // Acc register tile + auto c_block_tile = decltype(block_gemm( + get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), + b_lds_gemm_window)){}; + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + if constexpr(k_loops > 1) + { + // LDS write 0 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); + } + + if constexpr(k_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); + + block_sync_lds(); + + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }); + } + // tail + { + if constexpr(k_loops > 1) + { + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + bWarpTile); + + block_sync_lds(); + } + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + bWarpTile = load_tile(b_lds_gemm_window); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 1) * kKPerBlock>{}, + sequence{}), + bWarpTile); + } +#endif + return c_block_tile; + } + + template + __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } + + template + __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp new file mode 100644 index 0000000000..bdf5818325 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +namespace ck_tile { + +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy +{ + static constexpr index_t AKDim = AKDim_; + + template + __host__ __device__ static constexpr auto GetBlockGemm() + { + using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy; + + return BlockGemmARegBSmemCRegV1{}; + } + + template + __host__ __device__ static constexpr auto MakeARegBlockDescriptor() + { + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = remove_cvref_t; + + static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!"); + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = AKDim; + + constexpr auto config = + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + return a_block_dstr; + } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_problem.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_problem.hpp new file mode 100644 index 0000000000..233194ec1e --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/block_gemm_pipeline_problem.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd.cpp new file mode 100644 index 0000000000..2de69b5855 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd.cpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck_tile/host.hpp" + +#include "reference_batched_gemm.hpp" +#include "reference_batched_softmax.hpp" +#include "flash_attention_fwd.hpp" + +/* + * Toy code of flash attention forward pass + * Assume simplest case. + * Q [Batch, HeadNum, SeqenceLengthQ, HeadDim] + * K [Batch, HeadNum, SeqenceLengthK, HeadDim] + * V [Batch, HeadNum, HeadDim, SeqenceLengthK] + * O [Batch, HeadNum, SeqenceLengthQ, HeadDim] + */ + +int main(int argc, char* argv[]) +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using SaccDataType = float; + using SMPLComputeDataType = float; + using PDataType = ck_tile::half_t; + using OaccDataType = float; + using ODataType = ck_tile::half_t; + + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; + + if(argc == 3) + { + init_method = std::stoi(argv[1]); + verification = std::stoi(argv[2]); + } + else if(argc == 8) + { + init_method = std::stoi(argv[1]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); + } + + std::array q_lengths{Batch, M0, K0}; + std::array q_strides{M0 * K0, K0, 1}; + + std::array k_lengths{Batch, N0, K0}; + std::array k_strides{N0 * K0, K0, 1}; + + std::array v_lengths{Batch, N1, N0}; + std::array v_strides{N1 * N0, N0, 1}; + + std::array s_lengths{Batch, M0, N0}; + std::array s_strides{M0 * N0, N0, 1}; + + std::array p_lengths{Batch, M0, N0}; + std::array p_strides{M0 * N0, N0, 1}; + + std::array o_lengths{Batch, M0, N1}; + std::array o_strides{M0 * N1, N1, 1}; + + // host verify + ck_tile::HostTensor q_host(q_lengths, q_strides); + ck_tile::HostTensor k_host(k_lengths, k_strides); + ck_tile::HostTensor v_host(v_lengths, v_strides); + ck_tile::HostTensor o_host_dev(o_lengths, o_strides); + + switch(init_method) + { + case 0: break; + case 1: + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 2: + ck_tile::FillUniformDistribution{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(v_host); + break; + default: + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + } + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host_dev.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.mData.data()); + k_buf.ToDevice(k_host.mData.data()); + v_buf.ToDevice(v_host.mData.data()); + + constexpr ck_tile::index_t kM0PerBlock = 128; + constexpr ck_tile::index_t kN0PerBlock = 128; + constexpr ck_tile::index_t kK0PerBlock = 32; + constexpr ck_tile::index_t kN1PerBlock = 128; + constexpr ck_tile::index_t kK1PerBlock = 32; + + constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kHeadDim = 128; + + ck_tile::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true}, + ck_tile::make_kernel( + ck_tile::FlashAttentionFwd{}, + kGridSize, + kBlockSize, + 0, + static_cast(q_buf.GetDeviceBuffer()), + static_cast(k_buf.GetDeviceBuffer()), + static_cast(v_buf.GetDeviceBuffer()), + static_cast(o_buf.GetDeviceBuffer()), + M0, + N0, + K0, + N1, + Batch, + K0, // StrideQ + K0, // StrideK + N0, // StrideV + N1, // StrideO + M0 * K0, // BatchStrideQ + N0 * K0, // BatchStrideK + N1 * N0, // BatchStrideV + M0 * N1)); // BatchStrideO + + // reference + auto pass = true; + if(verification) + { + o_buf.FromDevice(o_host_dev.mData.data()); + + ck_tile::HostTensor s_host_ref(s_lengths, s_strides); + ck_tile::HostTensor p_host_ref(p_lengths, p_strides); + ck_tile::HostTensor o_host_ref(o_lengths, o_strides); + + ck_tile::reference_batched_gemm( + q_host, k_host, s_host_ref); + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref); + ck_tile::reference_batched_gemm( + p_host_ref, v_host, o_host_ref); + + pass &= ck_tile::check_err(o_host_dev, o_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = + std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0; + std::size_t num_btype = + sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 + + sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd.hpp new file mode 100644 index 0000000000..81fadf9d1a --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +#include "block_gemm_pipeline_problem.hpp" +#include "block_gemm_areg_bsmem_creg_v1.hpp" +#include "flash_attention_fwd_impl.hpp" + +namespace ck_tile { + +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if(rlen > 0) + { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if(rlen > 0 and rm < M0) + { + n = rn + update_N0; + m = rm; + } + else + { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +} + +// S[M0, N0] = Q[M0, K0] * K[N0, K0] +// P[M0, N0] = Softmax(S[M0, N0]) +// O[M0, N1] = P[M0, N0] * V[N1, N0] +template +struct FlashAttentionFwd +{ + __device__ void operator()(const QDataType* q_ptr, + const KDataType* k_ptr, + const VDataType* v_ptr, + ODataType* o_ptr, + const index_t M0, + const index_t N0, + const index_t K0, + const index_t N1, + const index_t /* Batch */, + const index_t StrideQ, + const index_t StrideK, + const index_t StrideV, + const index_t StrideO, + const index_t BatchStrideQ, + const index_t BatchStrideK, + const index_t BatchStrideV, + const index_t BatchStrideO) const + { + const index_t id_block = get_block_id(); + + const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); + const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); + +#if defined(TOY_FA_FWD_CACHE_AWARE) +#pragma message("Enable toy FA fwd cache aware") + const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1); + + const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0; + const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0); + + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % + num_tile_m0 * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % + num_tile_n1 * kN1PerBlock); + +#else + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + + return make_tuple(quotient, modulus); + }; + const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); + const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); + +#endif + + const auto kernel_impl = FlashAttentionFwdImpl{}; + + kernel_impl(q_ptr + iBatch * BatchStrideQ, + k_ptr + iBatch * BatchStrideK, + v_ptr + iBatch * BatchStrideV, + o_ptr + iBatch * BatchStrideO, + M0, + N0, + K0, + N1, + StrideQ, + StrideK, + StrideV, + StrideO, + iM0, + iN1); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd_impl.hpp new file mode 100644 index 0000000000..c7c7ead371 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -0,0 +1,440 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/reduce.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp" +#include "block_gemm_pipeline_problem.hpp" +#include "block_gemm_areg_bsmem_creg_v1.hpp" +#include "tile_gemm_shape.hpp" + +namespace ck_tile { + +// S[M0, N0] = Q[M0, K0] * K[N0, K0] +// P[M0, N0] = Softmax(S[M0, N0]) +// O[M0, N1] = P[M0, N0] * V[N1, N0] +template +struct FlashAttentionFwdImpl +{ + // block gemm0 pipeline + using BlockGemm0Problem = + BlockGemmPipelineProblem>; + + using BlockGemm0Policy = + BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; + + using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg; + + // block gemm1 + using BlockGemm1 = BlockGemmARegBSmemCRegV1< + BlockGemmARegBSmemCRegProblem>, + BlockGemmARegBSmemCRegV1DefaultPolicy>; + + // 3d, with padding + __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; +#if !defined(TOY_FA_FWD_QK_SWIZZLE) + constexpr index_t kKPack = 4; +#else + constexpr index_t kKPack = 8; +#endif + + constexpr auto dataTypeSize = sizeof(VDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + __device__ static constexpr auto MakeVDramTileDistribution() + { + using BDataType = VDataType; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + __device__ static constexpr index_t GetStaticLdsSize() + { + return max(BlockGemm0Pipeline::GetStaticLdsSize(), + static_cast(MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(VDataType))); + } + + __device__ void operator()(const QDataType* q_ptr, + const KDataType* k_ptr, + const VDataType* v_ptr, + ODataType* o_ptr, + const index_t M0, + const index_t N0, + const index_t K0, + const index_t N1, + const index_t StrideQ, + const index_t StrideK, + const index_t StrideV, + const index_t StrideO, + const index_t iM0, + const index_t iN1) const + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + + // allocate LDS + __shared__ char smem_ptr[GetStaticLdsSize()]; + + // Q/K/V DRAM and DRAM window + const auto q_dram = make_naive_tensor_view( + q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + + const auto k_dram = make_naive_tensor_view( + k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); + + const auto v_dram = make_naive_tensor_view( + v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); + + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {iN1, 0}, + MakeVDramTileDistribution()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); + + // V LDS and LDS window + // V LDS occupies the same LDS allocation Q/K LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); + +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else + auto v_lds_window = make_tile_window( + v_lds, make_tuple(number{}, number{}), {0, 0}); +#endif + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SaccBlockTileType = + decltype(gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, nullptr)); + + using SBlockTileType = decltype(tile_elementwise_in( + type_convert, SaccBlockTileType{})); + + using PBlockTileType = decltype(tile_elementwise_in(type_convert, + SaccBlockTileType{})); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm1( + get_slice_tile( + PBlockTileType{}, sequence<0, 0>{}, sequence{}), + v_dram_window)); + + // init Sacc, Oacc, M, L + auto s_acc = SaccBlockTileType{}; + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); + tile_elementwise_inout( + [](auto& e) { e = std::numeric_limits::lowest(); }, m); + tile_elementwise_inout([](auto& e) { e = 0; }, l); + + // loop over Column of S (J loop) + index_t iN0 = 0; + + do + { + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + + // S{j} + const auto s = + tile_elementwise_in(type_convert, s_acc); + +#if defined(TOY_FA_FWD_OPT) + // prefetch load v tile + auto v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); +#endif + // m_local = rowmax(S{j}) + auto m_local = block_tile_reduce( + s, sequence<1>{}, f_max, std::numeric_limits::lowest()); + + block_tile_reduce_sync(m_local, f_max); + + // m{j-1} + const auto m_old = m; + + // m{j} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + // Pcompute{j} + auto p_compute = + make_static_distributed_tensor(s.get_tile_distribution()); + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + p_compute(i_j_idx) = exp(s[i_j_idx] - m[i_idx]); + }); + }); + + // rowsum(Pcompute{j}) + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum); + + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = exp(m_old[i_idx] - m[i_idx]); + + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + block_sync_lds(); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; + + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + if constexpr(k1_loops > 1) + { + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { + block_sync_lds(); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); + }); + } + // tail + { + if constexpr(k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } +#endif + // move tile windows + move_tile_window(k_dram_window, {kN0PerBlock, 0}); + iN0 += kN0PerBlock; + } while(iN0 < N0); + + // Oacc + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = 1 / l[i_idx]; + + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + // type cast Oacc into O + const auto o = tile_elementwise_in(type_convert, o_acc); + + // O DRAM and O DRAM window + auto o_dram = make_naive_tensor_view( + o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), number<32>{}, number<1>{}); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {iM0, iN1}, + o.get_tile_distribution()); + + // store O + store_tile(o_dram_window, o); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/reference_batched_gemm.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/reference_batched_gemm.hpp new file mode 100644 index 0000000000..111e59a835 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/reference_batched_gemm.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_batched_gemm(const ck_tile::HostTensor& a_b_m_k, + const ck_tile::HostTensor& b_b_n_k, + ck_tile::HostTensor& c_b_m_n) +{ + const int N = b_b_n_k.mDesc.get_lengths()[1]; + const int K = b_b_n_k.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_b_m_k(batch, m, k); + BDataType v_b = b_b_n_k(batch, n, k); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert(v_acc); + } + }; + + ck_tile::make_ParallelTensorFunctor( + f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/reference_batched_softmax.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/reference_batched_softmax.hpp new file mode 100644 index 0000000000..cc75ba8599 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/reference_batched_softmax.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, + ck_tile::HostTensor& b_b_m_n) +{ + const int N = a_b_m_n.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + AccDataType v_max = std::numeric_limits::lowest(); + + // max + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_b_m_n(batch, m, n); + + v_max = v_max < v_a ? v_a : v_max; + } + + AccDataType v_exp_sum = 0; + + // sum + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_b_m_n(batch, m, n); + + v_exp_sum += ck_tile::exp(v_a - v_max); + } + + // elementwise + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_b_m_n(batch, m, n); + + b_b_m_n(batch, m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum; + } + }; + + ck_tile::make_ParallelTensorFunctor( + f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} diff --git a/example/ck_tile/tutorial/03_flash_attention_fwd/tile_gemm_shape.hpp b/example/ck_tile/tutorial/03_flash_attention_fwd/tile_gemm_shape.hpp new file mode 100644 index 0000000000..b9877ec1a1 --- /dev/null +++ b/example/ck_tile/tutorial/03_flash_attention_fwd/tile_gemm_shape.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/CMakeLists.txt new file mode 100644 index 0000000000..4a114d4af3 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/CMakeLists.txt @@ -0,0 +1,66 @@ +set(FLASH_ATTENTION_FWD_KNOWN_APIS "fwd") +set(FLASH_ATTENTION_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${FLASH_ATTENTION_FWD_KNOWN_APIS}) & link, or \"all\".") +if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all") + set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS}) +endif() + +option(TOY_FA_FWD_OPT "Enable toy flash attention forward optimization" ON) +option(TOY_FA_FWD_QK_SWIZZLE "Enable toy flash attention forward QK swizzle" OFF) + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list Flash Attention kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/flash_attention_fwd_blobs.txt FLASH_ATTENTION_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --gen_blobs +) + +set(EXAMPLE_FA "codegen_basic_flash_attention_fwd") +message("adding example ${EXAMPLE_FA}") + +add_executable(${EXAMPLE_FA} + EXCLUDE_FROM_ALL + flash_attention_fwd.cpp +) + +target_compile_definitions(${EXAMPLE_FA} + PRIVATE + $<$:TOY_FA_FWD_OPT=1> +) + +target_include_directories(${EXAMPLE_FA} + PRIVATE + ${CMAKE_CURRENT_LIST_DIR} +) + +target_sources(${EXAMPLE_FA} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS}) + +message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}") + +set(EXAMPLE_FA_COMPILE_OPTIONS) +list(APPEND EXAMPLE_FA_COMPILE_OPTIONS + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) + +target_compile_options(${EXAMPLE_FA} + PRIVATE + ${EXAMPLE_FA_COMPILE_OPTIONS} +) + +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_problem.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_problem.hpp new file mode 100644 index 0000000000..b56823b1be --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_problem.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// Problem Description for BlockGemmARegBSmemCReg +template +struct BlockGemmARegBSmemCRegProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..ba87ab9c9b --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -0,0 +1,565 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +#include "block_gemm_areg_bsmem_creg_problem.hpp" +#include "block_gemm_areg_bsmem_creg_v1_default_policy.hpp" +#include "block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV1 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using BlockGemmPolicy = Policy; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct C-Block-Tensor + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..ad9c3218b3 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +namespace ck_tile { + +struct BlockGemmARegBSmemCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { +#if !defined(TOY_FA_FWD_QK_SWIZZLE) + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#else +#pragma message("Enable toy FA fwd QK swizzle") + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); +#endif + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp new file mode 100644 index 0000000000..d7e6ee4c96 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +namespace ck_tile { + +struct BlockGemmARegBSmemCRegV1K8Policy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { +#if !defined(TOY_FA_FWD_QK_SWIZZLE) + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}, 4, 1); +#endif + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp new file mode 100644 index 0000000000..eb3ba8f6db --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t k_loops = Policy::AKDim / kKPerBlock; + + // Move this part into Policy? + __host__ __device__ static constexpr index_t GetStaticLdsSize() + { + return sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + // Cold A Register Cache + template + __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + + ignore = a_element_func; + ignore = b_element_func; + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm( + get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), + b_lds_gemm_window)){}; + + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(k_loops > 1) + { + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + sequence<0, 0>{}, + sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + } + + __builtin_amdgcn_sched_barrier(0); + + if constexpr(k_loops > 2) + { + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + sequence<0, (i_k0 + 1) * kKPerBlock>{}, + sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }); + } + + // tail + { + if constexpr(k_loops > 1) + { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + + block_sync_lds(); + } + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + sequence<0, (k_loops - 1) * kKPerBlock>{}, + sequence{}); + + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 1) * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + } + + set_slice_tile(a_reg_block_tensor_tmp, + a_copy_reg_tensor, + sequence<0, 0>{}, + sequence{}); + + return c_block_tile; + } + + // Hot A Register Cache + template + __host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + ignore = b_element_func; + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + + set_slice_tile(a_copy_reg_tensor, + a_reg_block_tensor_tmp, + sequence<0, 0>{}, + sequence{}); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); + + // Acc register tile + auto c_block_tile = decltype(block_gemm( + get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), + b_lds_gemm_window)){}; + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + if constexpr(k_loops > 1) + { + // LDS write 0 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); + } + + if constexpr(k_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); + + block_sync_lds(); + + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }); + } + // tail + { + if constexpr(k_loops > 1) + { + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 2) * kKPerBlock>{}, + sequence{}), + bWarpTile); + + block_sync_lds(); + } + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + bWarpTile = load_tile(b_lds_gemm_window); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, (k_loops - 1) * kKPerBlock>{}, + sequence{}), + bWarpTile); + } +#endif + return c_block_tile; + } + + template + __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } + + template + __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp new file mode 100644 index 0000000000..bdf5818325 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +namespace ck_tile { + +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy +{ + static constexpr index_t AKDim = AKDim_; + + template + __host__ __device__ static constexpr auto GetBlockGemm() + { + using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy; + + return BlockGemmARegBSmemCRegV1{}; + } + + template + __host__ __device__ static constexpr auto MakeARegBlockDescriptor() + { + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = remove_cvref_t; + + static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!"); + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = AKDim; + + constexpr auto config = + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + return a_block_dstr; + } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_problem.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_problem.hpp new file mode 100644 index 0000000000..233194ec1e --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/block_gemm_pipeline_problem.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp new file mode 100644 index 0000000000..30b42ba4e4 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck_tile/host.hpp" + +#include "reference_batched_gemm.hpp" +#include "reference_batched_softmax.hpp" +#include "flash_attention_fwd.hpp" + +/* + * Toy code of flash attention forward pass + * Assume simplest case. + * Q [Batch, HeadNum, SeqenceLengthQ, HeadDim] + * K [Batch, HeadNum, SeqenceLengthK, HeadDim] + * V [Batch, HeadNum, HeadDim, SeqenceLengthK] + * O [Batch, HeadNum, SeqenceLengthQ, HeadDim] + */ + +int main(int argc, char* argv[]) +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using SaccDataType = float; + using SMPLComputeDataType = float; + using PDataType = ck_tile::half_t; + using OaccDataType = float; + using ODataType = ck_tile::half_t; + + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; + + if(argc == 3) + { + init_method = std::stoi(argv[1]); + verification = std::stoi(argv[2]); + } + else if(argc == 8) + { + init_method = std::stoi(argv[1]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); + } + + std::array q_lengths{Batch, M0, K0}; + std::array q_strides{M0 * K0, K0, 1}; + + std::array k_lengths{Batch, N0, K0}; + std::array k_strides{N0 * K0, K0, 1}; + + std::array v_lengths{Batch, N1, N0}; + std::array v_strides{N1 * N0, N0, 1}; + + std::array s_lengths{Batch, M0, N0}; + std::array s_strides{M0 * N0, N0, 1}; + + std::array p_lengths{Batch, M0, N0}; + std::array p_strides{M0 * N0, N0, 1}; + + std::array o_lengths{Batch, M0, N1}; + std::array o_strides{M0 * N1, N1, 1}; + + // host verify + ck_tile::HostTensor q_host(q_lengths, q_strides); + ck_tile::HostTensor k_host(k_lengths, k_strides); + ck_tile::HostTensor v_host(v_lengths, v_strides); + ck_tile::HostTensor o_host_dev(o_lengths, o_strides); + + switch(init_method) + { + case 0: break; + case 1: + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 2: + ck_tile::FillUniformDistribution{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(v_host); + break; + default: + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + } + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host_dev.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.mData.data()); + k_buf.ToDevice(k_host.mData.data()); + v_buf.ToDevice(v_host.mData.data()); + + // Construct the FlashAttnArgs object with your arguments + ck_tile::FlashAttnArgs flash_attention_args{ + static_cast(q_buf.GetDeviceBuffer()), + static_cast(k_buf.GetDeviceBuffer()), + static_cast(v_buf.GetDeviceBuffer()), + static_cast(o_buf.GetDeviceBuffer()), + M0, + N0, + K0, + N1, + Batch, + K0, // strideQ + K0, // strideK + N0, // strideV + N1, // strideO + M0 * K0, // batchStrideQ + N0 * K0, // batchStrideK + N1 * N0, // batchStrideV + M0 * N1 // batchStrideO + }; + + float ave_time = ck_tile::flash_attention_fwd(flash_attention_args, + ck_tile::stream_config{nullptr, true}); + + // reference + auto pass = true; + if(verification) + { + o_buf.FromDevice(o_host_dev.mData.data()); + + ck_tile::HostTensor s_host_ref(s_lengths, s_strides); + ck_tile::HostTensor p_host_ref(p_lengths, p_strides); + ck_tile::HostTensor o_host_ref(o_lengths, o_strides); + + ck_tile::reference_batched_gemm( + q_host, k_host, s_host_ref); + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref); + ck_tile::reference_batched_gemm( + p_host_ref, v_host, o_host_ref); + + pass &= ck_tile::check_err(o_host_dev, o_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = + std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0; + std::size_t num_btype = + sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 + + sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp new file mode 100644 index 0000000000..32cb98b886 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +#include "block_gemm_pipeline_problem.hpp" +#include "block_gemm_areg_bsmem_creg_v1.hpp" +#include "flash_attention_fwd_impl.hpp" + +namespace ck_tile { + +template +struct FlashAttnArgs +{ + // Pointers to device buffers for Q, K, V, O + QDataType* q_ptr; + KDataType* k_ptr; + VDataType* v_ptr; + ODataType* o_ptr; + + // Problem sizes + index_t M0; + index_t N0; + index_t K0; + index_t N1; + index_t Batch; + + // Strides within a batch + index_t strideQ; + index_t strideK; + index_t strideV; + index_t strideO; + + // Batch strides + index_t batchStrideQ; + index_t batchStrideK; + index_t batchStrideV; + index_t batchStrideO; +}; + +// S[M0, N0] = Q[M0, K0] * K[N0, K0] +// P[M0, N0] = Softmax(S[M0, N0]) +// O[M0, N1] = P[M0, N0] * V[N1, N0] +template +struct FlashAttentionFwd +{ + __device__ void operator()(const QDataType* q_ptr, + const KDataType* k_ptr, + const VDataType* v_ptr, + ODataType* o_ptr, + const index_t M0, + const index_t N0, + const index_t K0, + const index_t N1, + const index_t /* Batch */, + const index_t StrideQ, + const index_t StrideK, + const index_t StrideV, + const index_t StrideO, + const index_t BatchStrideQ, + const index_t BatchStrideK, + const index_t BatchStrideV, + const index_t BatchStrideO) const + { + const index_t id_block = get_block_id(); + + const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); + const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + + return make_tuple(quotient, modulus); + }; + + const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); + const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); + + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); + + const auto kernel_impl = FlashAttentionFwdImpl{}; + + kernel_impl(q_ptr + iBatch * BatchStrideQ, + k_ptr + iBatch * BatchStrideK, + v_ptr + iBatch * BatchStrideV, + o_ptr + iBatch * BatchStrideO, + M0, + N0, + K0, + N1, + StrideQ, + StrideK, + StrideV, + StrideO, + iM0, + iN1); + } +}; + +template +float flash_attention_fwd(const FlashAttnArgs& a, + const stream_config& stream_config); + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp new file mode 100644 index 0000000000..c7c7ead371 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -0,0 +1,440 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/reduce.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp" +#include "block_gemm_pipeline_problem.hpp" +#include "block_gemm_areg_bsmem_creg_v1.hpp" +#include "tile_gemm_shape.hpp" + +namespace ck_tile { + +// S[M0, N0] = Q[M0, K0] * K[N0, K0] +// P[M0, N0] = Softmax(S[M0, N0]) +// O[M0, N1] = P[M0, N0] * V[N1, N0] +template +struct FlashAttentionFwdImpl +{ + // block gemm0 pipeline + using BlockGemm0Problem = + BlockGemmPipelineProblem>; + + using BlockGemm0Policy = + BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; + + using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg; + + // block gemm1 + using BlockGemm1 = BlockGemmARegBSmemCRegV1< + BlockGemmARegBSmemCRegProblem>, + BlockGemmARegBSmemCRegV1DefaultPolicy>; + + // 3d, with padding + __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; +#if !defined(TOY_FA_FWD_QK_SWIZZLE) + constexpr index_t kKPack = 4; +#else + constexpr index_t kKPack = 8; +#endif + + constexpr auto dataTypeSize = sizeof(VDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + __device__ static constexpr auto MakeVDramTileDistribution() + { + using BDataType = VDataType; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + __device__ static constexpr index_t GetStaticLdsSize() + { + return max(BlockGemm0Pipeline::GetStaticLdsSize(), + static_cast(MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(VDataType))); + } + + __device__ void operator()(const QDataType* q_ptr, + const KDataType* k_ptr, + const VDataType* v_ptr, + ODataType* o_ptr, + const index_t M0, + const index_t N0, + const index_t K0, + const index_t N1, + const index_t StrideQ, + const index_t StrideK, + const index_t StrideV, + const index_t StrideO, + const index_t iM0, + const index_t iN1) const + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + + // allocate LDS + __shared__ char smem_ptr[GetStaticLdsSize()]; + + // Q/K/V DRAM and DRAM window + const auto q_dram = make_naive_tensor_view( + q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + + const auto k_dram = make_naive_tensor_view( + k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); + + const auto v_dram = make_naive_tensor_view( + v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); + + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {iN1, 0}, + MakeVDramTileDistribution()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); + + // V LDS and LDS window + // V LDS occupies the same LDS allocation Q/K LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); + +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else + auto v_lds_window = make_tile_window( + v_lds, make_tuple(number{}, number{}), {0, 0}); +#endif + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SaccBlockTileType = + decltype(gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, nullptr)); + + using SBlockTileType = decltype(tile_elementwise_in( + type_convert, SaccBlockTileType{})); + + using PBlockTileType = decltype(tile_elementwise_in(type_convert, + SaccBlockTileType{})); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm1( + get_slice_tile( + PBlockTileType{}, sequence<0, 0>{}, sequence{}), + v_dram_window)); + + // init Sacc, Oacc, M, L + auto s_acc = SaccBlockTileType{}; + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); + tile_elementwise_inout( + [](auto& e) { e = std::numeric_limits::lowest(); }, m); + tile_elementwise_inout([](auto& e) { e = 0; }, l); + + // loop over Column of S (J loop) + index_t iN0 = 0; + + do + { + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + + // S{j} + const auto s = + tile_elementwise_in(type_convert, s_acc); + +#if defined(TOY_FA_FWD_OPT) + // prefetch load v tile + auto v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); +#endif + // m_local = rowmax(S{j}) + auto m_local = block_tile_reduce( + s, sequence<1>{}, f_max, std::numeric_limits::lowest()); + + block_tile_reduce_sync(m_local, f_max); + + // m{j-1} + const auto m_old = m; + + // m{j} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + // Pcompute{j} + auto p_compute = + make_static_distributed_tensor(s.get_tile_distribution()); + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + p_compute(i_j_idx) = exp(s[i_j_idx] - m[i_idx]); + }); + }); + + // rowsum(Pcompute{j}) + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum); + + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = exp(m_old[i_idx] - m[i_idx]); + + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + block_sync_lds(); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; + + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + if constexpr(k1_loops > 1) + { + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { + block_sync_lds(); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); + }); + } + // tail + { + if constexpr(k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } +#endif + // move tile windows + move_tile_window(k_dram_window, {kN0PerBlock, 0}); + iN0 += kN0PerBlock; + } while(iN0 < N0); + + // Oacc + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = 1 / l[i_idx]; + + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + // type cast Oacc into O + const auto o = tile_elementwise_in(type_convert, o_acc); + + // O DRAM and O DRAM window + auto o_dram = make_naive_tensor_view( + o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), number<32>{}, number<1>{}); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {iM0, iN1}, + o.get_tile_distribution()); + + // store O + store_tile(o_dram_window, o); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/generate.py new file mode 100644 index 0000000000..ce2f9d32c8 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/generate.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass + +def get_if_str(size_, total, last_else=True): + if size_ == "head_dim_128_seq_16384": + return 'if' + else: + return 'else if' + +DATA_TYPE_MAP = {'fp32': 'float', + 'fp16': 'ck_tile::half_t', + 'bf16': 'ck_tile::bf16_t'} + +def BOOL_MAP(b_) -> str: + return 'true' if b_ else 'false' + +class FlashAttentionFwdCodegen: + API_TRAITS_DEFINE = """ + +template +struct flash_attention_fwd_traits_ +{ + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kHeadDim = kHeadDim_; + static constexpr index_t kM0PerBlock = kM0PerBlock_; + static constexpr index_t kN0PerBlock = kN0PerBlock_; + static constexpr index_t kK0PerBlock = kK0PerBlock_; + static constexpr index_t kN1PerBlock = kN1PerBlock_; + static constexpr index_t kK1PerBlock = kK1PerBlock_; + + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); + static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; +}; + +template +using traits_ = flash_attention_fwd_traits_; +""" + + API_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "flash_attention_fwd.hpp" + +namespace ck_tile {{ + +{F_traits_define} + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float flash_attention_fwd_(const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config); + +template +float flash_attention_fwd(const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) {{ + float r = -1; +{F_dispatch} + return r; +}} + +template float flash_attention_fwd( + const FlashAttnArgs&, + const ck_tile::stream_config&); + +}} +""" + + API_INNER_CASE = """ {F_if} {F_VEC_COND} + r = flash_attention_fwd_>(a, stream_config); +""" + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "flash_attention_fwd_api_common.hpp" + +namespace ck_tile { +// clang-format off +// +{F_instance_def} +// clang-format on + +} +""" + + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter + + @dataclass + class h_traits: + F_SaccDataType: str + F_SMPLComputeDataType: str + F_PDataType: str + F_OaccDataType: str + F_kBlockSize: int + F_kHeadDim: int + F_kM0PerBlock: int + F_kN0PerBlock: int + F_kK0PerBlock: int + F_kN1PerBlock: int + F_kK1PerBlock: int + + @property + def trait_name(self) -> str: + return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, " + f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " + f"{DATA_TYPE_MAP[self.F_PDataType]}, " + f"{DATA_TYPE_MAP[self.F_OaccDataType]}, " + f"{self.F_kBlockSize}, {self.F_kHeadDim}, " + f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " + f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}") + + @property + def def_name(self) -> str: + return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " + f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " + f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " + f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, " + "const ck_tile::stream_config&);") + + @dataclass + class h_instance: + F_DataTypePair: str # "q,k,v,o" + F_SizeCategory: str # "small", "medium", "large" + instance_list: List[Any] # List[h_traits] + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "flash_attention_fwd_api_common.hpp" + +namespace ck_tile {{ +// clang-format off +// +{F_instance_def} +// clang-format on +}} +""" + + @property + def name(self) -> str: + q_type, k_type, v_type, o_type = self.F_DataTypePair.split(',') + dtype_str = f"{q_type}_{k_type}_{v_type}_{o_type}" + return f"flash_attention_fwd_{dtype_str}_{self.F_SizeCategory}" + + @property + def content(self) -> str: + instance_defs = '\n'.join(ins.def_name for ins in self.instance_list) + return self.INSTANCE_BASE.format(F_instance_def=instance_defs) + + @property + def name_api(self) -> str: + return "flash_attention_fwd_api" + + @property + def name_common_header(self) -> str: + return "flash_attention_fwd_api_common" + + @property + def content_common_header(self) -> str: + return f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "flash_attention_fwd.hpp" + +namespace ck_tile {{ + +template +struct flash_attention_fwd_traits_ +{{ + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kHeadDim = kHeadDim_; + static constexpr index_t kM0PerBlock = kM0PerBlock_; + static constexpr index_t kN0PerBlock = kN0PerBlock_; + static constexpr index_t kK0PerBlock = kK0PerBlock_; + static constexpr index_t kN1PerBlock = kN1PerBlock_; + static constexpr index_t kK1PerBlock = kK1PerBlock_; + + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; +}}; + + +template +using traits_ = flash_attention_fwd_traits_; + + +template +float flash_attention_fwd_(const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) {{ + using SaccDataType = typename Traits_::SaccDataType; + using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; + using PDataType = typename Traits_::PDataType; + using OaccDataType = typename Traits_::OaccDataType; + + index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); + + if(stream_config.log_level_ > 0) + std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; + + return ck_tile::launch_kernel(stream_config, + ck_tile::make_kernel( + ck_tile::FlashAttentionFwd{{}}, + kGridSize, + Traits_::kBlockSize, + 0, + a.q_ptr, + a.k_ptr, + a.v_ptr, + a.o_ptr, + a.M0, + a.N0, + a.K0, + a.N1, + a.Batch, + a.strideQ, // StrideQ + a.strideK, // StrideK + a.strideV, // StrideV + a.strideO, // StrideO + a.batchStrideQ, // BatchStrideQ + a.batchStrideK, // BatchStrideK + a.batchStrideV, // BatchStrideV + a.batchStrideO)); // BatchStrideO +}} +}} +""" + def content_api(self, args) -> str: + # Sort based on dtype + t_dtype_dict = {} + blobs = self.get_blobs(args) + + for blob in blobs: + if blob.F_DataTypePair not in t_dtype_dict: + t_dtype_dict[blob.F_DataTypePair] = {} + if blob.F_SizeCategory not in t_dtype_dict[blob.F_DataTypePair]: + t_dtype_dict[blob.F_DataTypePair][blob.F_SizeCategory] = [] + t_dtype_dict[blob.F_DataTypePair][blob.F_SizeCategory].append(blob) + + d_str = '' + for i_d, dtype_ in enumerate(t_dtype_dict): + blob_per_t = t_dtype_dict[dtype_] + size_str = '' + + for i_size, size_ in enumerate(blob_per_t): + blob_per_size = blob_per_t[size_] + inner_str = "" + + for i_b, b_ in enumerate(blob_per_size): + for i_ins, ins in enumerate(b_.instance_list): + idx_in_size = i_b * len(b_.instance_list) + i_ins + len_in_size = sum(len(b.instance_list) for b in blob_per_size) + + size_cond = "" + if size_ == "head_dim_128_seq_16384": + size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 128 && a.N1 == 128)" + elif size_ == "head_dim_64_seq_16384": + size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 64 && a.N1 == 64)" + elif size_ == "head_dim_128_seq_8192": + size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 128 && a.N1 == 128)" + elif size_ == "head_dim_64_seq_8192": + size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 64 && a.N1 == 64)" + elif size_ == "head_dim_256_seq_4096": + size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)" + elif size_ == "head_dim_128_seq_4096": + size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)" + elif size_ == "head_dim_64_seq_4096": + size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 64 && a.N1 == 64)" + elif size_ == "head_dim_32_seq_4096": + size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)" + elif size_ == "head_dim_128_seq_2048": + size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 128 && a.N1 == 128)" + elif size_ == "head_dim_64_seq_2048": + size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 64 && a.N1 == 64)" + elif size_ == "head_dim_128_seq_1024": + size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)" + elif size_ == "head_dim_64_seq_1024": + size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 64 && a.N1 == 64)" + elif size_ == "head_dim_128_seq_512": + size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)" + elif size_ == "head_dim_64_seq_512": + size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 64 && a.N1 == 64)" + else: + size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)" + + inner_str += self.API_INNER_CASE.format( + F_if=get_if_str(size_, len_in_size, False), + F_VEC_COND=size_cond, + F_trait_name=ins.trait_name + ) + size_str += inner_str + + d_str += size_str + + api_base = self.API_BASE.format( + F_traits_define=self.API_TRAITS_DEFINE, + F_dispatch=d_str + ) + return api_base + + def get_blobs(self, args): + h_traits = self.h_traits + h_instance = self.h_instance + + # Define kernel configurations for different size categories + trait_dict = { + "head_dim_128_seq_16384": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32), + ], + "head_dim_64_seq_16384": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32), + ], + "head_dim_128_seq_8192": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32), + ], + "head_dim_64_seq_8192": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32), + ], + "head_dim_256_seq_4096": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64), + ], + "head_dim_128_seq_4096": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32), + ], + "head_dim_64_seq_4096": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64), + ], + "head_dim_32_seq_4096": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32), + ], + "head_dim_128_seq_2048": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32), + ], + "head_dim_64_seq_2048": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32), + ], + "head_dim_128_seq_1024": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32), + ], + "head_dim_64_seq_1024": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32), + ], + "head_dim_128_seq_512": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32), + ], + "head_dim_64_seq_512": [ + h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64), + ], + } + + # Toy example only support fp16 + dtype_combinations = [ + "fp16,fp16,fp16,fp16" + # "bf16,bf16,bf16,bf16" + ] + + total_blob = [] + for dtype_pair in dtype_combinations: + for size_category in trait_dict: + traits = trait_dict[size_category] + # Convert data types for the current dtype_pair + q_type, k_type, v_type, o_type = dtype_pair.split(',') + current_traits = [] + for t in traits: + new_t = copy.copy(t) + new_t.F_SaccDataType = 'fp32' # accumulation in fp32 + new_t.F_SMPLComputeDataType = 'fp32' # softmax compute in fp32 + new_t.F_PDataType = q_type + new_t.F_OaccDataType = 'fp32' # output accumulation in fp32 + current_traits.append(new_t) + + total_blob.append(h_instance(dtype_pair, size_category, current_traits)) + + return total_blob + + def list_blobs(self, args) -> None: + w_p = Path(self.working_path) + list_p = w_p / 'flash_attention_fwd_blobs.txt' + blobs = self.get_blobs(args) + + with list_p.open('w') as list_f: + # API related files + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + # Kernel instance files + for b in blobs: + list_f.write(str(w_p / (b.name + ".cpp")) + "\n") + + def gen_blobs(self, args) -> None: + w_p = Path(self.working_path) + w_str = self.content_api(args) + (w_p / (self.name_api + ".cpp")).write_text(w_str) + (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + + blobs = self.get_blobs(args) + for b in blobs: + (w_p / (b.name + ".cpp")).write_text(b.content) + +def list_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + FlashAttentionFwdCodegen(args.working_path, args.filter).list_blobs(args) + +def gen_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + FlashAttentionFwdCodegen(args.working_path, args.filter).gen_blobs(args) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for Flash Attention kernel", + ) + parser.add_argument( + "-a", + "--api", + default='fwd', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + parser.add_argument( + "-l", + "--list_blobs", + action='store_true', + help="list all the kernels to a file" + ) + parser.add_argument( + "-g", + "--gen_blobs", + action='store_true', + help="generate all kernels into different tile" + ) + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate" + ) + + args = parser.parse_args() + + if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): + print('gen_blobs/list_blobs must specify only one option') + sys.exit() + + p = Path(args.working_path) + if not p.exists(): + p.mkdir() + + if args.list_blobs: + list_blobs(args) + else: + gen_blobs(args) diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/reference_batched_gemm.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/reference_batched_gemm.hpp new file mode 100644 index 0000000000..111e59a835 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/reference_batched_gemm.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_batched_gemm(const ck_tile::HostTensor& a_b_m_k, + const ck_tile::HostTensor& b_b_n_k, + ck_tile::HostTensor& c_b_m_n) +{ + const int N = b_b_n_k.mDesc.get_lengths()[1]; + const int K = b_b_n_k.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_b_m_k(batch, m, k); + BDataType v_b = b_b_n_k(batch, n, k); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert(v_acc); + } + }; + + ck_tile::make_ParallelTensorFunctor( + f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/reference_batched_softmax.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/reference_batched_softmax.hpp new file mode 100644 index 0000000000..cc75ba8599 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/reference_batched_softmax.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, + ck_tile::HostTensor& b_b_m_n) +{ + const int N = a_b_m_n.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + AccDataType v_max = std::numeric_limits::lowest(); + + // max + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_b_m_n(batch, m, n); + + v_max = v_max < v_a ? v_a : v_max; + } + + AccDataType v_exp_sum = 0; + + // sum + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_b_m_n(batch, m, n); + + v_exp_sum += ck_tile::exp(v_a - v_max); + } + + // elementwise + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_b_m_n(batch, m, n); + + b_b_m_n(batch, m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum; + } + }; + + ck_tile::make_ParallelTensorFunctor( + f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} diff --git a/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/tile_gemm_shape.hpp b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/tile_gemm_shape.hpp new file mode 100644 index 0000000000..b9877ec1a1 --- /dev/null +++ b/example/ck_tile/tutorial/04_codegen_flash_attention_fwd/tile_gemm_shape.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/tutorial/CMakeLists.txt b/example/ck_tile/tutorial/CMakeLists.txt new file mode 100644 index 0000000000..b3a49ca5d5 --- /dev/null +++ b/example/ck_tile/tutorial/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(00_add_basic) +add_subdirectory(01_add) +add_subdirectory(02_gemm) +add_subdirectory(03_flash_attention_fwd) +add_subdirectory(04_codegen_flash_attention_fwd) diff --git a/example/ck_tile/tutorial/README.md b/example/ck_tile/tutorial/README.md new file mode 100644 index 0000000000..fd73b0c6ce --- /dev/null +++ b/example/ck_tile/tutorial/README.md @@ -0,0 +1,112 @@ + + +# CK_TILE Toy Example + +This repository demonstrates a toy example implemented using ck_tile + +## Build Instructions + +Follow these steps to build the examples: + +```sh +cd composable_kernel +mkdir build +cd build + +cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx942" \ + -Dkernel=N .. +``` + +### Compile Examples + +#### **Elementwise Add Example** +```sh +make -j add +``` + +#### **GEMM Example** +```sh +make -j basic_gemm +``` + +#### **Flash Attention Forward Example** +```sh +make -j basic_flash_attention_fwd +``` + +## Running Examples + +### **Elementwise Add** +```sh +./bin/add +``` + +### **GEMM Example** +```sh +./bin/basic_gemm 1 +``` + +### **Flash Attention Forward Example** +```sh +./bin/basic_flash_attention_fwd 1 1 +``` + +## Advanced part +#### **GEMM Example** +##### Follow these steps to build and run the different kernels: +```sh + +cd composable_kernel +mkdir build +cd build + +# for naive kernel +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=N .. && make -j basic_gemm + +# for kernel A +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=A .. && make -j basic_gemm + +# for kernel B +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=B .. && make -j basic_gemm + +... + +# for kernel H +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=H .. && make -j basic_gemm + +``` + +```sh +./bin/basic_gemm 1 +``` + +#### **Flash Attention Forward Example** +##### Follow these steps to build the kernels + +```sh +# for naive kernel +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" .. && make -j basic_flash_attention_fwd + +# for optimized kernel +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -DENABLE_TOY_FA_FWD_OPT=ON .. && make -j basic_flash_attention_fwd +``` +```sh +./bin/basic_flash_attention_fwd 1 1 +``` + + +##### Follow these steps to build the codegen instances + +```sh +mkdir build +cd build +../script/cmake-ck-release.sh .. gfx942 +make -j codegen_basic_flash_attention_fwd +``` + +```sh +./bin/codegen_basic_flash_attention_fwd 1 1 64 16384 16384 128 128 +``` diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp index 7a4ba2eb79..bf3c2bb30e 100644 --- a/include/ck_tile/core/tensor/slice_tile.hpp +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,13 +80,10 @@ set_slice_tile(static_distributed_tensor(); constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>(); constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); - static_assert(std::is_same_v, "wrong!"); - - dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); + dst_tile.set_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); } } // namespace ck_tile