Add multiple tutorial examples

This commit is contained in:
Clement Lin
2025-05-18 17:24:14 +08:00
parent 6342f6b5e8
commit a010920134
52 changed files with 7578 additions and 5 deletions

View File

@@ -0,0 +1,21 @@
set(EXAMPLE_ADD_BASIC "add_basic")
message("adding example ${EXAMPLE_ADD_BASIC}")
add_executable(${EXAMPLE_ADD_BASIC} EXCLUDE_FROM_ALL add_basic.cpp)
target_include_directories(${EXAMPLE_ADD_BASIC} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_ADD_BASIC_COMPILE_OPTIONS)
# generate assembly
# list(APPEND EXAMPLE_ADD_BASIC_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_ADD_BASIC_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_ADD_BASIC} PRIVATE ${EXAMPLE_ADD_BASIC_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "reference_add_vector.hpp"
#include "add_basic.hpp"
#include <cstring>
// This example demonstrates how to use the ck_tile library to perform an elementwise vector
// addition using a custom kernel. The kernel is defined in the vector_add.hpp file, and the
// reference implementation is provided in the reference_vector_add.hpp file.
// parse command line arguments
// -m: size of the vectors
// -v: validation flag (1 for validation, 0 for no validation)
// -prec: precision of the data type (fp16, fp32, int8, int32)
// -warmup: number of warmup iterations (number of kernel launches before measuring performance)
// -repeat: number of repeat iterations (number of kernel launches to measure performance)
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "41943040", "m dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType; // input data type
using ComputeDataType = float; // compute data type
using YDataType = DataType; // output data type
ck_tile::index_t m = arg_parser.get_int("m"); // size of the vectors
int do_validation = arg_parser.get_int("v"); // do we verify the result on cpu
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
ck_tile::HostTensor<XDataType> x_host_a(
{m}); // length input vector A, if given two arguments (m, n) the HostTensor will be created
// with shape (m, n)
ck_tile::HostTensor<XDataType> x_host_b(
{m}); // length input vector B, if given two arguments (m, n) the HostTensor will be created
// with shape (m, n)
ck_tile::HostTensor<YDataType> y_host_ref({m});
ck_tile::HostTensor<YDataType> y_host_dev({m});
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(
x_host_a); // fill the input vector A with random values
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_b);
ck_tile::DeviceMem x_buf_a(
x_host_a.get_element_space_size_in_bytes()); // allocate device memory for input vector A
// (this a wrapper over hipMalloc)
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
x_buf_a.ToDevice(
x_host_a
.data()); // copy the input vector A to device memory, this is a wrapper over hipMemcpy
x_buf_b.ToDevice(x_host_b.data());
// Dividing the problem into blocktile, warptile, and vector
// The blocktile is the size of the tile that will be processed by a single thread block (also
// called work group) The warptile is the size of the tile that will be processed by a single
// warp (also called wavefront) The vector is the size of the tile that will be processed by a
// single thread (also called work item) The problem is divided into blocks of size BlockTile,
// each block is further divided into warps of size WarpTile and each warp is composed of 64 or
// 32 threads of size Vector each of the thread in a warp will process one vector worth elements
// of the data
using BlockTile = ck_tile::sequence<8192>; // Size of the block tile (Entire problem is divided
// into blocks of this size)
using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp
// will cover some part of blockTile)
using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp
using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread
// will cover some part of WarpTile)
// Interpretation of above configurations
// Each thread will cover 1 element (Vector)
// Each WarpTile will cover 64 elements (WarpTile) --> since 64 threads in a warp
// if we have 8 warps in a block (BlockWarps) then we have 8 * 64 = 512 threads in a block
// if 8 warps are not enough to cover the entire blockTile then each of the 8 concurrent warps
// will iterate over the blockTile several times
constexpr ck_tile::index_t kBlockSize = 512;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl;
std::cout << "grid size " << kGridSize << std::endl;
using Shape = ck_tile::AddVectorShape<BlockWarps, BlockTile, WarpTile, Vector>;
std::cout << "Problem Shape:: M = " << m << std::endl;
std::cout << "BlockTile: " << BlockTile::at(ck_tile::number<0>{}) << std::endl;
std::cout << "Number of Blocks in Grid: " << m / BlockTile::at(ck_tile::number<0>{})
<< std::endl;
std::cout << "BlockWarps: " << BlockWarps::at(ck_tile::number<0>{}) << std::endl;
std::cout << "WarpTile: " << WarpTile::at(ck_tile::number<0>{}) << std::endl;
std::cout << "Vector: " << Vector::at(ck_tile::number<0>{}) << std::endl;
std::cout << "Repeat: " << Shape::Repeat_M
<< std::endl; // number of times a warp will iterate over the blockTile, covering
// different parts of the blockTile
std::cout << "Threads per Block: " << kBlockSize << std::endl;
std::cout << "ThreadBlocks per CU: " << kBlockPerCu << std::endl;
// What is a Problem in CKTile?
// A Problem defines the shape of the data, the precision of the data
using Problem = ck_tile::AddVectorProblem<XDataType, ComputeDataType, YDataType, Shape>;
// What is a Policy in CKTile?
// A Policy defines how to map the data between threads and data in memory
// The kernel is the function that will be executed on the device
// It requires a Problem and Policy to be defined
using Kernel = ck_tile::AddVectorKernel<Problem>;
// The kernel is launched with the following parameters:
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, // wrapper over hipStreamCreate
ck_tile::make_kernel<kBlockSize, kBlockPerCu>( // numOfThreadsPerBlock, numOfBlocksPerCU
Kernel{}, // kernel
kGridSize, // number of blocks in the grid
kBlockSize, // number of threads in a block
0, // shared memory size
static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()), // input vector A
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()), // input vector B
static_cast<YDataType*>(y_buf.GetDeviceBuffer()), // output vector
m));
std::size_t num_btype = sizeof(XDataType) * m + sizeof(YDataType) * m;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
ck_tile::reference_add_vector<XDataType, YDataType>(x_host_a, x_host_b, y_host_ref);
y_buf.FromDevice(y_host_dev.mData.data());
pass = ck_tile::check_err(y_host_dev, y_host_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
}

View File

@@ -0,0 +1,140 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
// struct that holds the tile size of the block, warp, and vector
// and the number of warps per block
// and the number of threads per warp
// and the number of times the warp tile is repeated in the block tile
// and the block size
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename Vector>
struct AddVectorShape
{
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Vector_M = Vector::at(number<0>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t Repeat_M =
Block_M /
(WarpPerBlock_M * Warp_M); // Number of times the warp tile is repeated in the block tile
static constexpr index_t BlockSize =
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
struct AddVectorProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
};
// data mapping beween threads and memory
struct AddDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, // Replicate
tuple<sequence<S::Repeat_M,
S::WarpPerBlock_M,
S::ThreadPerWarp_M,
S::Vector_M>>, // Hierarchical
tuple<sequence<1>, sequence<1>>, // Parallel
tuple<sequence<1>, sequence<2>>, // Parallel
sequence<1, 1>, // Yield
sequence<0, 3>>{} // Yield
);
}
};
template <typename Problem_, typename Policy_ = AddDefaultPolicy>
struct AddVectorKernel
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
// body of the kernel
CK_TILE_DEVICE void
operator()(const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M) const
{
using S = typename Problem::BlockShape;
// create tensor view for the input and output data, this defines how the data is laid out
// in memory
const auto x_m_n_a = make_naive_tensor_view<address_space_enum::global>(
p_x_a,
make_tuple(M),
make_tuple(1),
number<S::Vector_M>{}); // raw pointer, shape of the tensor, stride of the tensor, and
// lastGarunteedVectorLength
const auto x_m_n_b = make_naive_tensor_view<address_space_enum::global>(
p_x_b, make_tuple(M), make_tuple(1), number<S::Vector_M>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M), make_tuple(1), number<S::Vector_M>{});
// origin of the block tile
const auto iM = get_block_id() * S::Block_M;
// creating tile windows for the input and output data
auto x_window_a = make_tile_window(x_m_n_a,
make_tuple(number<S::Block_M>{}),
{iM},
Policy::template MakeXBlockTileDistribution<Problem>());
auto x_window_b = make_tile_window(x_m_n_b,
make_tuple(number<S::Block_M>{}),
{iM},
Policy::template MakeXBlockTileDistribution<Problem>());
auto y_window = make_tile_window(y_m_n,
make_tuple(number<S::Block_M>{}),
{iM},
Policy::template MakeXBlockTileDistribution<Problem>());
// Load tile data
const auto xa =
load_tile(x_window_a); // load tile data from global tensor view, load from where? what?
// how many? logical memory layout? all are defined in x_window_a
const auto xb = load_tile(x_window_b);
auto y_compute = load_tile(y_window);
// Process the vector add
constexpr auto spans = decltype(xa)::get_distributed_spans(); // shape of the tile
sweep_tile_span(spans[number<0>{}], [&](auto idx) { // iterate over the tile
const auto tile_idx = make_tuple(idx);
const auto a_val = type_convert<ComputeDataType>(xa[tile_idx]);
const auto b_val = type_convert<ComputeDataType>(xb[tile_idx]);
y_compute(tile_idx) = a_val + b_val;
});
// Store results
store_tile(y_window,
cast_tile<YDataType>(y_compute)); // store the result back to global tensor view
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType, typename YDataType>
CK_TILE_HOST void reference_add_vector(const HostTensor<XDataType>& xa_m_n,
const HostTensor<XDataType>& xb_m_n,
HostTensor<YDataType>& y_m_n)
{
auto f = [&](auto m) {
const int N = 1;
for(int n = 0; n < N; ++n)
{
y_m_n(m, n) = ck_tile::type_convert<YDataType>(xa_m_n(m, n)) +
ck_tile::type_convert<YDataType>(xb_m_n(m, n));
}
};
make_ParallelTensorFunctor(f,
y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,21 @@
set(EXAMPLE_ADD "add")
message("adding example ${EXAMPLE_ADD}")
add_executable(${EXAMPLE_ADD} EXCLUDE_FROM_ALL add.cpp)
target_include_directories(${EXAMPLE_ADD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_ADD_COMPILE_OPTIONS)
# generate assembly
# list(APPEND EXAMPLE_ADD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_ADD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_ADD} PRIVATE ${EXAMPLE_ADD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,117 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "reference_add.hpp"
#include "add.hpp"
#include <cstring>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "10240", "m dimension")
.insert("n", "4096", "n dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "200", "cold iter")
.insert("repeat", "1000", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType;
using ComputeDataType = float;
using YDataType = DataType;
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
ck_tile::HostTensor<XDataType> x_host_a({m, n});
ck_tile::HostTensor<XDataType> x_host_b({m, n});
ck_tile::HostTensor<YDataType> y_host_ref({m, n});
ck_tile::HostTensor<YDataType> y_host_dev({m, n});
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_a);
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_b);
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
x_buf_a.ToDevice(x_host_a.data());
x_buf_b.ToDevice(x_host_b.data());
using BlockWarps =
ck_tile::sequence<1, 8>; // number of concurrent warps in one block (if 8 warps * 64 threads
// per warp, 512 threads in one block are NEEDED)
using BlockTile =
ck_tile::sequence<1, 4096>; // shape of one blockTile (elements covered by one block)
using WarpTile = ck_tile::sequence<1, 512>; // shape of one warpTile (elements covered by one
// warp (64 threads))
using Vector = ck_tile::sequence<1, 8>; // shape of one vector (elements covered by one thread)
constexpr ck_tile::index_t kBlockSize =
512; // number of blockWarps * number of threads per warp
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl;
std::cout << "grid size " << kGridSize << std::endl;
using Shape = ck_tile::AddShape<BlockWarps, BlockTile, WarpTile, Vector>;
using Porblem = ck_tile::AddProblem<XDataType, ComputeDataType, YDataType, Shape>;
using Kernel = ck_tile::Add<Porblem>;
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n));
std::size_t num_btype = 2 * sizeof(XDataType) * m * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
ck_tile::reference_add<XDataType, YDataType>(x_host_a, x_host_b, y_host_ref);
y_buf.FromDevice(y_host_dev.mData.data());
pass = ck_tile::check_err(y_host_dev, y_host_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
}

View File

@@ -0,0 +1,166 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
template <typename BlockWarps, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WarpTile, // warp size, seq<M, N>
typename Vector> // contiguous pixels(vector size) along seq<M, N>
struct AddShape
{
static constexpr index_t Block_M = BlockTile::at(number<0>{}); // elements along M in one Block
static constexpr index_t Block_N = BlockTile::at(number<1>{}); // elements along N in one Block
static constexpr index_t Warp_M = WarpTile::at(number<0>{}); // elements along M in one Warp
static constexpr index_t Warp_N = WarpTile::at(number<1>{}); // elements along N in one Warp
static constexpr index_t Vector_M = Vector::at(number<0>{}); // elements along M in one Vector
static constexpr index_t Vector_N = Vector::at(number<1>{}); // elements along N in one Vector
static constexpr index_t WarpPerBlock_M =
BlockWarps::at(number<0>{}); // num concurrent warps along M
static constexpr index_t WarpPerBlock_N =
BlockWarps::at(number<1>{}); // num concurrent warps along N
static constexpr index_t ThreadPerWarp_M =
Warp_M /
Vector_M; // num threads along M in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
static constexpr index_t ThreadPerWarp_N =
Warp_N /
Vector_N; // num threads along N in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
static constexpr index_t Repeat_M =
Block_M /
(WarpPerBlock_M *
Warp_M); // num of time a warp iterates along M to ensure the entire block is covered
static constexpr index_t Repeat_N =
Block_N /
(WarpPerBlock_N *
Warp_N); // num of time a warp iterates along N to ensure the entire block is covered
static constexpr index_t BlockSize =
warpSize *
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); // num of threads in one block
};
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
struct AddProblem
{
using XDataType = remove_cvref_t<XDataType_>; // data type of input tensor
using ComputeDataType = remove_cvref_t<ComputeDataType_>; // data type of compute tensor
using YDataType = remove_cvref_t<YDataType_>; // data type of output tensor
using BlockShape = remove_cvref_t<BlockShape_>; // block shapes and sizes
};
struct AddDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M,
S::WarpPerBlock_M,
S::ThreadPerWarp_M,
S::Vector_M>, // how many sub division is a block divided in
sequence<S::Repeat_N,
S::WarpPerBlock_N,
S::ThreadPerWarp_N,
S::Vector_N>>, // how many sub division is a block divided in
tuple<sequence<1, 2>, sequence<1, 2>>, // What are the shapes of those sub divisions
tuple<sequence<1, 1>, sequence<2, 2>>, // What are the shapes of those sub divisions
sequence<1, 1, 2, 2>, // How much data does a thread work on and how many iterations
// of warps are there
sequence<0, 3, 0, 3>>{}); // How much data does a thread work on and how many
// iterations of warps are there
}
};
template <typename Problem_, typename Policy_ = AddDefaultPolicy>
struct Add
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
CK_TILE_DEVICE void operator()(
const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;
const auto x_m_n_a = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::slc>(
p_x_a,
make_tuple(M, N),
make_tuple(N, 1),
number<S::Vector_N>{},
number<1>{}); // raw data, shape of tensor, stride of tensor, lastGarunteedVectorLength,
// lastGarunteedVectorStride
const auto x_m_n_b = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::slc>(
p_x_b, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::slc>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
const auto iM = get_block_id() * S::Block_M; // origin of the block along
auto x_window_a = make_tile_window(x_m_n_a,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto x_window_b = make_tile_window(x_m_n_b,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto y_window = make_tile_window(y_m_n,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto xa = load_tile(x_window_a);
const auto xb = load_tile(x_window_b);
auto y_compute = load_tile(y_window);
constexpr auto spans = decltype(xa)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
const auto x = ck_tile::type_convert<ComputeDataType>(xa[i_j_idx]);
const auto y = ck_tile::type_convert<ComputeDataType>(xb[i_j_idx]);
y_compute(i_j_idx) = x + y;
});
});
store_tile(y_window, cast_tile<YDataType>(y_compute));
move_tile_window(x_window_a, {0, S::Block_N});
move_tile_window(x_window_b, {0, S::Block_N});
move_tile_window(y_window, {0, S::Block_N});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType, typename YDataType>
CK_TILE_HOST void reference_add(const HostTensor<XDataType>& xa_m_n,
const HostTensor<XDataType>& xb_m_n,
HostTensor<YDataType>& y_m_n)
{
auto f = [&](auto m) {
const int N = xa_m_n.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
y_m_n(m, n) = ck_tile::type_convert<YDataType>(xa_m_n(m, n)) +
ck_tile::type_convert<YDataType>(xb_m_n(m, n));
}
};
make_ParallelTensorFunctor(f,
y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,26 @@
set(EXAMPLE_BASIC_GEMM "basic_gemm")
message("adding example ${EXAMPLE_BASIC_GEMM}")
add_executable(${EXAMPLE_BASIC_GEMM} EXCLUDE_FROM_ALL gemm.cpp)
target_include_directories(${EXAMPLE_BASIC_GEMM} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS)
# generate assembly
# list(APPEND EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
if(DEFINED kernel)
message("Compiling with Kernel: ${kernel}")
target_compile_definitions(${EXAMPLE_BASIC_GEMM} PRIVATE KERNEL_${kernel}=1)
endif()
target_compile_options(${EXAMPLE_BASIC_GEMM} PRIVATE ${EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,372 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_asmem_bsmem_creg_default_policy.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegDefaultPolicy>
struct BlockGemmASmemBSmemCReg
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using WarpGemm = remove_cvref_t<
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
static constexpr index_t MWarp =
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
static constexpr index_t NWarp =
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
#if defined(ENABLE_PREFETCH)
// A block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM);
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
// B block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN);
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile aWarpTile;
BLdsTile bWarpTile;
// Prefetch from LDS to warp register
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
aWarpTile = load_tile(a_block_window);
bWarpTile = load_tile(b_block_window);
}
#endif
// C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
#if !defined(ENABLE_PREFETCH)
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// Construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
a_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// Read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_PREFETCH)
#pragma message("local data share prefetch")
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
#else
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
#else
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// Read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// Warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// Write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C = A * B
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
#if !defined(ENABLE_PREFETCH)
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// Construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
a_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
// Construct C-Block-Tensor
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
// Hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// Read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_PREFETCH)
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
#else
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
#else
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// Read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// Warp GEMM
if constexpr(KIterPerWarp == 0)
{
// c = a * b
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
// c += a * b
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
// Write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,100 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "config.h"
namespace ck_tile {
// Default policy for BlockGemmASmemBSmemCReg
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmASmemBSmemCRegDefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if defined(ADJUST_BLOCK_TILE_SHAPE)
constexpr index_t kMWarp = 2;
constexpr index_t kNWarp = 2;
#else
constexpr index_t kMWarp = 4;
constexpr index_t kNWarp = 1;
#endif
#if defined(NAIVE_IMPLEMENTATION)
#pragma message("mfma m32 n32 k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
}
#elif defined(USING_MFMA_32x32x_8x2)
#pragma message("mfma m32 n32 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
}
#elif defined(USING_MFMA_16x16x16)
#pragma message("mfma m16 n16 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
}
#elif defined(USING_MFMA_16x16x_16x2)
#pragma message("mfma m16 n16 k32")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
}
#endif
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,412 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = ck_tile::BlockGemmPipelineAGmemBGmemCRegDefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCReg
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
#if defined(ENABLE_INSTRUCTION_SCH)
static constexpr index_t kPackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetSmemPack() { return Policy::template GetSmemPack<Problem>(); }
static constexpr bool HasHotLoop = Problem::HasHotLoop;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemm::MWarp;
constexpr index_t WaveNumN = BlockGemm::NWarp;
constexpr index_t AB_LDS_RW_Width = GetSmemPack();
constexpr index_t A_Buffer_Load_Inst_Num =
kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t B_LDS_Write_Inst_Num =
kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock /
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
constexpr auto num_mfma_per_dswrite_a =
(num_mfma_per_issue - num_dswrite_per_issue_a * 2 >= 1) ? 2 : 1;
constexpr auto num_mfma_per_dswrite_b =
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_mfma_per_dswrite_a *
num_dswrite_per_issue_a,
0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_mfma_per_dswrite_b *
num_dswrite_per_issue_b,
0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
#endif
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// -----------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
#if defined(ENABLE_PREFETCH)
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()));
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()));
#else
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
#endif
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
// -------------------------------------------------------------------------------------
// Gemm pipeline start
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
#if defined(ENABLE_PREFETCH)
#pragma message("global prefetch")
// Prefetch
// Global read 0
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
if(num_loop > 1)
{
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// LDS write 0
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
// Global read 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
// Prefetch from LDS to warp register in block gemm
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
}
__builtin_amdgcn_sched_barrier(0);
// Main body
if(num_loop > 2)
{
index_t iCounter = 0;
do
{
block_sync_lds();
// LDS write 1
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
// Global read 2
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// Prefetch from LDS to warp register in block gemm
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
#if defined(ENABLE_INSTRUCTION_SCH)
HotLoopScheduler();
#endif
__builtin_amdgcn_sched_barrier(0);
iCounter += 1;
} while(iCounter < (num_loop - 2));
}
// Tail
if(num_loop > 1)
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
}
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
#else
// non-prefetch
index_t iCounter = num_loop;
while(iCounter > 0)
{
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
iCounter--;
}
#endif
return c_block_tile;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,352 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "block_gemm_asmem_bsmem_creg.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "config.h"
namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCReg
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
{
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
#if defined(NAIVE_IMPLEMENTATION)
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_K_FIRST)
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_MN_FIRST)
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
number<kMPerBlock / MLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
return a_lds_block_desc;
}
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
#if defined(PADDING_K_FIRST) || defined(NAIVE_IMPLEMENTATION)
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_K_FIRST)
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_MN_FIRST)
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
number<kNPerBlock / NLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
#if defined(ENABLE_INSTRUCTION_SCH)
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{
return (PackedSize * 32 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (PackedSize * 16 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
{
return (PackedSize * 8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
{
return (PackedSize * 4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
{
return (PackedSize * 2 / sizeof(DataType));
}
else
{
return PackedSize;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPack()
{
constexpr index_t kKPack = 8;
return kKPack;
}
#endif
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
return BlockGemmASmemBSmemCReg<Problem>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#if defined(KERNEL_A)
#define PADDING_K_FIRST
#define USING_MFMA_32x32x_8x2
#elif defined(KERNEL_B)
#define PADDING_K_FIRST
#define USING_MFMA_16x16x16
#elif defined(KERNEL_C)
#define PADDING_K_FIRST
#define USING_MFMA_16x16x_16x2
#elif defined(KERNEL_D)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#elif defined(KERNEL_E)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#elif defined(KERNEL_F)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#elif defined(KERNEL_G)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#define ENABLE_INSTRUCTION_SCH
#elif defined(KERNEL_H)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#define ENABLE_INSTRUCTION_SCH
#define ENABLE_CACHE_AWARE_WG_SCH
#else
#define NAIVE_IMPLEMENTATION
#endif

View File

@@ -0,0 +1,202 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include "config.h"
#include "ck_tile/host.hpp"
#include "gemm.hpp"
#include "reference_gemm.hpp"
/*
* Toy code of GEMM
* Assume simplest case.
* A [M, K]
* B [N, K]
* C [M, N]
*/
// elementwise lambda
struct CElementFunction
{
template <typename X>
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
{
return x;
}
};
int main(int argc, char* argv[])
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
ck_tile::index_t verification = 0;
ck_tile::index_t M = 3328;
ck_tile::index_t N = 4096;
ck_tile::index_t K = 4096;
if(argc == 2)
{
verification = std::stoi(argv[1]);
}
if(argc == 5)
{
verification = std::stoi(argv[1]);
M = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
}
#if defined(KERNEL_A)
printf("*** Kernel A test *** \n");
printf(" --> Using mfma_32x32x(8x2)\n");
#elif defined(KERNEL_B)
printf("*** Kernel B test *** \n");
printf(" --> Using mfma_16x16x16\n");
#elif defined(KERNEL_C)
printf("*** Kernel C test *** \n");
printf(" --> Using mfma_16x16x(16x2)\n");
#elif defined(KERNEL_D)
printf("*** Kernel D test *** \n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
#elif defined(KERNEL_E)
printf("*** Kernel E test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
#elif defined(KERNEL_F)
printf("*** Kernel F test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
#elif defined(KERNEL_G)
printf("*** Kernel G test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
printf(" --> Enable instruction schedule\n");
#elif defined(KERNEL_H)
printf("*** Kernel H test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
printf(" --> Enable instruction schedule\n");
printf(" --> Enable cache-aware thread blocks schedule\n");
#else
printf("*** Naive implementation test ***\n");
#endif
const ck_tile::index_t Lda = K;
const ck_tile::index_t Ldb = K;
const ck_tile::index_t Ldc = N;
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
// host verify
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.mData.data());
b_buf.ToDevice(b_host.mData.data());
// Alignment
constexpr ck_tile::index_t kAAlignment = 8;
constexpr ck_tile::index_t kBAlignment = 8;
constexpr ck_tile::index_t kCAlignment = 8;
constexpr ck_tile::index_t kBlockSize = 256;
#ifdef ADJUST_BLOCK_TILE_SHAPE
constexpr ck_tile::index_t kGemmMPerBlock = 128;
constexpr ck_tile::index_t kGemmKPerBlock = 64;
#else
constexpr ck_tile::index_t kGemmMPerBlock = 256;
constexpr ck_tile::index_t kGemmKPerBlock = 32;
#endif
constexpr ck_tile::index_t kGemmNPerBlock = 128;
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
std::cout << "grid size " << kGridSize << std::endl;
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
using gemm_kernel = ck_tile::Gemm<ADataType,
BDataType,
AccDataType,
CDataType,
CElementFunction,
kAAlignment,
kBAlignment,
kCAlignment,
kBlockSize,
kGemmMPerBlock,
kGemmNPerBlock,
kGemmKPerBlock>;
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
gemm_kernel{},
kGridSize,
kBlockSize,
0,
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
M,
N,
K,
Lda,
Ldb,
Ldc,
CElementFunction{}));
auto pass = true;
if(verification)
{
// reference gemm
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
a_host, b_host, c_host_ref);
c_buf.FromDevice(c_host_dev.mData.data());
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
return !pass;
}

View File

@@ -0,0 +1,195 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
#include "config.h"
#include "grid_gemm.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename CElementFunction_>
struct GridGemmProblem
{
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CDataType = CDataType_;
using CElementFunction = CElementFunction_;
};
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
struct TileGemmShape
{
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
// C = A * B
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename CElementFunction,
index_t kAAlignment,
index_t kBAlignment,
index_t kCAlignment,
index_t kBlockSize_,
index_t kMPerBlock_,
index_t kNPerBlock_,
index_t kKPerBlock_>
struct Gemm
{
using GridGemmProblem =
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
struct GridGemmPolicy
{
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMPerBlock = kMPerBlock_;
static constexpr index_t kNPerBlock = kNPerBlock_;
static constexpr index_t kKPerBlock = kKPerBlock_;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
{
#if defined(ENABLE_CACHE_AWARE_WG_SCH)
return [=](index_t block_1d_id) {
constexpr index_t M01 = 4;
constexpr index_t GroupNum = 8;
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
const auto update_M0 =
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
const auto xcd_id = block_1d_id % GroupNum;
const auto l_block_id = block_1d_id - (xcd_id % 2);
const auto ridn = GroupNum * M01 * (update_N0 / 2);
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
const auto lu = (l_block_id % GroupNum) + rid * ridn;
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
const auto sub_M0_id =
(l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
const auto total_update_size = update_N0 * update_M0;
if(block_1d_id >= total_update_size)
{
auto x = (block_1d_id + 1) - total_update_size;
auto rlen = N0 - update_N0;
auto rm = 0;
auto rn = 0;
if(rlen > 0)
{
rm = (x - 1) / rlen;
rn = x % rlen;
}
if(rlen > 0 and rm < M0)
{
n = rn + update_N0;
m = rm;
}
else
{
x = x - rlen * M0;
rm = (x - 1) / update_N0;
rn = x % update_N0;
n = rn;
m = update_M0 + rm;
}
}
return make_multi_index(m, n);
};
#else
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
return [unmerge](index_t block_id) {
multi_index<2> unmerged;
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
};
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
{
using BlockGemmPipelineProblem_ =
BlockGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
kBlockSize,
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
}
};
using GridGemm = GridGemm<GridGemmProblem, GridGemmPolicy>;
CK_TILE_DEVICE void operator()(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t Lda,
const index_t Ldb,
const index_t Ldc,
const CElementFunction& c_element_func) const
{
const auto a_dram = [&] {
return make_naive_tensor_view<address_space_enum::global>(
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
}();
const auto b_dram = [&] {
return make_naive_tensor_view<address_space_enum::global>(
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
}();
const auto c_dram = [&] {
return make_naive_tensor_view<address_space_enum::global>(
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
}();
GridGemm{}(a_dram, b_dram, c_dram, c_element_func);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename Problem, typename Policy>
struct GridGemm
{
using ADataType = typename Problem::ADataType;
using BDataType = typename Problem::BDataType;
using CDataType = typename Problem::CDataType;
using AccDataType = typename Problem::AccDataType;
using CElementFunction = typename Problem::CElementFunction;
static constexpr auto kMPerBlock = Policy::kMPerBlock;
static constexpr auto kNPerBlock = Policy::kNPerBlock;
static constexpr auto kKPerBlock = Policy::kKPerBlock;
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
const BGridTensorView& b_grid,
CGridTensorView& c_grid,
const CElementFunction& c_element_func) const
{
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
// divide problem
const auto id_block = get_block_id();
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
const auto id_tile = block2tile(id_block);
const auto iM =
__builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock);
const auto iN =
__builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock);
// A block window
auto a_block_window = make_tile_window(
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
// B block window
auto b_block_window = make_tile_window(
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
const auto acc_block_tile =
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
// cast to CDataType and apply CElementFunction
const auto c_block_tile = tile_elementwise_in(
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
acc_block_tile);
// store C
auto c_window = make_tile_window(
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
store_tile(c_window, c_block_tile);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
const ck_tile::HostTensor<BDataType>& b_n_k,
ck_tile::HostTensor<CDataType>& c_m_n)
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_n_k(n, k);
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
}
};
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
struct StreamConfig
{
hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false;
int log_level_ = 0;
};

View File

@@ -0,0 +1,38 @@
set(EXAMPLE_BASIC_FLASH_ATTENTION "basic_flash_attention_fwd")
message("adding example ${EXAMPLE_BASIC_FLASH_ATTENTION}")
add_executable(${EXAMPLE_BASIC_FLASH_ATTENTION} EXCLUDE_FROM_ALL flash_attention_fwd.cpp)
target_include_directories(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF)
if(ENABLE_TOY_FA_FWD_OPT)
message("Compiling with toy FA fwd optimization")
target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_OPT)
endif()
option(ENABLE_TOY_FA_FWD_QK_SWIZZLE "Enable toy FA fwd QK swizzle" OFF)
if(ENABLE_TOY_FA_FWD_QK_SWIZZLE)
message("Compiling with toy FA fwd QK swizzle")
target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_QK_SWIZZLE)
endif()
option(ENABLE_TOY_FA_FWD_CACHE_AWARE "Enable toy FA fwd cache aware" OFF)
if(ENABLE_TOY_FA_FWD_CACHE_AWARE)
message("Compiling with toy FA fwd cache aware")
target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_CACHE_AWARE)
endif()
target_compile_options(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE ${EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// Problem Description for BlockGemmARegBSmemCReg
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmARegBSmemCRegProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,565 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_areg_bsmem_creg_problem.hpp"
#include "block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BlockGemmPolicy = Policy;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kPackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// B block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto config =
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MPerXDL = WG::kM;
constexpr index_t NPerXDL = WG::kN;
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = config.template get<1>();
constexpr index_t B_LDS_RW_Width = SmemPack;
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
// B split schedule
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_b_issue_cycle =
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
constexpr auto num_mfma_per_dswrite_b =
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_mfma_per_dswrite_b *
num_dswrite_per_issue_b,
0); // MFMA
});
// stage 2
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BLdsTile& b_block_tensor_tmp) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
__device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Construct C-Block-Tensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
namespace ck_tile {
struct BlockGemmARegBSmemCRegV1DefaultPolicy
{
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
}
else if constexpr(kM0 == 128)
{
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#else
#pragma message("Enable toy FA fwd QK swizzle")
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
else
{
static_assert(false, "Unsupported configuration for warp execution.");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
namespace ck_tile {
struct BlockGemmARegBSmemCRegV1K8Policy
{
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
}
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
}
else if constexpr(kM0 == 128)
{
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}, 4, 1);
#endif
}
else
{
static_assert(false, "Unsupported configuration for warp execution.");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,397 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy>
struct BlockGemmPipelineAGmemBGmemCReg
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t k_loops = Policy::AKDim / kKPerBlock;
// Move this part into Policy?
__host__ __device__ static constexpr index_t GetStaticLdsSize()
{
return sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
// Cold A Register Cache
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename ARegBlockTensorTmp>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
ignore = a_element_func;
ignore = b_element_func;
// A tile in RegblockTensor
// This tensor distribution used to construct both distributed tensor for local buffer store
// and read. without buffer address info
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
// B tile in LDS, blockWindow
BDataType* p_b_lds =
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
// This tensor view used to construct both tile window for lds store and read, with buffer
// address info
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile
auto c_block_tile = decltype(block_gemm(
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
b_lds_gemm_window)){};
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(k_loops > 1)
{
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, 0>{},
sequence<kMPerBlock, kKPerBlock>{});
a_block_tile = load_tile(a_copy_dram_window);
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, (i_k0 + 1) * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 2) * kKPerBlock>{});
a_block_tile = load_tile(a_copy_dram_window);
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
block_gemm.HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
});
}
// tail
{
if constexpr(k_loops > 1)
{
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
}
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{});
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
b_copy_lds_window);
}
set_slice_tile(a_reg_block_tensor_tmp,
a_copy_reg_tensor,
sequence<0, 0>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{});
return c_block_tile;
}
// Hot A Register Cache
template <typename BDramBlockWindowTmp, typename BElementFunction, typename ARegBlockTensorTmp>
__host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
static_assert(
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
ignore = b_element_func;
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// A tile in RegblockTensor
// This tensor distribution used to construct both distributed tensor for local buffer store
// and read. without buffer address info
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
set_slice_tile(a_copy_reg_tensor,
a_reg_block_tensor_tmp,
sequence<0, 0>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{});
// B tile in LDS, blockWindow
BDataType* p_b_lds =
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
// This tensor view used to construct both tile window for lds store and read, with buffer
// address info
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
// Acc register tile
auto c_block_tile = decltype(block_gemm(
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
b_lds_gemm_window)){};
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
#if !defined(TOY_FA_FWD_OPT)
static_for<0, k_loops, 1>{}([&](auto i_k0) {
auto b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
});
#else
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
BLdsTile bWarpTile;
// Global read 0
auto b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
if constexpr(k_loops > 1)
{
// LDS write 0
store_tile(b_copy_lds_window, b_block_tile);
// Global read 1
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
block_sync_lds();
// LDS read 0
bWarpTile = load_tile(b_lds_gemm_window);
}
if constexpr(k_loops > 2)
{
__builtin_amdgcn_sched_barrier(0);
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
// LDS write 1
store_tile(b_copy_lds_window, b_block_tile);
// Global read 2
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
bWarpTile);
block_sync_lds();
// LDS read 1
bWarpTile = load_tile(b_lds_gemm_window);
block_gemm.HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
});
}
// tail
{
if constexpr(k_loops > 1)
{
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
bWarpTile);
block_sync_lds();
}
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
bWarpTile = load_tile(b_lds_gemm_window);
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
bWarpTile);
}
#endif
return c_block_tile;
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename ARegBlockTensorTmp>
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
a_reg_block_tensor_tmp,
p_smem);
}
template <typename BDramBlockWindowTmp, typename ARegBlockTensorTmp>
__device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
return operator()(
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
a_reg_block_tensor_tmp,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,144 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
namespace ck_tile {
template <index_t AKDim_>
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
{
static constexpr index_t AKDim = AKDim_;
template <typename Problem>
__host__ __device__ static constexpr auto GetBlockGemm()
{
using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy;
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
{
constexpr auto blockgemm = GetBlockGemm<Problem>();
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = AKDim;
constexpr auto config =
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
return a_block_dstr;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeADramTileDistribution()
{
return MakeARegBlockDescriptor<Problem>();
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
number<kNPerBlock / NLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,25 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,197 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include "ck_tile/host.hpp"
#include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp"
#include "flash_attention_fwd.hpp"
/*
* Toy code of flash attention forward pass
* Assume simplest case.
* Q [Batch, HeadNum, SeqenceLengthQ, HeadDim]
* K [Batch, HeadNum, SeqenceLengthK, HeadDim]
* V [Batch, HeadNum, HeadDim, SeqenceLengthK]
* O [Batch, HeadNum, SeqenceLengthQ, HeadDim]
*/
int main(int argc, char* argv[])
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using SaccDataType = float;
using SMPLComputeDataType = float;
using PDataType = ck_tile::half_t;
using OaccDataType = float;
using ODataType = ck_tile::half_t;
ck_tile::index_t Batch = 64; // Batch Number * Head Number
ck_tile::index_t M0 = 4096; // SequenceLengthQ
ck_tile::index_t N0 = 4096; // SequencelengthK
ck_tile::index_t K0 = 128; // HeadDim
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
if(argc == 3)
{
init_method = std::stoi(argv[1]);
verification = std::stoi(argv[2]);
}
else if(argc == 8)
{
init_method = std::stoi(argv[1]);
verification = std::stoi(argv[2]);
Batch = std::stoi(argv[3]);
M0 = std::stoi(argv[4]);
N0 = std::stoi(argv[5]);
K0 = std::stoi(argv[6]);
N1 = std::stoi(argv[7]);
}
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
std::array<ck_tile::index_t, 3> q_strides{M0 * K0, K0, 1};
std::array<ck_tile::index_t, 3> k_lengths{Batch, N0, K0};
std::array<ck_tile::index_t, 3> k_strides{N0 * K0, K0, 1};
std::array<ck_tile::index_t, 3> v_lengths{Batch, N1, N0};
std::array<ck_tile::index_t, 3> v_strides{N1 * N0, N0, 1};
std::array<ck_tile::index_t, 3> s_lengths{Batch, M0, N0};
std::array<ck_tile::index_t, 3> s_strides{M0 * N0, N0, 1};
std::array<ck_tile::index_t, 3> p_lengths{Batch, M0, N0};
std::array<ck_tile::index_t, 3> p_strides{M0 * N0, N0, 1};
std::array<ck_tile::index_t, 3> o_lengths{Batch, M0, N1};
std::array<ck_tile::index_t, 3> o_strides{M0 * N1, N1, 1};
// host verify
ck_tile::HostTensor<QDataType> q_host(q_lengths, q_strides);
ck_tile::HostTensor<KDataType> k_host(k_lengths, k_strides);
ck_tile::HostTensor<VDataType> v_host(v_lengths, v_strides);
ck_tile::HostTensor<ODataType> o_host_dev(o_lengths, o_strides);
switch(init_method)
{
case 0: break;
case 1:
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
break;
case 2:
ck_tile::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck_tile::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
break;
default:
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host_dev.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.mData.data());
k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data());
constexpr ck_tile::index_t kM0PerBlock = 128;
constexpr ck_tile::index_t kN0PerBlock = 128;
constexpr ck_tile::index_t kK0PerBlock = 32;
constexpr ck_tile::index_t kN1PerBlock = 128;
constexpr ck_tile::index_t kK1PerBlock = 32;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kHeadDim = 128;
ck_tile::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
std::cout << "grid size " << kGridSize << std::endl;
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
ck_tile::FlashAttentionFwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{},
kGridSize,
kBlockSize,
0,
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
static_cast<ODataType*>(o_buf.GetDeviceBuffer()),
M0,
N0,
K0,
N1,
Batch,
K0, // StrideQ
K0, // StrideK
N0, // StrideV
N1, // StrideO
M0 * K0, // BatchStrideQ
N0 * K0, // BatchStrideK
N1 * N0, // BatchStrideV
M0 * N1)); // BatchStrideO
// reference
auto pass = true;
if(verification)
{
o_buf.FromDevice(o_host_dev.mData.data());
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
ck_tile::HostTensor<PDataType> p_host_ref(p_lengths, p_strides);
ck_tile::HostTensor<ODataType> o_host_ref(o_lengths, o_strides);
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host, k_host, s_host_ref);
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref);
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host, o_host_ref);
pass &= ck_tile::check_err(o_host_dev, o_host_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
std::size_t flop =
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
std::size_t num_btype =
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 +
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
return !pass;
}

View File

@@ -0,0 +1,178 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_pipeline_problem.hpp"
#include "block_gemm_areg_bsmem_creg_v1.hpp"
#include "flash_attention_fwd_impl.hpp"
namespace ck_tile {
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
{
return [=](index_t block_1d_id) {
constexpr index_t M01 = 4;
constexpr index_t GroupNum = 8;
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
const auto update_M0 =
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
const auto xcd_id = block_1d_id % GroupNum;
const auto l_block_id = block_1d_id - (xcd_id % 2);
const auto ridn = GroupNum * M01 * (update_N0 / 2);
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
const auto lu = (l_block_id % GroupNum) + rid * ridn;
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
const auto total_update_size = update_N0 * update_M0;
if(block_1d_id >= total_update_size)
{
auto x = (block_1d_id + 1) - total_update_size;
auto rlen = N0 - update_N0;
auto rm = 0;
auto rn = 0;
if(rlen > 0)
{
rm = (x - 1) / rlen;
rn = x % rlen;
}
if(rlen > 0 and rm < M0)
{
n = rn + update_N0;
m = rm;
}
else
{
x = x - rlen * M0;
rm = (x - 1) / update_N0;
rn = x % update_N0;
n = rn;
m = update_M0 + rm;
}
}
return make_multi_index(m, n);
};
}
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
index_t kBlockSize,
index_t kHeadDim,
index_t kM0PerBlock,
index_t kN0PerBlock,
index_t kK0PerBlock,
index_t kN1PerBlock,
index_t kK1PerBlock>
struct FlashAttentionFwd
{
__device__ void operator()(const QDataType* q_ptr,
const KDataType* k_ptr,
const VDataType* v_ptr,
ODataType* o_ptr,
const index_t M0,
const index_t N0,
const index_t K0,
const index_t N1,
const index_t /* Batch */,
const index_t StrideQ,
const index_t StrideK,
const index_t StrideV,
const index_t StrideO,
const index_t BatchStrideQ,
const index_t BatchStrideK,
const index_t BatchStrideV,
const index_t BatchStrideO) const
{
const index_t id_block = get_block_id();
const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock);
const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock);
#if defined(TOY_FA_FWD_CACHE_AWARE)
#pragma message("Enable toy FA fwd cache aware")
const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1);
const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0;
const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0);
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) %
num_tile_m0 * kM0PerBlock);
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) %
num_tile_n1 * kN1PerBlock);
#else
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return make_tuple(quotient, modulus);
};
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
#endif
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{};
kernel_impl(q_ptr + iBatch * BatchStrideQ,
k_ptr + iBatch * BatchStrideK,
v_ptr + iBatch * BatchStrideV,
o_ptr + iBatch * BatchStrideO,
M0,
N0,
K0,
N1,
StrideQ,
StrideK,
StrideV,
StrideO,
iM0,
iN1);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,440 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp"
#include "block_gemm_pipeline_problem.hpp"
#include "block_gemm_areg_bsmem_creg_v1.hpp"
#include "tile_gemm_shape.hpp"
namespace ck_tile {
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
index_t kBlockSize,
index_t kHeadDim,
index_t kM0PerBlock,
index_t kN0PerBlock,
index_t kK0PerBlock,
index_t kN1PerBlock,
index_t kK1PerBlock>
struct FlashAttentionFwdImpl
{
// block gemm0 pipeline
using BlockGemm0Problem =
BlockGemmPipelineProblem<QDataType,
KDataType,
SaccDataType,
kBlockSize,
TileGemmShape<kM0PerBlock, kN0PerBlock, kK0PerBlock>>;
using BlockGemm0Policy =
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg<BlockGemm0Problem, BlockGemm0Policy>;
// block gemm1
using BlockGemm1 = BlockGemmARegBSmemCRegV1<
BlockGemmARegBSmemCRegProblem<PDataType,
VDataType,
OaccDataType,
kBlockSize,
TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;
// 3d, with padding
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
constexpr index_t kKPack = 4;
#else
constexpr index_t kKPack = 8;
#endif
constexpr auto dataTypeSize = sizeof(VDataType);
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
number<kNPerBlock / NLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
__device__ static constexpr auto MakeVDramTileDistribution()
{
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
__device__ static constexpr index_t GetStaticLdsSize()
{
return max(BlockGemm0Pipeline::GetStaticLdsSize(),
static_cast<index_t>(MakeVLdsBlockDescriptor().get_element_space_size() *
sizeof(VDataType)));
}
__device__ void operator()(const QDataType* q_ptr,
const KDataType* k_ptr,
const VDataType* v_ptr,
ODataType* o_ptr,
const index_t M0,
const index_t N0,
const index_t K0,
const index_t N1,
const index_t StrideQ,
const index_t StrideK,
const index_t StrideV,
const index_t StrideO,
const index_t iM0,
const index_t iN1) const
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
// Block GEMM0 pipeline and Block GEMM1
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
constexpr auto gemm1 = BlockGemm1{};
// allocate LDS
__shared__ char smem_ptr[GetStaticLdsSize()];
// Q/K/V DRAM and DRAM window
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
const auto v_dram = make_naive_tensor_view<address_space_enum::global>(
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}),
{iM0, 0},
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{iN1, 0},
MakeVDramTileDistribution());
// Q in register
auto q_reg_tensor = load_tile(q_dram_window);
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
#if defined(TOY_FA_FWD_OPT)
// V LDS tile window for store
auto v_copy_lds_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
v_dram_window.get_tile_distribution());
// V LDS tile for block GEMM
auto v_lds_gemm_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
#else
auto v_lds_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
#endif
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SaccBlockTileType =
decltype(gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, nullptr));
using SBlockTileType = decltype(tile_elementwise_in(
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
SaccBlockTileType{}));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm1(
get_slice_tile(
PBlockTileType{}, sequence<0, 0>{}, sequence<kM0PerBlock, kK1PerBlock>{}),
v_dram_window));
// init Sacc, Oacc, M, L
auto s_acc = SaccBlockTileType{};
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
tile_elementwise_inout([](auto& e) { e = 0; }, o_acc);
tile_elementwise_inout(
[](auto& e) { e = std::numeric_limits<SMPLComputeDataType>::lowest(); }, m);
tile_elementwise_inout([](auto& e) { e = 0; }, l);
// loop over Column of S (J loop)
index_t iN0 = 0;
do
{
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
// S{j}
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
#if defined(TOY_FA_FWD_OPT)
// prefetch load v tile
auto v_prefetch = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1PerBlock});
#endif
// m_local = rowmax(S{j})
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
block_tile_reduce_sync(m_local, f_max);
// m{j-1}
const auto m_old = m;
// m{j}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
// Pcompute{j}
auto p_compute =
make_static_distributed_tensor<SMPLComputeDataType>(s.get_tile_distribution());
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = exp(s[i_j_idx] - m[i_idx]);
});
});
// rowsum(Pcompute{j})
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
#if !defined(TOY_FA_FWD_OPT)
// type cast Pcompute{j} into P{j}
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Oacc{j}
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
move_tile_window(v_dram_window, {0, kK1PerBlock});
store_tile(v_lds_window, v);
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
block_sync_lds();
});
#else
using VLdsTile = typename decltype(gemm1)::BLdsTile;
VLdsTile vWarpTile;
// type cast Pcompute{j} into P{j}
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Oacc{j}
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
if constexpr(k1_loops > 1)
{
store_tile(v_copy_lds_window, v_prefetch);
v_prefetch = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1PerBlock});
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
}
if constexpr(k1_loops > 2)
{
__builtin_amdgcn_sched_barrier(0);
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
block_sync_lds();
// LDS write 1
store_tile(v_copy_lds_window, v_prefetch);
// Global read 2
v_prefetch = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1PerBlock});
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1.template HotLoopScheduler<8, 4>();
__builtin_amdgcn_sched_barrier(0);
});
}
// tail
{
if constexpr(k1_loops > 1)
{
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
}
store_tile(v_copy_lds_window, v_prefetch);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
sequence<kM0PerBlock, kN0PerBlock>{}),
vWarpTile);
block_sync_lds();
}
#endif
// move tile windows
move_tile_window(k_dram_window, {kN0PerBlock, 0});
iN0 += kN0PerBlock;
} while(iN0 < N0);
// Oacc
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = 1 / l[i_idx];
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
// type cast Oacc into O
const auto o = tile_elementwise_in(type_convert<ODataType, OaccDataType>, o_acc);
// O DRAM and O DRAM window
auto o_dram = make_naive_tensor_view<address_space_enum::global>(
o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), number<32>{}, number<1>{});
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<kM0PerBlock>{}, number<kN1PerBlock>{}),
{iM0, iN1},
o.get_tile_distribution());
// store O
store_tile(o_dram_window, o);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_batched_gemm(const ck_tile::HostTensor<ADataType>& a_b_m_k,
const ck_tile::HostTensor<BDataType>& b_b_n_k,
ck_tile::HostTensor<CDataType>& c_b_m_n)
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_b_m_k(batch, m, k);
BDataType v_b = b_b_n_k(batch, n, k);
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(v_acc);
}
};
ck_tile::make_ParallelTensorFunctor(
f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}

View File

@@ -0,0 +1,48 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_batched_softmax(const ck_tile::HostTensor<ADataType>& a_b_m_n,
ck_tile::HostTensor<BDataType>& b_b_m_n)
{
const int N = a_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
AccDataType v_max = std::numeric_limits<ADataType>::lowest();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
v_max = v_max < v_a ? v_a : v_max;
}
AccDataType v_exp_sum = 0;
// sum
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
v_exp_sum += ck_tile::exp(v_a - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
b_b_m_n(batch, m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
}
};
ck_tile::make_ParallelTensorFunctor(
f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}

View File

@@ -0,0 +1,18 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
struct TileGemmShape
{
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
} // namespace ck_tile

View File

@@ -0,0 +1,66 @@
set(FLASH_ATTENTION_FWD_KNOWN_APIS "fwd")
set(FLASH_ATTENTION_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FLASH_ATTENTION_FWD_KNOWN_APIS}) & link, or \"all\".")
if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all")
set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS})
endif()
option(TOY_FA_FWD_OPT "Enable toy flash attention forward optimization" ON)
option(TOY_FA_FWD_QK_SWIZZLE "Enable toy flash attention forward QK swizzle" OFF)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list Flash Attention kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/flash_attention_fwd_blobs.txt FLASH_ATTENTION_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--gen_blobs
)
set(EXAMPLE_FA "codegen_basic_flash_attention_fwd")
message("adding example ${EXAMPLE_FA}")
add_executable(${EXAMPLE_FA}
EXCLUDE_FROM_ALL
flash_attention_fwd.cpp
)
target_compile_definitions(${EXAMPLE_FA}
PRIVATE
$<$<BOOL:${TOY_FA_FWD_OPT}>:TOY_FA_FWD_OPT=1>
)
target_include_directories(${EXAMPLE_FA}
PRIVATE
${CMAKE_CURRENT_LIST_DIR}
)
target_sources(${EXAMPLE_FA} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS})
message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}")
set(EXAMPLE_FA_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FA_COMPILE_OPTIONS
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
target_compile_options(${EXAMPLE_FA}
PRIVATE
${EXAMPLE_FA_COMPILE_OPTIONS}
)
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// Problem Description for BlockGemmARegBSmemCReg
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmARegBSmemCRegProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,565 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_areg_bsmem_creg_problem.hpp"
#include "block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BlockGemmPolicy = Policy;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kPackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// B block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto config =
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MPerXDL = WG::kM;
constexpr index_t NPerXDL = WG::kN;
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNumM = config.template get<1>();
constexpr index_t B_LDS_RW_Width = SmemPack;
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
// B split schedule
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_b_issue_cycle =
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
constexpr auto num_mfma_per_dswrite_b =
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_mfma_per_dswrite_b *
num_dswrite_per_issue_b,
0); // MFMA
});
// stage 2
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BLdsTile& b_block_tensor_tmp) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
__device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Construct C-Block-Tensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
namespace ck_tile {
struct BlockGemmARegBSmemCRegV1DefaultPolicy
{
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
}
else if constexpr(kM0 == 128)
{
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
#else
#pragma message("Enable toy FA fwd QK swizzle")
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
else
{
static_assert(false, "Unsupported configuration for warp execution.");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
namespace ck_tile {
struct BlockGemmARegBSmemCRegV1K8Policy
{
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
}
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
}
else if constexpr(kM0 == 128)
{
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}, 4, 1);
#endif
}
else
{
static_assert(false, "Unsupported configuration for warp execution.");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,397 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy>
struct BlockGemmPipelineAGmemBGmemCReg
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t k_loops = Policy::AKDim / kKPerBlock;
// Move this part into Policy?
__host__ __device__ static constexpr index_t GetStaticLdsSize()
{
return sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
// Cold A Register Cache
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
typename ARegBlockTensorTmp>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
ignore = a_element_func;
ignore = b_element_func;
// A tile in RegblockTensor
// This tensor distribution used to construct both distributed tensor for local buffer store
// and read. without buffer address info
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
// B tile in LDS, blockWindow
BDataType* p_b_lds =
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
// This tensor view used to construct both tile window for lds store and read, with buffer
// address info
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile
auto c_block_tile = decltype(block_gemm(
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
b_lds_gemm_window)){};
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(k_loops > 1)
{
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, 0>{},
sequence<kMPerBlock, kKPerBlock>{});
a_block_tile = load_tile(a_copy_dram_window);
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, (i_k0 + 1) * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 2) * kKPerBlock>{});
a_block_tile = load_tile(a_copy_dram_window);
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
block_gemm.HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
});
}
// tail
{
if constexpr(k_loops > 1)
{
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
}
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{});
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
b_copy_lds_window);
}
set_slice_tile(a_reg_block_tensor_tmp,
a_copy_reg_tensor,
sequence<0, 0>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{});
return c_block_tile;
}
// Hot A Register Cache
template <typename BDramBlockWindowTmp, typename BElementFunction, typename ARegBlockTensorTmp>
__host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
static_assert(
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
ignore = b_element_func;
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// A tile in RegblockTensor
// This tensor distribution used to construct both distributed tensor for local buffer store
// and read. without buffer address info
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
set_slice_tile(a_copy_reg_tensor,
a_reg_block_tensor_tmp,
sequence<0, 0>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{});
// B tile in LDS, blockWindow
BDataType* p_b_lds =
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
// This tensor view used to construct both tile window for lds store and read, with buffer
// address info
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
// Acc register tile
auto c_block_tile = decltype(block_gemm(
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
b_lds_gemm_window)){};
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
#if !defined(TOY_FA_FWD_OPT)
static_for<0, k_loops, 1>{}([&](auto i_k0) {
auto b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
});
#else
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
BLdsTile bWarpTile;
// Global read 0
auto b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
if constexpr(k_loops > 1)
{
// LDS write 0
store_tile(b_copy_lds_window, b_block_tile);
// Global read 1
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
block_sync_lds();
// LDS read 0
bWarpTile = load_tile(b_lds_gemm_window);
}
if constexpr(k_loops > 2)
{
__builtin_amdgcn_sched_barrier(0);
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
// LDS write 1
store_tile(b_copy_lds_window, b_block_tile);
// Global read 2
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
bWarpTile);
block_sync_lds();
// LDS read 1
bWarpTile = load_tile(b_lds_gemm_window);
block_gemm.HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
});
}
// tail
{
if constexpr(k_loops > 1)
{
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
bWarpTile);
block_sync_lds();
}
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
bWarpTile = load_tile(b_lds_gemm_window);
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
bWarpTile);
}
#endif
return c_block_tile;
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename ARegBlockTensorTmp>
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
a_reg_block_tensor_tmp,
p_smem);
}
template <typename BDramBlockWindowTmp, typename ARegBlockTensorTmp>
__device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
void* p_smem) const
{
return operator()(
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
a_reg_block_tensor_tmp,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,144 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
namespace ck_tile {
template <index_t AKDim_>
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
{
static constexpr index_t AKDim = AKDim_;
template <typename Problem>
__host__ __device__ static constexpr auto GetBlockGemm()
{
using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy;
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
{
constexpr auto blockgemm = GetBlockGemm<Problem>();
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = AKDim;
constexpr auto config =
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
constexpr index_t NWarp = config.template get<2>();
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
return a_block_dstr;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeADramTileDistribution()
{
return MakeARegBlockDescriptor<Problem>();
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
number<kNPerBlock / NLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,25 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,173 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include "ck_tile/host.hpp"
#include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp"
#include "flash_attention_fwd.hpp"
/*
* Toy code of flash attention forward pass
* Assume simplest case.
* Q [Batch, HeadNum, SeqenceLengthQ, HeadDim]
* K [Batch, HeadNum, SeqenceLengthK, HeadDim]
* V [Batch, HeadNum, HeadDim, SeqenceLengthK]
* O [Batch, HeadNum, SeqenceLengthQ, HeadDim]
*/
int main(int argc, char* argv[])
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using SaccDataType = float;
using SMPLComputeDataType = float;
using PDataType = ck_tile::half_t;
using OaccDataType = float;
using ODataType = ck_tile::half_t;
ck_tile::index_t Batch = 64; // Batch Number * Head Number
ck_tile::index_t M0 = 4096; // SequenceLengthQ
ck_tile::index_t N0 = 4096; // SequencelengthK
ck_tile::index_t K0 = 128; // HeadDim
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
if(argc == 3)
{
init_method = std::stoi(argv[1]);
verification = std::stoi(argv[2]);
}
else if(argc == 8)
{
init_method = std::stoi(argv[1]);
verification = std::stoi(argv[2]);
Batch = std::stoi(argv[3]);
M0 = std::stoi(argv[4]);
N0 = std::stoi(argv[5]);
K0 = std::stoi(argv[6]);
N1 = std::stoi(argv[7]);
}
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
std::array<ck_tile::index_t, 3> q_strides{M0 * K0, K0, 1};
std::array<ck_tile::index_t, 3> k_lengths{Batch, N0, K0};
std::array<ck_tile::index_t, 3> k_strides{N0 * K0, K0, 1};
std::array<ck_tile::index_t, 3> v_lengths{Batch, N1, N0};
std::array<ck_tile::index_t, 3> v_strides{N1 * N0, N0, 1};
std::array<ck_tile::index_t, 3> s_lengths{Batch, M0, N0};
std::array<ck_tile::index_t, 3> s_strides{M0 * N0, N0, 1};
std::array<ck_tile::index_t, 3> p_lengths{Batch, M0, N0};
std::array<ck_tile::index_t, 3> p_strides{M0 * N0, N0, 1};
std::array<ck_tile::index_t, 3> o_lengths{Batch, M0, N1};
std::array<ck_tile::index_t, 3> o_strides{M0 * N1, N1, 1};
// host verify
ck_tile::HostTensor<QDataType> q_host(q_lengths, q_strides);
ck_tile::HostTensor<KDataType> k_host(k_lengths, k_strides);
ck_tile::HostTensor<VDataType> v_host(v_lengths, v_strides);
ck_tile::HostTensor<ODataType> o_host_dev(o_lengths, o_strides);
switch(init_method)
{
case 0: break;
case 1:
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
break;
case 2:
ck_tile::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck_tile::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
break;
default:
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host_dev.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.mData.data());
k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data());
// Construct the FlashAttnArgs object with your arguments
ck_tile::FlashAttnArgs<QDataType, KDataType, VDataType, ODataType> flash_attention_args{
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
static_cast<ODataType*>(o_buf.GetDeviceBuffer()),
M0,
N0,
K0,
N1,
Batch,
K0, // strideQ
K0, // strideK
N0, // strideV
N1, // strideO
M0 * K0, // batchStrideQ
N0 * K0, // batchStrideK
N1 * N0, // batchStrideV
M0 * N1 // batchStrideO
};
float ave_time = ck_tile::flash_attention_fwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType>(flash_attention_args,
ck_tile::stream_config{nullptr, true});
// reference
auto pass = true;
if(verification)
{
o_buf.FromDevice(o_host_dev.mData.data());
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
ck_tile::HostTensor<PDataType> p_host_ref(p_lengths, p_strides);
ck_tile::HostTensor<ODataType> o_host_ref(o_lengths, o_strides);
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host, k_host, s_host_ref);
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref);
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host, o_host_ref);
pass &= ck_tile::check_err(o_host_dev, o_host_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
std::size_t flop =
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
std::size_t num_btype =
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 +
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
return !pass;
}

View File

@@ -0,0 +1,148 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_pipeline_problem.hpp"
#include "block_gemm_areg_bsmem_creg_v1.hpp"
#include "flash_attention_fwd_impl.hpp"
namespace ck_tile {
template <typename QDataType, typename KDataType, typename VDataType, typename ODataType>
struct FlashAttnArgs
{
// Pointers to device buffers for Q, K, V, O
QDataType* q_ptr;
KDataType* k_ptr;
VDataType* v_ptr;
ODataType* o_ptr;
// Problem sizes
index_t M0;
index_t N0;
index_t K0;
index_t N1;
index_t Batch;
// Strides within a batch
index_t strideQ;
index_t strideK;
index_t strideV;
index_t strideO;
// Batch strides
index_t batchStrideQ;
index_t batchStrideK;
index_t batchStrideV;
index_t batchStrideO;
};
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
index_t kBlockSize,
index_t kHeadDim,
index_t kM0PerBlock,
index_t kN0PerBlock,
index_t kK0PerBlock,
index_t kN1PerBlock,
index_t kK1PerBlock>
struct FlashAttentionFwd
{
__device__ void operator()(const QDataType* q_ptr,
const KDataType* k_ptr,
const VDataType* v_ptr,
ODataType* o_ptr,
const index_t M0,
const index_t N0,
const index_t K0,
const index_t N1,
const index_t /* Batch */,
const index_t StrideQ,
const index_t StrideK,
const index_t StrideV,
const index_t StrideO,
const index_t BatchStrideQ,
const index_t BatchStrideK,
const index_t BatchStrideV,
const index_t BatchStrideO) const
{
const index_t id_block = get_block_id();
const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock);
const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock);
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return make_tuple(quotient, modulus);
};
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{};
kernel_impl(q_ptr + iBatch * BatchStrideQ,
k_ptr + iBatch * BatchStrideK,
v_ptr + iBatch * BatchStrideV,
o_ptr + iBatch * BatchStrideO,
M0,
N0,
K0,
N1,
StrideQ,
StrideK,
StrideV,
StrideO,
iM0,
iN1);
}
};
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType>
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const stream_config& stream_config);
} // namespace ck_tile

View File

@@ -0,0 +1,440 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp"
#include "block_gemm_pipeline_problem.hpp"
#include "block_gemm_areg_bsmem_creg_v1.hpp"
#include "tile_gemm_shape.hpp"
namespace ck_tile {
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
index_t kBlockSize,
index_t kHeadDim,
index_t kM0PerBlock,
index_t kN0PerBlock,
index_t kK0PerBlock,
index_t kN1PerBlock,
index_t kK1PerBlock>
struct FlashAttentionFwdImpl
{
// block gemm0 pipeline
using BlockGemm0Problem =
BlockGemmPipelineProblem<QDataType,
KDataType,
SaccDataType,
kBlockSize,
TileGemmShape<kM0PerBlock, kN0PerBlock, kK0PerBlock>>;
using BlockGemm0Policy =
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg<BlockGemm0Problem, BlockGemm0Policy>;
// block gemm1
using BlockGemm1 = BlockGemmARegBSmemCRegV1<
BlockGemmARegBSmemCRegProblem<PDataType,
VDataType,
OaccDataType,
kBlockSize,
TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;
// 3d, with padding
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
constexpr index_t kKPack = 4;
#else
constexpr index_t kKPack = 8;
#endif
constexpr auto dataTypeSize = sizeof(VDataType);
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
number<kNPerBlock / NLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
__device__ static constexpr auto MakeVDramTileDistribution()
{
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
__device__ static constexpr index_t GetStaticLdsSize()
{
return max(BlockGemm0Pipeline::GetStaticLdsSize(),
static_cast<index_t>(MakeVLdsBlockDescriptor().get_element_space_size() *
sizeof(VDataType)));
}
__device__ void operator()(const QDataType* q_ptr,
const KDataType* k_ptr,
const VDataType* v_ptr,
ODataType* o_ptr,
const index_t M0,
const index_t N0,
const index_t K0,
const index_t N1,
const index_t StrideQ,
const index_t StrideK,
const index_t StrideV,
const index_t StrideO,
const index_t iM0,
const index_t iN1) const
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
// Block GEMM0 pipeline and Block GEMM1
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
constexpr auto gemm1 = BlockGemm1{};
// allocate LDS
__shared__ char smem_ptr[GetStaticLdsSize()];
// Q/K/V DRAM and DRAM window
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
const auto v_dram = make_naive_tensor_view<address_space_enum::global>(
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}),
{iM0, 0},
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{iN1, 0},
MakeVDramTileDistribution());
// Q in register
auto q_reg_tensor = load_tile(q_dram_window);
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
#if defined(TOY_FA_FWD_OPT)
// V LDS tile window for store
auto v_copy_lds_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
v_dram_window.get_tile_distribution());
// V LDS tile for block GEMM
auto v_lds_gemm_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
#else
auto v_lds_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
#endif
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SaccBlockTileType =
decltype(gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, nullptr));
using SBlockTileType = decltype(tile_elementwise_in(
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
SaccBlockTileType{}));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm1(
get_slice_tile(
PBlockTileType{}, sequence<0, 0>{}, sequence<kM0PerBlock, kK1PerBlock>{}),
v_dram_window));
// init Sacc, Oacc, M, L
auto s_acc = SaccBlockTileType{};
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
tile_elementwise_inout([](auto& e) { e = 0; }, o_acc);
tile_elementwise_inout(
[](auto& e) { e = std::numeric_limits<SMPLComputeDataType>::lowest(); }, m);
tile_elementwise_inout([](auto& e) { e = 0; }, l);
// loop over Column of S (J loop)
index_t iN0 = 0;
do
{
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
// S{j}
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
#if defined(TOY_FA_FWD_OPT)
// prefetch load v tile
auto v_prefetch = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1PerBlock});
#endif
// m_local = rowmax(S{j})
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
block_tile_reduce_sync(m_local, f_max);
// m{j-1}
const auto m_old = m;
// m{j}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
// Pcompute{j}
auto p_compute =
make_static_distributed_tensor<SMPLComputeDataType>(s.get_tile_distribution());
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = exp(s[i_j_idx] - m[i_idx]);
});
});
// rowsum(Pcompute{j})
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
#if !defined(TOY_FA_FWD_OPT)
// type cast Pcompute{j} into P{j}
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Oacc{j}
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
move_tile_window(v_dram_window, {0, kK1PerBlock});
store_tile(v_lds_window, v);
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
block_sync_lds();
});
#else
using VLdsTile = typename decltype(gemm1)::BLdsTile;
VLdsTile vWarpTile;
// type cast Pcompute{j} into P{j}
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Oacc{j}
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
if constexpr(k1_loops > 1)
{
store_tile(v_copy_lds_window, v_prefetch);
v_prefetch = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1PerBlock});
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
}
if constexpr(k1_loops > 2)
{
__builtin_amdgcn_sched_barrier(0);
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
block_sync_lds();
// LDS write 1
store_tile(v_copy_lds_window, v_prefetch);
// Global read 2
v_prefetch = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1PerBlock});
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1.template HotLoopScheduler<8, 4>();
__builtin_amdgcn_sched_barrier(0);
});
}
// tail
{
if constexpr(k1_loops > 1)
{
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
}
store_tile(v_copy_lds_window, v_prefetch);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
sequence<kM0PerBlock, kN0PerBlock>{}),
vWarpTile);
block_sync_lds();
}
#endif
// move tile windows
move_tile_window(k_dram_window, {kN0PerBlock, 0});
iN0 += kN0PerBlock;
} while(iN0 < N0);
// Oacc
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = 1 / l[i_idx];
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
// type cast Oacc into O
const auto o = tile_elementwise_in(type_convert<ODataType, OaccDataType>, o_acc);
// O DRAM and O DRAM window
auto o_dram = make_naive_tensor_view<address_space_enum::global>(
o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), number<32>{}, number<1>{});
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<kM0PerBlock>{}, number<kN1PerBlock>{}),
{iM0, iN1},
o.get_tile_distribution());
// store O
store_tile(o_dram_window, o);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,578 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
import argparse
from enum import IntEnum
from pathlib import Path
import sys
from typing import List, Optional, Any
import functools
import itertools
import copy
from dataclasses import dataclass
def get_if_str(size_, total, last_else=True):
if size_ == "head_dim_128_seq_16384":
return 'if'
else:
return 'else if'
DATA_TYPE_MAP = {'fp32': 'float',
'fp16': 'ck_tile::half_t',
'bf16': 'ck_tile::bf16_t'}
def BOOL_MAP(b_) -> str:
return 'true' if b_ else 'false'
class FlashAttentionFwdCodegen:
API_TRAITS_DEFINE = """
template <typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
index_t kBlockSize_ = 256,
index_t kHeadDim_ = 128,
index_t kM0PerBlock_ = 128,
index_t kN0PerBlock_ = 128,
index_t kK0PerBlock_ = 64,
index_t kN1PerBlock_ = 128,
index_t kK1PerBlock_ = 64>
struct flash_attention_fwd_traits_
{
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kHeadDim = kHeadDim_;
static constexpr index_t kM0PerBlock = kM0PerBlock_;
static constexpr index_t kN0PerBlock = kN0PerBlock_;
static constexpr index_t kK0PerBlock = kK0PerBlock_;
static constexpr index_t kN1PerBlock = kN1PerBlock_;
static constexpr index_t kK1PerBlock = kK1PerBlock_;
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
};
template <typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
ck_tile::index_t kBlockSize = 256,
ck_tile::index_t kHeadDim = 128,
ck_tile::index_t kM0PerBlock = 128,
ck_tile::index_t kN0PerBlock = 128,
ck_tile::index_t kK0PerBlock = 64,
ck_tile::index_t kN1PerBlock = 128,
ck_tile::index_t kK1PerBlock = 64>
using traits_ = flash_attention_fwd_traits_<SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>;
"""
API_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "flash_attention_fwd.hpp"
namespace ck_tile {{
{F_traits_define}
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename Traits_>
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config);
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType>
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config) {{
float r = -1;
{F_dispatch}
return r;
}}
template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, float, float, ck_tile::half_t, float, ck_tile::half_t>(
const FlashAttnArgs<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t>&,
const ck_tile::stream_config&);
}}
"""
API_INNER_CASE = """ {F_if} {F_VEC_COND}
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(a, stream_config);
"""
INSTANCE_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "flash_attention_fwd_api_common.hpp"
namespace ck_tile {
// clang-format off
//
{F_instance_def}
// clang-format on
}
"""
def __init__(self, working_path, kernel_filter):
self.working_path = working_path
self.kernel_filter = kernel_filter
@dataclass
class h_traits:
F_SaccDataType: str
F_SMPLComputeDataType: str
F_PDataType: str
F_OaccDataType: str
F_kBlockSize: int
F_kHeadDim: int
F_kM0PerBlock: int
F_kN0PerBlock: int
F_kK0PerBlock: int
F_kN1PerBlock: int
F_kK1PerBlock: int
@property
def trait_name(self) -> str:
return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, "
f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, "
f"{DATA_TYPE_MAP[self.F_PDataType]}, "
f"{DATA_TYPE_MAP[self.F_OaccDataType]}, "
f"{self.F_kBlockSize}, {self.F_kHeadDim}, "
f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, "
f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}")
@property
def def_name(self) -> str:
return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, "
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, "
f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, "
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, "
"const ck_tile::stream_config&);")
@dataclass
class h_instance:
F_DataTypePair: str # "q,k,v,o"
F_SizeCategory: str # "small", "medium", "large"
instance_list: List[Any] # List[h_traits]
INSTANCE_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "flash_attention_fwd_api_common.hpp"
namespace ck_tile {{
// clang-format off
//
{F_instance_def}
// clang-format on
}}
"""
@property
def name(self) -> str:
q_type, k_type, v_type, o_type = self.F_DataTypePair.split(',')
dtype_str = f"{q_type}_{k_type}_{v_type}_{o_type}"
return f"flash_attention_fwd_{dtype_str}_{self.F_SizeCategory}"
@property
def content(self) -> str:
instance_defs = '\n'.join(ins.def_name for ins in self.instance_list)
return self.INSTANCE_BASE.format(F_instance_def=instance_defs)
@property
def name_api(self) -> str:
return "flash_attention_fwd_api"
@property
def name_common_header(self) -> str:
return "flash_attention_fwd_api_common"
@property
def content_common_header(self) -> str:
return f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "flash_attention_fwd.hpp"
namespace ck_tile {{
template <typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
index_t kBlockSize_ = 256,
index_t kHeadDim_ = 128,
index_t kM0PerBlock_ = 128,
index_t kN0PerBlock_ = 128,
index_t kK0PerBlock_ = 64,
index_t kN1PerBlock_ = 128,
index_t kK1PerBlock_ = 64>
struct flash_attention_fwd_traits_
{{
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kHeadDim = kHeadDim_;
static constexpr index_t kM0PerBlock = kM0PerBlock_;
static constexpr index_t kN0PerBlock = kN0PerBlock_;
static constexpr index_t kK0PerBlock = kK0PerBlock_;
static constexpr index_t kN1PerBlock = kN1PerBlock_;
static constexpr index_t kK1PerBlock = kK1PerBlock_;
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
}};
template <typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
ck_tile::index_t kBlockSize,
ck_tile::index_t kHeadDim,
ck_tile::index_t kM0PerBlock,
ck_tile::index_t kN0PerBlock,
ck_tile::index_t kK0PerBlock,
ck_tile::index_t kN1PerBlock,
ck_tile::index_t kK1PerBlock>
using traits_ = flash_attention_fwd_traits_<SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>;
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename Traits_>
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config) {{
using SaccDataType = typename Traits_::SaccDataType;
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
using PDataType = typename Traits_::PDataType;
using OaccDataType = typename Traits_::OaccDataType;
index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
if(stream_config.log_level_ > 0)
std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush;
return ck_tile::launch_kernel(stream_config,
ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
ck_tile::FlashAttentionFwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
Traits_::kBlockSize,
Traits_::kHeadDim,
Traits_::kM0PerBlock,
Traits_::kN0PerBlock,
Traits_::kK0PerBlock,
Traits_::kN1PerBlock,
Traits_::kK1PerBlock>{{}},
kGridSize,
Traits_::kBlockSize,
0,
a.q_ptr,
a.k_ptr,
a.v_ptr,
a.o_ptr,
a.M0,
a.N0,
a.K0,
a.N1,
a.Batch,
a.strideQ, // StrideQ
a.strideK, // StrideK
a.strideV, // StrideV
a.strideO, // StrideO
a.batchStrideQ, // BatchStrideQ
a.batchStrideK, // BatchStrideK
a.batchStrideV, // BatchStrideV
a.batchStrideO)); // BatchStrideO
}}
}}
"""
def content_api(self, args) -> str:
# Sort based on dtype
t_dtype_dict = {}
blobs = self.get_blobs(args)
for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {}
if blob.F_SizeCategory not in t_dtype_dict[blob.F_DataTypePair]:
t_dtype_dict[blob.F_DataTypePair][blob.F_SizeCategory] = []
t_dtype_dict[blob.F_DataTypePair][blob.F_SizeCategory].append(blob)
d_str = ''
for i_d, dtype_ in enumerate(t_dtype_dict):
blob_per_t = t_dtype_dict[dtype_]
size_str = ''
for i_size, size_ in enumerate(blob_per_t):
blob_per_size = blob_per_t[size_]
inner_str = ""
for i_b, b_ in enumerate(blob_per_size):
for i_ins, ins in enumerate(b_.instance_list):
idx_in_size = i_b * len(b_.instance_list) + i_ins
len_in_size = sum(len(b.instance_list) for b in blob_per_size)
size_cond = ""
if size_ == "head_dim_128_seq_16384":
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_16384":
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_128_seq_8192":
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_8192":
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_256_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)"
elif size_ == "head_dim_128_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_32_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)"
elif size_ == "head_dim_128_seq_2048":
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_2048":
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_128_seq_1024":
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_1024":
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_128_seq_512":
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_512":
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 64 && a.N1 == 64)"
else:
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
inner_str += self.API_INNER_CASE.format(
F_if=get_if_str(size_, len_in_size, False),
F_VEC_COND=size_cond,
F_trait_name=ins.trait_name
)
size_str += inner_str
d_str += size_str
api_base = self.API_BASE.format(
F_traits_define=self.API_TRAITS_DEFINE,
F_dispatch=d_str
)
return api_base
def get_blobs(self, args):
h_traits = self.h_traits
h_instance = self.h_instance
# Define kernel configurations for different size categories
trait_dict = {
"head_dim_128_seq_16384": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_16384": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_128_seq_8192": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_8192": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_256_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
],
"head_dim_128_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
],
"head_dim_32_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32),
],
"head_dim_128_seq_2048": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_2048": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_128_seq_1024": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_1024": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
],
"head_dim_128_seq_512": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
],
"head_dim_64_seq_512": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
],
}
# Toy example only support fp16
dtype_combinations = [
"fp16,fp16,fp16,fp16"
# "bf16,bf16,bf16,bf16"
]
total_blob = []
for dtype_pair in dtype_combinations:
for size_category in trait_dict:
traits = trait_dict[size_category]
# Convert data types for the current dtype_pair
q_type, k_type, v_type, o_type = dtype_pair.split(',')
current_traits = []
for t in traits:
new_t = copy.copy(t)
new_t.F_SaccDataType = 'fp32' # accumulation in fp32
new_t.F_SMPLComputeDataType = 'fp32' # softmax compute in fp32
new_t.F_PDataType = q_type
new_t.F_OaccDataType = 'fp32' # output accumulation in fp32
current_traits.append(new_t)
total_blob.append(h_instance(dtype_pair, size_category, current_traits))
return total_blob
def list_blobs(self, args) -> None:
w_p = Path(self.working_path)
list_p = w_p / 'flash_attention_fwd_blobs.txt'
blobs = self.get_blobs(args)
with list_p.open('w') as list_f:
# API related files
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
# Kernel instance files
for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self, args) -> None:
w_p = Path(self.working_path)
w_str = self.content_api(args)
(w_p / (self.name_api + ".cpp")).write_text(w_str)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs(args)
for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content)
def list_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
FlashAttentionFwdCodegen(args.working_path, args.filter).list_blobs(args)
def gen_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
FlashAttentionFwdCodegen(args.working_path, args.filter).gen_blobs(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for Flash Attention kernel",
)
parser.add_argument(
"-a",
"--api",
default='fwd',
required=False,
help="supply API(s) to generate (default: fwd). separated by comma."
)
parser.add_argument(
"-w",
"--working_path",
default="./",
required=False,
help="the path where all the blobs are going to be generated"
)
parser.add_argument(
"-l",
"--list_blobs",
action='store_true',
help="list all the kernels to a file"
)
parser.add_argument(
"-g",
"--gen_blobs",
action='store_true',
help="generate all kernels into different tile"
)
parser.add_argument(
"-f",
"--filter",
required=False,
help="filter out kernels that need to generate"
)
args = parser.parse_args()
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
print('gen_blobs/list_blobs must specify only one option')
sys.exit()
p = Path(args.working_path)
if not p.exists():
p.mkdir()
if args.list_blobs:
list_blobs(args)
else:
gen_blobs(args)

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_batched_gemm(const ck_tile::HostTensor<ADataType>& a_b_m_k,
const ck_tile::HostTensor<BDataType>& b_b_n_k,
ck_tile::HostTensor<CDataType>& c_b_m_n)
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_b_m_k(batch, m, k);
BDataType v_b = b_b_n_k(batch, n, k);
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(v_acc);
}
};
ck_tile::make_ParallelTensorFunctor(
f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}

View File

@@ -0,0 +1,48 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_batched_softmax(const ck_tile::HostTensor<ADataType>& a_b_m_n,
ck_tile::HostTensor<BDataType>& b_b_m_n)
{
const int N = a_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
AccDataType v_max = std::numeric_limits<ADataType>::lowest();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
v_max = v_max < v_a ? v_a : v_max;
}
AccDataType v_exp_sum = 0;
// sum
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
v_exp_sum += ck_tile::exp(v_a - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_b_m_n(batch, m, n);
b_b_m_n(batch, m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
}
};
ck_tile::make_ParallelTensorFunctor(
f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}

View File

@@ -0,0 +1,18 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
struct TileGemmShape
{
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
} // namespace ck_tile

View File

@@ -0,0 +1,9 @@
include_directories(AFTER
${CMAKE_CURRENT_LIST_DIR}
)
add_subdirectory(00_add_basic)
add_subdirectory(01_add)
add_subdirectory(02_gemm)
add_subdirectory(03_flash_attention_fwd)
add_subdirectory(04_codegen_flash_attention_fwd)

View File

@@ -0,0 +1,112 @@
# CK_TILE Toy Example
This repository demonstrates a toy example implemented using ck_tile
## Build Instructions
Follow these steps to build the examples:
```sh
cd composable_kernel
mkdir build
cd build
cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx942" \
-Dkernel=N ..
```
### Compile Examples
#### **Elementwise Add Example**
```sh
make -j add
```
#### **GEMM Example**
```sh
make -j basic_gemm
```
#### **Flash Attention Forward Example**
```sh
make -j basic_flash_attention_fwd
```
## Running Examples
### **Elementwise Add**
```sh
./bin/add
```
### **GEMM Example**
```sh
./bin/basic_gemm 1
```
### **Flash Attention Forward Example**
```sh
./bin/basic_flash_attention_fwd 1 1
```
## Advanced part
#### **GEMM Example**
##### Follow these steps to build and run the different kernels:
```sh
cd composable_kernel
mkdir build
cd build
# for naive kernel
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=N .. && make -j basic_gemm
# for kernel A
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=A .. && make -j basic_gemm
# for kernel B
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=B .. && make -j basic_gemm
...
# for kernel H
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=H .. && make -j basic_gemm
```
```sh
./bin/basic_gemm 1
```
#### **Flash Attention Forward Example**
##### Follow these steps to build the kernels
```sh
# for naive kernel
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" .. && make -j basic_flash_attention_fwd
# for optimized kernel
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -DENABLE_TOY_FA_FWD_OPT=ON .. && make -j basic_flash_attention_fwd
```
```sh
./bin/basic_flash_attention_fwd 1 1
```
##### Follow these steps to build the codegen instances
```sh
mkdir build
cd build
../script/cmake-ck-release.sh .. gfx942
make -j codegen_basic_flash_attention_fwd
```
```sh
./bin/codegen_basic_flash_attention_fwd 1 1 64 16384 16384 128 128
```