mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Add multiple tutorial examples
This commit is contained in:
21
example/ck_tile/tutorial/00_add_basic/CMakeLists.txt
Normal file
21
example/ck_tile/tutorial/00_add_basic/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
set(EXAMPLE_ADD_BASIC "add_basic")
|
||||
|
||||
message("adding example ${EXAMPLE_ADD_BASIC}")
|
||||
|
||||
add_executable(${EXAMPLE_ADD_BASIC} EXCLUDE_FROM_ALL add_basic.cpp)
|
||||
target_include_directories(${EXAMPLE_ADD_BASIC} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_ADD_BASIC_COMPILE_OPTIONS)
|
||||
|
||||
# generate assembly
|
||||
# list(APPEND EXAMPLE_ADD_BASIC_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_ADD_BASIC_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_ADD_BASIC} PRIVATE ${EXAMPLE_ADD_BASIC_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
169
example/ck_tile/tutorial/00_add_basic/add_basic.cpp
Normal file
169
example/ck_tile/tutorial/00_add_basic/add_basic.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "reference_add_vector.hpp"
|
||||
#include "add_basic.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// This example demonstrates how to use the ck_tile library to perform an elementwise vector
|
||||
// addition using a custom kernel. The kernel is defined in the vector_add.hpp file, and the
|
||||
// reference implementation is provided in the reference_vector_add.hpp file.
|
||||
|
||||
// parse command line arguments
|
||||
// -m: size of the vectors
|
||||
// -v: validation flag (1 for validation, 0 for no validation)
|
||||
// -prec: precision of the data type (fp16, fp32, int8, int32)
|
||||
// -warmup: number of warmup iterations (number of kernel launches before measuring performance)
|
||||
// -repeat: number of repeat iterations (number of kernel launches to measure performance)
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "41943040", "m dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType; // input data type
|
||||
using ComputeDataType = float; // compute data type
|
||||
using YDataType = DataType; // output data type
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m"); // size of the vectors
|
||||
int do_validation = arg_parser.get_int("v"); // do we verify the result on cpu
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host_a(
|
||||
{m}); // length input vector A, if given two arguments (m, n) the HostTensor will be created
|
||||
// with shape (m, n)
|
||||
ck_tile::HostTensor<XDataType> x_host_b(
|
||||
{m}); // length input vector B, if given two arguments (m, n) the HostTensor will be created
|
||||
// with shape (m, n)
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(
|
||||
x_host_a); // fill the input vector A with random values
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_b);
|
||||
|
||||
ck_tile::DeviceMem x_buf_a(
|
||||
x_host_a.get_element_space_size_in_bytes()); // allocate device memory for input vector A
|
||||
// (this a wrapper over hipMalloc)
|
||||
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf_a.ToDevice(
|
||||
x_host_a
|
||||
.data()); // copy the input vector A to device memory, this is a wrapper over hipMemcpy
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
// Dividing the problem into blocktile, warptile, and vector
|
||||
// The blocktile is the size of the tile that will be processed by a single thread block (also
|
||||
// called work group) The warptile is the size of the tile that will be processed by a single
|
||||
// warp (also called wavefront) The vector is the size of the tile that will be processed by a
|
||||
// single thread (also called work item) The problem is divided into blocks of size BlockTile,
|
||||
// each block is further divided into warps of size WarpTile and each warp is composed of 64 or
|
||||
// 32 threads of size Vector each of the thread in a warp will process one vector worth elements
|
||||
// of the data
|
||||
using BlockTile = ck_tile::sequence<8192>; // Size of the block tile (Entire problem is divided
|
||||
// into blocks of this size)
|
||||
using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp
|
||||
// will cover some part of blockTile)
|
||||
using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp
|
||||
using Vector = ck_tile::sequence<1>; // How many elements are covered by a thread (Each thread
|
||||
// will cover some part of WarpTile)
|
||||
|
||||
// Interpretation of above configurations
|
||||
// Each thread will cover 1 element (Vector)
|
||||
// Each WarpTile will cover 64 elements (WarpTile) --> since 64 threads in a warp
|
||||
// if we have 8 warps in a block (BlockWarps) then we have 8 * 64 = 512 threads in a block
|
||||
// if 8 warps are not enough to cover the entire blockTile then each of the 8 concurrent warps
|
||||
// will iterate over the blockTile several times
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 512;
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::AddVectorShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
std::cout << "Problem Shape:: M = " << m << std::endl;
|
||||
std::cout << "BlockTile: " << BlockTile::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "Number of Blocks in Grid: " << m / BlockTile::at(ck_tile::number<0>{})
|
||||
<< std::endl;
|
||||
std::cout << "BlockWarps: " << BlockWarps::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "WarpTile: " << WarpTile::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "Vector: " << Vector::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "Repeat: " << Shape::Repeat_M
|
||||
<< std::endl; // number of times a warp will iterate over the blockTile, covering
|
||||
// different parts of the blockTile
|
||||
std::cout << "Threads per Block: " << kBlockSize << std::endl;
|
||||
std::cout << "ThreadBlocks per CU: " << kBlockPerCu << std::endl;
|
||||
|
||||
// What is a Problem in CKTile?
|
||||
// A Problem defines the shape of the data, the precision of the data
|
||||
using Problem = ck_tile::AddVectorProblem<XDataType, ComputeDataType, YDataType, Shape>;
|
||||
|
||||
// What is a Policy in CKTile?
|
||||
// A Policy defines how to map the data between threads and data in memory
|
||||
|
||||
// The kernel is the function that will be executed on the device
|
||||
// It requires a Problem and Policy to be defined
|
||||
using Kernel = ck_tile::AddVectorKernel<Problem>;
|
||||
|
||||
// The kernel is launched with the following parameters:
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, // wrapper over hipStreamCreate
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>( // numOfThreadsPerBlock, numOfBlocksPerCU
|
||||
Kernel{}, // kernel
|
||||
kGridSize, // number of blocks in the grid
|
||||
kBlockSize, // number of threads in a block
|
||||
0, // shared memory size
|
||||
static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()), // input vector A
|
||||
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()), // input vector B
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()), // output vector
|
||||
m));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * m + sizeof(YDataType) * m;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_add_vector<XDataType, YDataType>(x_host_a, x_host_b, y_host_ref);
|
||||
y_buf.FromDevice(y_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(y_host_dev, y_host_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
140
example/ck_tile/tutorial/00_add_basic/add_basic.hpp
Normal file
140
example/ck_tile/tutorial/00_add_basic/add_basic.hpp
Normal file
@@ -0,0 +1,140 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// struct that holds the tile size of the block, warp, and vector
|
||||
// and the number of warps per block
|
||||
// and the number of threads per warp
|
||||
// and the number of times the warp tile is repeated in the block tile
|
||||
// and the block size
|
||||
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename Vector>
|
||||
struct AddVectorShape
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{});
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
|
||||
static constexpr index_t Repeat_M =
|
||||
Block_M /
|
||||
(WarpPerBlock_M * Warp_M); // Number of times the warp tile is repeated in the block tile
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
|
||||
struct AddVectorProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
};
|
||||
|
||||
// data mapping beween threads and memory
|
||||
struct AddDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // Replicate
|
||||
tuple<sequence<S::Repeat_M,
|
||||
S::WarpPerBlock_M,
|
||||
S::ThreadPerWarp_M,
|
||||
S::Vector_M>>, // Hierarchical
|
||||
tuple<sequence<1>, sequence<1>>, // Parallel
|
||||
tuple<sequence<1>, sequence<2>>, // Parallel
|
||||
sequence<1, 1>, // Yield
|
||||
sequence<0, 3>>{} // Yield
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = AddDefaultPolicy>
|
||||
struct AddVectorKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
// body of the kernel
|
||||
CK_TILE_DEVICE void
|
||||
operator()(const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// create tensor view for the input and output data, this defines how the data is laid out
|
||||
// in memory
|
||||
const auto x_m_n_a = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x_a,
|
||||
make_tuple(M),
|
||||
make_tuple(1),
|
||||
number<S::Vector_M>{}); // raw pointer, shape of the tensor, stride of the tensor, and
|
||||
// lastGarunteedVectorLength
|
||||
|
||||
const auto x_m_n_b = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x_b, make_tuple(M), make_tuple(1), number<S::Vector_M>{});
|
||||
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, make_tuple(M), make_tuple(1), number<S::Vector_M>{});
|
||||
|
||||
// origin of the block tile
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// creating tile windows for the input and output data
|
||||
auto x_window_a = make_tile_window(x_m_n_a,
|
||||
make_tuple(number<S::Block_M>{}),
|
||||
{iM},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto x_window_b = make_tile_window(x_m_n_b,
|
||||
make_tuple(number<S::Block_M>{}),
|
||||
{iM},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m_n,
|
||||
make_tuple(number<S::Block_M>{}),
|
||||
{iM},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Load tile data
|
||||
const auto xa =
|
||||
load_tile(x_window_a); // load tile data from global tensor view, load from where? what?
|
||||
// how many? logical memory layout? all are defined in x_window_a
|
||||
const auto xb = load_tile(x_window_b);
|
||||
auto y_compute = load_tile(y_window);
|
||||
|
||||
// Process the vector add
|
||||
constexpr auto spans = decltype(xa)::get_distributed_spans(); // shape of the tile
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx) { // iterate over the tile
|
||||
const auto tile_idx = make_tuple(idx);
|
||||
const auto a_val = type_convert<ComputeDataType>(xa[tile_idx]);
|
||||
const auto b_val = type_convert<ComputeDataType>(xb[tile_idx]);
|
||||
y_compute(tile_idx) = a_val + b_val;
|
||||
|
||||
});
|
||||
|
||||
// Store results
|
||||
store_tile(y_window,
|
||||
cast_tile<YDataType>(y_compute)); // store the result back to global tensor view
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType, typename YDataType>
|
||||
CK_TILE_HOST void reference_add_vector(const HostTensor<XDataType>& xa_m_n,
|
||||
const HostTensor<XDataType>& xb_m_n,
|
||||
HostTensor<YDataType>& y_m_n)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = 1;
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
y_m_n(m, n) = ck_tile::type_convert<YDataType>(xa_m_n(m, n)) +
|
||||
ck_tile::type_convert<YDataType>(xb_m_n(m, n));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f,
|
||||
y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
21
example/ck_tile/tutorial/01_add/CMakeLists.txt
Normal file
21
example/ck_tile/tutorial/01_add/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
set(EXAMPLE_ADD "add")
|
||||
|
||||
message("adding example ${EXAMPLE_ADD}")
|
||||
|
||||
add_executable(${EXAMPLE_ADD} EXCLUDE_FROM_ALL add.cpp)
|
||||
target_include_directories(${EXAMPLE_ADD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_ADD_COMPILE_OPTIONS)
|
||||
|
||||
# generate assembly
|
||||
# list(APPEND EXAMPLE_ADD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_ADD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_ADD} PRIVATE ${EXAMPLE_ADD_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
117
example/ck_tile/tutorial/01_add/add.cpp
Normal file
117
example/ck_tile/tutorial/01_add/add.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "reference_add.hpp"
|
||||
#include "add.hpp"
|
||||
#include <cstring>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "10240", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "200", "cold iter")
|
||||
.insert("repeat", "1000", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host_a({m, n});
|
||||
ck_tile::HostTensor<XDataType> x_host_b({m, n});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_a);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_b);
|
||||
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
using BlockWarps =
|
||||
ck_tile::sequence<1, 8>; // number of concurrent warps in one block (if 8 warps * 64 threads
|
||||
// per warp, 512 threads in one block are NEEDED)
|
||||
using BlockTile =
|
||||
ck_tile::sequence<1, 4096>; // shape of one blockTile (elements covered by one block)
|
||||
using WarpTile = ck_tile::sequence<1, 512>; // shape of one warpTile (elements covered by one
|
||||
// warp (64 threads))
|
||||
using Vector = ck_tile::sequence<1, 8>; // shape of one vector (elements covered by one thread)
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize =
|
||||
512; // number of blockWarps * number of threads per warp
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::AddShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
using Porblem = ck_tile::AddProblem<XDataType, ComputeDataType, YDataType, Shape>;
|
||||
|
||||
using Kernel = ck_tile::Add<Porblem>;
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
|
||||
std::size_t num_btype = 2 * sizeof(XDataType) * m * n + sizeof(YDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_add<XDataType, YDataType>(x_host_a, x_host_b, y_host_ref);
|
||||
y_buf.FromDevice(y_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(y_host_dev, y_host_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
166
example/ck_tile/tutorial/01_add/add.hpp
Normal file
166
example/ck_tile/tutorial/01_add/add.hpp
Normal file
@@ -0,0 +1,166 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename Vector> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct AddShape
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{}); // elements along M in one Block
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{}); // elements along N in one Block
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{}); // elements along M in one Warp
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{}); // elements along N in one Warp
|
||||
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{}); // elements along M in one Vector
|
||||
static constexpr index_t Vector_N = Vector::at(number<1>{}); // elements along N in one Vector
|
||||
|
||||
static constexpr index_t WarpPerBlock_M =
|
||||
BlockWarps::at(number<0>{}); // num concurrent warps along M
|
||||
static constexpr index_t WarpPerBlock_N =
|
||||
BlockWarps::at(number<1>{}); // num concurrent warps along N
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M =
|
||||
Warp_M /
|
||||
Vector_M; // num threads along M in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
|
||||
static constexpr index_t ThreadPerWarp_N =
|
||||
Warp_N /
|
||||
Vector_N; // num threads along N in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
|
||||
|
||||
static constexpr index_t Repeat_M =
|
||||
Block_M /
|
||||
(WarpPerBlock_M *
|
||||
Warp_M); // num of time a warp iterates along M to ensure the entire block is covered
|
||||
static constexpr index_t Repeat_N =
|
||||
Block_N /
|
||||
(WarpPerBlock_N *
|
||||
Warp_N); // num of time a warp iterates along N to ensure the entire block is covered
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize *
|
||||
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); // num of threads in one block
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
|
||||
struct AddProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>; // data type of input tensor
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>; // data type of compute tensor
|
||||
using YDataType = remove_cvref_t<YDataType_>; // data type of output tensor
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // block shapes and sizes
|
||||
};
|
||||
|
||||
struct AddDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M,
|
||||
S::WarpPerBlock_M,
|
||||
S::ThreadPerWarp_M,
|
||||
S::Vector_M>, // how many sub division is a block divided in
|
||||
sequence<S::Repeat_N,
|
||||
S::WarpPerBlock_N,
|
||||
S::ThreadPerWarp_N,
|
||||
S::Vector_N>>, // how many sub division is a block divided in
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // What are the shapes of those sub divisions
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // What are the shapes of those sub divisions
|
||||
sequence<1, 1, 2, 2>, // How much data does a thread work on and how many iterations
|
||||
// of warps are there
|
||||
sequence<0, 3, 0, 3>>{}); // How much data does a thread work on and how many
|
||||
// iterations of warps are there
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = AddDefaultPolicy>
|
||||
struct Add
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(
|
||||
const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
const auto x_m_n_a = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_x_a,
|
||||
make_tuple(M, N),
|
||||
make_tuple(N, 1),
|
||||
number<S::Vector_N>{},
|
||||
number<1>{}); // raw data, shape of tensor, stride of tensor, lastGarunteedVectorLength,
|
||||
// lastGarunteedVectorStride
|
||||
|
||||
const auto x_m_n_b = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_x_b, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
|
||||
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * S::Block_M; // origin of the block along
|
||||
|
||||
auto x_window_a = make_tile_window(x_m_n_a,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto x_window_b = make_tile_window(x_m_n_b,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m_n,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto xa = load_tile(x_window_a);
|
||||
const auto xb = load_tile(x_window_b);
|
||||
auto y_compute = load_tile(y_window);
|
||||
|
||||
constexpr auto spans = decltype(xa)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
const auto x = ck_tile::type_convert<ComputeDataType>(xa[i_j_idx]);
|
||||
const auto y = ck_tile::type_convert<ComputeDataType>(xb[i_j_idx]);
|
||||
y_compute(i_j_idx) = x + y;
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_compute));
|
||||
move_tile_window(x_window_a, {0, S::Block_N});
|
||||
move_tile_window(x_window_b, {0, S::Block_N});
|
||||
move_tile_window(y_window, {0, S::Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
31
example/ck_tile/tutorial/01_add/reference_add.hpp
Normal file
31
example/ck_tile/tutorial/01_add/reference_add.hpp
Normal file
@@ -0,0 +1,31 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType, typename YDataType>
|
||||
CK_TILE_HOST void reference_add(const HostTensor<XDataType>& xa_m_n,
|
||||
const HostTensor<XDataType>& xb_m_n,
|
||||
HostTensor<YDataType>& y_m_n)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = xa_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
y_m_n(m, n) = ck_tile::type_convert<YDataType>(xa_m_n(m, n)) +
|
||||
ck_tile::type_convert<YDataType>(xb_m_n(m, n));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f,
|
||||
y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
26
example/ck_tile/tutorial/02_gemm/CMakeLists.txt
Normal file
26
example/ck_tile/tutorial/02_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
set(EXAMPLE_BASIC_GEMM "basic_gemm")
|
||||
|
||||
message("adding example ${EXAMPLE_BASIC_GEMM}")
|
||||
|
||||
add_executable(${EXAMPLE_BASIC_GEMM} EXCLUDE_FROM_ALL gemm.cpp)
|
||||
target_include_directories(${EXAMPLE_BASIC_GEMM} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS)
|
||||
|
||||
# generate assembly
|
||||
# list(APPEND EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
if(DEFINED kernel)
|
||||
message("Compiling with Kernel: ${kernel}")
|
||||
target_compile_definitions(${EXAMPLE_BASIC_GEMM} PRIVATE KERNEL_${kernel}=1)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_BASIC_GEMM} PRIVATE ${EXAMPLE_BASIC_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
372
example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg.hpp
Normal file
372
example/ck_tile/tutorial/02_gemm/block_gemm_asmem_bsmem_creg.hpp
Normal file
@@ -0,0 +1,372 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegDefaultPolicy>
|
||||
struct BlockGemmASmemBSmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
// A block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile aWarpTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Prefetch from LDS to warp register
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
aWarpTile = load_tile(a_block_window);
|
||||
bWarpTile = load_tile(b_block_window);
|
||||
}
|
||||
#endif
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message("local data share prefetch")
|
||||
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
#else
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
#else
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
#else
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
#else
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// c += a * b
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "config.h"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if defined(ADJUST_BLOCK_TILE_SHAPE)
|
||||
constexpr index_t kMWarp = 2;
|
||||
constexpr index_t kNWarp = 2;
|
||||
#else
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
#endif
|
||||
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message("mfma m32 n32 k8")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_32x32x_8x2)
|
||||
#pragma message("mfma m32 n32 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_16x16x16)
|
||||
#pragma message("mfma m16 n16 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_16x16x_16x2)
|
||||
#pragma message("mfma m16 n16 k32")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,412 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = ck_tile::BlockGemmPipelineAGmemBGmemCRegDefaultPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPack() { return Policy::template GetSmemPack<Problem>(); }
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemm::MWarp;
|
||||
constexpr index_t WaveNumN = BlockGemm::NWarp;
|
||||
|
||||
constexpr index_t AB_LDS_RW_Width = GetSmemPack();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock /
|
||||
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16
|
||||
? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_a_mfma =
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) /
|
||||
// sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) /
|
||||
// sizeof(ADataType) : sizeof(ComputeDataType)
|
||||
// / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
constexpr auto num_mfma_per_dswrite_a =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_a * 2 >= 1) ? 2 : 1;
|
||||
constexpr auto num_mfma_per_dswrite_b =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_a *
|
||||
num_dswrite_per_issue_a,
|
||||
0); // MFMA
|
||||
});
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_b *
|
||||
num_dswrite_per_issue_b,
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
|
||||
ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()));
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
#endif
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message("global prefetch")
|
||||
// Prefetch
|
||||
// Global read 0
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
if(num_loop > 1)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Prefetch from LDS to warp register in block gemm
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Main body
|
||||
if(num_loop > 2)
|
||||
{
|
||||
index_t iCounter = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 2
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Prefetch from LDS to warp register in block gemm
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
HotLoopScheduler();
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
iCounter += 1;
|
||||
} while(iCounter < (num_loop - 2));
|
||||
}
|
||||
|
||||
// Tail
|
||||
if(num_loop > 1)
|
||||
{
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
#else
|
||||
// non-prefetch
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
#endif
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,352 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
#include "config.h"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmPipelineAGmemBGmemCReg
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
{
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_K_FIRST)
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_MN_FIRST)
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
|
||||
number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
#endif
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if defined(PADDING_K_FIRST) || defined(NAIVE_IMPLEMENTATION)
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_K_FIRST)
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(PADDING_MN_FIRST)
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
#endif
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
|
||||
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
|
||||
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Problem::TransposeC;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPack()
|
||||
{
|
||||
constexpr index_t kKPack = 8;
|
||||
return kKPack;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return BlockGemmASmemBSmemCReg<Problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
40
example/ck_tile/tutorial/02_gemm/config.h
Normal file
40
example/ck_tile/tutorial/02_gemm/config.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#if defined(KERNEL_A)
|
||||
#define PADDING_K_FIRST
|
||||
#define USING_MFMA_32x32x_8x2
|
||||
#elif defined(KERNEL_B)
|
||||
#define PADDING_K_FIRST
|
||||
#define USING_MFMA_16x16x16
|
||||
#elif defined(KERNEL_C)
|
||||
#define PADDING_K_FIRST
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#elif defined(KERNEL_D)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#elif defined(KERNEL_E)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#define ADJUST_BLOCK_TILE_SHAPE
|
||||
#elif defined(KERNEL_F)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#define ADJUST_BLOCK_TILE_SHAPE
|
||||
#define ENABLE_PREFETCH
|
||||
#elif defined(KERNEL_G)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#define ADJUST_BLOCK_TILE_SHAPE
|
||||
#define ENABLE_PREFETCH
|
||||
#define ENABLE_INSTRUCTION_SCH
|
||||
#elif defined(KERNEL_H)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#define ADJUST_BLOCK_TILE_SHAPE
|
||||
#define ENABLE_PREFETCH
|
||||
#define ENABLE_INSTRUCTION_SCH
|
||||
#define ENABLE_CACHE_AWARE_WG_SCH
|
||||
#else
|
||||
#define NAIVE_IMPLEMENTATION
|
||||
#endif
|
||||
202
example/ck_tile/tutorial/02_gemm/gemm.cpp
Normal file
202
example/ck_tile/tutorial/02_gemm/gemm.cpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "config.h"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
|
||||
/*
|
||||
* Toy code of GEMM
|
||||
* Assume simplest case.
|
||||
* A [M, K]
|
||||
* B [N, K]
|
||||
* C [M, N]
|
||||
*/
|
||||
|
||||
// elementwise lambda
|
||||
struct CElementFunction
|
||||
{
|
||||
template <typename X>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
ck_tile::index_t N = 4096;
|
||||
ck_tile::index_t K = 4096;
|
||||
|
||||
if(argc == 2)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
}
|
||||
if(argc == 5)
|
||||
{
|
||||
verification = std::stoi(argv[1]);
|
||||
M = std::stoi(argv[2]);
|
||||
N = std::stoi(argv[3]);
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
#if defined(KERNEL_A)
|
||||
printf("*** Kernel A test *** \n");
|
||||
printf(" --> Using mfma_32x32x(8x2)\n");
|
||||
#elif defined(KERNEL_B)
|
||||
printf("*** Kernel B test *** \n");
|
||||
printf(" --> Using mfma_16x16x16\n");
|
||||
#elif defined(KERNEL_C)
|
||||
printf("*** Kernel C test *** \n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
#elif defined(KERNEL_D)
|
||||
printf("*** Kernel D test *** \n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based bank-conflict-free\n");
|
||||
#elif defined(KERNEL_E)
|
||||
printf("*** Kernel E test ***\n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based bank-conflict-free\n");
|
||||
printf(" --> Adjust block tile shape\n");
|
||||
#elif defined(KERNEL_F)
|
||||
printf("*** Kernel F test ***\n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based bank-conflict-free\n");
|
||||
printf(" --> Adjust block tile shape\n");
|
||||
printf(" --> Enable prefetch\n");
|
||||
#elif defined(KERNEL_G)
|
||||
printf("*** Kernel G test ***\n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based bank-conflict-free\n");
|
||||
printf(" --> Adjust block tile shape\n");
|
||||
printf(" --> Enable prefetch\n");
|
||||
printf(" --> Enable instruction schedule\n");
|
||||
#elif defined(KERNEL_H)
|
||||
printf("*** Kernel H test ***\n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based bank-conflict-free\n");
|
||||
printf(" --> Adjust block tile shape\n");
|
||||
printf(" --> Enable prefetch\n");
|
||||
printf(" --> Enable instruction schedule\n");
|
||||
printf(" --> Enable cache-aware thread blocks schedule\n");
|
||||
#else
|
||||
printf("*** Naive implementation test ***\n");
|
||||
#endif
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
|
||||
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
||||
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
||||
|
||||
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
||||
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
||||
|
||||
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
||||
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
||||
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
||||
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.mData.data());
|
||||
b_buf.ToDevice(b_host.mData.data());
|
||||
|
||||
// Alignment
|
||||
constexpr ck_tile::index_t kAAlignment = 8;
|
||||
constexpr ck_tile::index_t kBAlignment = 8;
|
||||
constexpr ck_tile::index_t kCAlignment = 8;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
#ifdef ADJUST_BLOCK_TILE_SHAPE
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 128;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 64;
|
||||
#else
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
||||
#endif
|
||||
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
using gemm_kernel = ck_tile::Gemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CElementFunction,
|
||||
kAAlignment,
|
||||
kBAlignment,
|
||||
kCAlignment,
|
||||
kBlockSize,
|
||||
kGemmMPerBlock,
|
||||
kGemmNPerBlock,
|
||||
kGemmKPerBlock>;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
gemm_kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Lda,
|
||||
Ldb,
|
||||
Ldc,
|
||||
CElementFunction{}));
|
||||
auto pass = true;
|
||||
|
||||
if(verification)
|
||||
{
|
||||
// reference gemm
|
||||
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
||||
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_host_ref);
|
||||
c_buf.FromDevice(c_host_dev.mData.data());
|
||||
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
195
example/ck_tile/tutorial/02_gemm/gemm.hpp
Normal file
195
example/ck_tile/tutorial/02_gemm/gemm.hpp
Normal file
@@ -0,0 +1,195 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "config.h"
|
||||
#include "grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename CElementFunction_>
|
||||
struct GridGemmProblem
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename CElementFunction,
|
||||
index_t kAAlignment,
|
||||
index_t kBAlignment,
|
||||
index_t kCAlignment,
|
||||
index_t kBlockSize_,
|
||||
index_t kMPerBlock_,
|
||||
index_t kNPerBlock_,
|
||||
index_t kKPerBlock_>
|
||||
struct Gemm
|
||||
{
|
||||
using GridGemmProblem =
|
||||
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
|
||||
|
||||
struct GridGemmPolicy
|
||||
{
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kMPerBlock_;
|
||||
static constexpr index_t kNPerBlock = kNPerBlock_;
|
||||
static constexpr index_t kKPerBlock = kKPerBlock_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
#if defined(ENABLE_CACHE_AWARE_WG_SCH)
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 4;
|
||||
constexpr index_t GroupNum = 8;
|
||||
|
||||
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
|
||||
const auto update_M0 =
|
||||
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
|
||||
|
||||
const auto xcd_id = block_1d_id % GroupNum;
|
||||
|
||||
const auto l_block_id = block_1d_id - (xcd_id % 2);
|
||||
|
||||
const auto ridn = GroupNum * M01 * (update_N0 / 2);
|
||||
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
|
||||
const auto lu = (l_block_id % GroupNum) + rid * ridn;
|
||||
|
||||
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
|
||||
const auto sub_M0_id =
|
||||
(l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
|
||||
|
||||
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
|
||||
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
|
||||
|
||||
const auto total_update_size = update_N0 * update_M0;
|
||||
|
||||
if(block_1d_id >= total_update_size)
|
||||
{
|
||||
auto x = (block_1d_id + 1) - total_update_size;
|
||||
auto rlen = N0 - update_N0;
|
||||
|
||||
auto rm = 0;
|
||||
auto rn = 0;
|
||||
if(rlen > 0)
|
||||
{
|
||||
rm = (x - 1) / rlen;
|
||||
rn = x % rlen;
|
||||
}
|
||||
|
||||
if(rlen > 0 and rm < M0)
|
||||
{
|
||||
n = rn + update_N0;
|
||||
m = rm;
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x - rlen * M0;
|
||||
rm = (x - 1) / update_N0;
|
||||
rn = x % update_N0;
|
||||
n = rn;
|
||||
m = update_M0 + rm;
|
||||
}
|
||||
}
|
||||
return make_multi_index(m, n);
|
||||
};
|
||||
#else
|
||||
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
|
||||
|
||||
return [unmerge](index_t block_id) {
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
};
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
using GridGemm = GridGemm<GridGemmProblem, GridGemmPolicy>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t Lda,
|
||||
const index_t Ldb,
|
||||
const index_t Ldc,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto a_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto b_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
const auto c_dram = [&] {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
|
||||
}();
|
||||
|
||||
GridGemm{}(a_dram, b_dram, c_dram, c_element_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
74
example/ck_tile/tutorial/02_gemm/grid_gemm.hpp
Normal file
74
example/ck_tile/tutorial/02_gemm/grid_gemm.hpp
Normal file
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GridGemm
|
||||
{
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
static constexpr auto kNPerBlock = Policy::kNPerBlock;
|
||||
static constexpr auto kKPerBlock = Policy::kKPerBlock;
|
||||
|
||||
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
|
||||
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
|
||||
const BGridTensorView& b_grid,
|
||||
CGridTensorView& c_grid,
|
||||
const CElementFunction& c_element_func) const
|
||||
{
|
||||
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
|
||||
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
|
||||
|
||||
// divide problem
|
||||
const auto id_block = get_block_id();
|
||||
|
||||
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
|
||||
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
|
||||
|
||||
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
|
||||
|
||||
const auto id_tile = block2tile(id_block);
|
||||
|
||||
const auto iM =
|
||||
__builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock);
|
||||
const auto iN =
|
||||
__builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock);
|
||||
|
||||
// A block window
|
||||
auto a_block_window = make_tile_window(
|
||||
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
|
||||
|
||||
// B block window
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
|
||||
|
||||
const auto acc_block_tile =
|
||||
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
|
||||
|
||||
// cast to CDataType and apply CElementFunction
|
||||
const auto c_block_tile = tile_elementwise_in(
|
||||
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
|
||||
acc_block_tile);
|
||||
|
||||
// store C
|
||||
auto c_window = make_tile_window(
|
||||
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
|
||||
|
||||
store_tile(c_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
37
example/ck_tile/tutorial/02_gemm/reference_gemm.hpp
Normal file
37
example/ck_tile/tutorial/02_gemm/reference_gemm.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
const ck_tile::HostTensor<BDataType>& b_n_k,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n)
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int K = b_n_k.mDesc.get_lengths()[1];
|
||||
|
||||
auto f = [&](auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_n_k(n, k);
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
14
example/ck_tile/tutorial/02_gemm/stream_config.hpp
Normal file
14
example/ck_tile/tutorial/02_gemm/stream_config.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
struct StreamConfig
|
||||
{
|
||||
hipStream_t stream_id_ = nullptr;
|
||||
bool time_kernel_ = false;
|
||||
int log_level_ = 0;
|
||||
};
|
||||
@@ -0,0 +1,38 @@
|
||||
set(EXAMPLE_BASIC_FLASH_ATTENTION "basic_flash_attention_fwd")
|
||||
|
||||
message("adding example ${EXAMPLE_BASIC_FLASH_ATTENTION}")
|
||||
|
||||
add_executable(${EXAMPLE_BASIC_FLASH_ATTENTION} EXCLUDE_FROM_ALL flash_attention_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS)
|
||||
|
||||
# list(APPEND EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF)
|
||||
if(ENABLE_TOY_FA_FWD_OPT)
|
||||
message("Compiling with toy FA fwd optimization")
|
||||
target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_OPT)
|
||||
endif()
|
||||
|
||||
option(ENABLE_TOY_FA_FWD_QK_SWIZZLE "Enable toy FA fwd QK swizzle" OFF)
|
||||
if(ENABLE_TOY_FA_FWD_QK_SWIZZLE)
|
||||
message("Compiling with toy FA fwd QK swizzle")
|
||||
target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_QK_SWIZZLE)
|
||||
endif()
|
||||
|
||||
option(ENABLE_TOY_FA_FWD_CACHE_AWARE "Enable toy FA fwd cache aware" OFF)
|
||||
if(ENABLE_TOY_FA_FWD_CACHE_AWARE)
|
||||
message("Compiling with toy FA fwd cache aware")
|
||||
target_compile_definitions(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE TOY_FA_FWD_CACHE_AWARE)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_BASIC_FLASH_ATTENTION} PRIVATE ${EXAMPLE_BASIC_FLASH_ATTENTION_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Problem Description for BlockGemmARegBSmemCReg
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmARegBSmemCRegProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,565 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "block_gemm_areg_bsmem_creg_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmARegBSmemCRegV1DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using BlockGemmPolicy = Policy;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
constexpr index_t MPerXDL = WG::kM;
|
||||
constexpr index_t NPerXDL = WG::kN;
|
||||
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = config.template get<1>();
|
||||
|
||||
constexpr index_t B_LDS_RW_Width = SmemPack;
|
||||
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
|
||||
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// B split schedule
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
constexpr auto num_mfma_per_dswrite_b =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
|
||||
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_b *
|
||||
num_dswrite_per_issue_b,
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#else
|
||||
#pragma message("Enable toy FA fwd QK swizzle")
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1K8Policy
|
||||
{
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,397 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t k_loops = Policy::AKDim / kKPerBlock;
|
||||
|
||||
// Move this part into Policy?
|
||||
__host__ __device__ static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
// Cold A Register Cache
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename ARegBlockTensorTmp>
|
||||
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
ignore = b_element_func;
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
|
||||
|
||||
// B tile in LDS, blockWindow
|
||||
BDataType* p_b_lds =
|
||||
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
// This tensor view used to construct both tile window for lds store and read, with buffer
|
||||
// address info
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, 0>{},
|
||||
sequence<kMPerBlock, kKPerBlock>{});
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, (i_k0 + 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 2) * kKPerBlock>{});
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{});
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
}
|
||||
|
||||
set_slice_tile(a_reg_block_tensor_tmp,
|
||||
a_copy_reg_tensor,
|
||||
sequence<0, 0>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{});
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
// Hot A Register Cache
|
||||
template <typename BDramBlockWindowTmp, typename BElementFunction, typename ARegBlockTensorTmp>
|
||||
__host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = b_element_func;
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_reg_block_tensor_tmp,
|
||||
sequence<0, 0>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{});
|
||||
|
||||
// B tile in LDS, blockWindow
|
||||
BDataType* p_b_lds =
|
||||
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
// This tensor view used to construct both tile window for lds store and read, with buffer
|
||||
// address info
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
static_for<0, k_loops, 1>{}([&](auto i_k0) {
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Global read 0
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 0
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 1
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
}
|
||||
#endif
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename ARegBlockTensorTmp>
|
||||
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename BDramBlockWindowTmp, typename ARegBlockTensorTmp>
|
||||
__device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,144 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
{
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy;
|
||||
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = AKDim;
|
||||
|
||||
constexpr auto config =
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,197 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include "reference_batched_gemm.hpp"
|
||||
#include "reference_batched_softmax.hpp"
|
||||
#include "flash_attention_fwd.hpp"
|
||||
|
||||
/*
|
||||
* Toy code of flash attention forward pass
|
||||
* Assume simplest case.
|
||||
* Q [Batch, HeadNum, SeqenceLengthQ, HeadDim]
|
||||
* K [Batch, HeadNum, SeqenceLengthK, HeadDim]
|
||||
* V [Batch, HeadNum, HeadDim, SeqenceLengthK]
|
||||
* O [Batch, HeadNum, SeqenceLengthQ, HeadDim]
|
||||
*/
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using SaccDataType = float;
|
||||
using SMPLComputeDataType = float;
|
||||
using PDataType = ck_tile::half_t;
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
|
||||
if(argc == 3)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
verification = std::stoi(argv[2]);
|
||||
}
|
||||
else if(argc == 8)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
verification = std::stoi(argv[2]);
|
||||
Batch = std::stoi(argv[3]);
|
||||
M0 = std::stoi(argv[4]);
|
||||
N0 = std::stoi(argv[5]);
|
||||
K0 = std::stoi(argv[6]);
|
||||
N1 = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
|
||||
std::array<ck_tile::index_t, 3> q_strides{M0 * K0, K0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> k_lengths{Batch, N0, K0};
|
||||
std::array<ck_tile::index_t, 3> k_strides{N0 * K0, K0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> v_lengths{Batch, N1, N0};
|
||||
std::array<ck_tile::index_t, 3> v_strides{N1 * N0, N0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> s_lengths{Batch, M0, N0};
|
||||
std::array<ck_tile::index_t, 3> s_strides{M0 * N0, N0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> p_lengths{Batch, M0, N0};
|
||||
std::array<ck_tile::index_t, 3> p_strides{M0 * N0, N0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> o_lengths{Batch, M0, N1};
|
||||
std::array<ck_tile::index_t, 3> o_strides{M0 * N1, N1, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<QDataType> q_host(q_lengths, q_strides);
|
||||
ck_tile::HostTensor<KDataType> k_host(k_lengths, k_strides);
|
||||
ck_tile::HostTensor<VDataType> v_host(v_lengths, v_strides);
|
||||
ck_tile::HostTensor<ODataType> o_host_dev(o_lengths, o_strides);
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
|
||||
break;
|
||||
case 2:
|
||||
ck_tile::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
|
||||
break;
|
||||
default:
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
|
||||
}
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.mData.data());
|
||||
k_buf.ToDevice(k_host.mData.data());
|
||||
v_buf.ToDevice(v_host.mData.data());
|
||||
|
||||
constexpr ck_tile::index_t kM0PerBlock = 128;
|
||||
constexpr ck_tile::index_t kN0PerBlock = 128;
|
||||
constexpr ck_tile::index_t kK0PerBlock = 32;
|
||||
constexpr ck_tile::index_t kN1PerBlock = 128;
|
||||
constexpr ck_tile::index_t kK1PerBlock = 32;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kHeadDim = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
ck_tile::FlashAttentionFwd<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
|
||||
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
|
||||
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
|
||||
static_cast<ODataType*>(o_buf.GetDeviceBuffer()),
|
||||
M0,
|
||||
N0,
|
||||
K0,
|
||||
N1,
|
||||
Batch,
|
||||
K0, // StrideQ
|
||||
K0, // StrideK
|
||||
N0, // StrideV
|
||||
N1, // StrideO
|
||||
M0 * K0, // BatchStrideQ
|
||||
N0 * K0, // BatchStrideK
|
||||
N1 * N0, // BatchStrideV
|
||||
M0 * N1)); // BatchStrideO
|
||||
|
||||
// reference
|
||||
auto pass = true;
|
||||
if(verification)
|
||||
{
|
||||
o_buf.FromDevice(o_host_dev.mData.data());
|
||||
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
|
||||
ck_tile::HostTensor<PDataType> p_host_ref(p_lengths, p_strides);
|
||||
ck_tile::HostTensor<ODataType> o_host_ref(o_lengths, o_strides);
|
||||
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host, k_host, s_host_ref);
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref);
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref, v_host, o_host_ref);
|
||||
|
||||
pass &= ck_tile::check_err(o_host_dev, o_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop =
|
||||
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
|
||||
std::size_t num_btype =
|
||||
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 +
|
||||
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "flash_attention_fwd_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 4;
|
||||
constexpr index_t GroupNum = 8;
|
||||
|
||||
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
|
||||
const auto update_M0 =
|
||||
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
|
||||
|
||||
const auto xcd_id = block_1d_id % GroupNum;
|
||||
|
||||
const auto l_block_id = block_1d_id - (xcd_id % 2);
|
||||
|
||||
const auto ridn = GroupNum * M01 * (update_N0 / 2);
|
||||
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
|
||||
const auto lu = (l_block_id % GroupNum) + rid * ridn;
|
||||
|
||||
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
|
||||
const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
|
||||
|
||||
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
|
||||
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
|
||||
|
||||
const auto total_update_size = update_N0 * update_M0;
|
||||
|
||||
if(block_1d_id >= total_update_size)
|
||||
{
|
||||
auto x = (block_1d_id + 1) - total_update_size;
|
||||
auto rlen = N0 - update_N0;
|
||||
|
||||
auto rm = 0;
|
||||
auto rn = 0;
|
||||
if(rlen > 0)
|
||||
{
|
||||
rm = (x - 1) / rlen;
|
||||
rn = x % rlen;
|
||||
}
|
||||
|
||||
if(rlen > 0 and rm < M0)
|
||||
{
|
||||
n = rn + update_N0;
|
||||
m = rm;
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x - rlen * M0;
|
||||
rm = (x - 1) / update_N0;
|
||||
rn = x % update_N0;
|
||||
n = rn;
|
||||
m = update_M0 + rm;
|
||||
}
|
||||
}
|
||||
return make_multi_index(m, n);
|
||||
};
|
||||
}
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
// P[M0, N0] = Softmax(S[M0, N0])
|
||||
// O[M0, N1] = P[M0, N0] * V[N1, N0]
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType,
|
||||
index_t kBlockSize,
|
||||
index_t kHeadDim,
|
||||
index_t kM0PerBlock,
|
||||
index_t kN0PerBlock,
|
||||
index_t kK0PerBlock,
|
||||
index_t kN1PerBlock,
|
||||
index_t kK1PerBlock>
|
||||
struct FlashAttentionFwd
|
||||
{
|
||||
__device__ void operator()(const QDataType* q_ptr,
|
||||
const KDataType* k_ptr,
|
||||
const VDataType* v_ptr,
|
||||
ODataType* o_ptr,
|
||||
const index_t M0,
|
||||
const index_t N0,
|
||||
const index_t K0,
|
||||
const index_t N1,
|
||||
const index_t /* Batch */,
|
||||
const index_t StrideQ,
|
||||
const index_t StrideK,
|
||||
const index_t StrideV,
|
||||
const index_t StrideO,
|
||||
const index_t BatchStrideQ,
|
||||
const index_t BatchStrideK,
|
||||
const index_t BatchStrideV,
|
||||
const index_t BatchStrideO) const
|
||||
{
|
||||
const index_t id_block = get_block_id();
|
||||
|
||||
const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock);
|
||||
const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock);
|
||||
|
||||
#if defined(TOY_FA_FWD_CACHE_AWARE)
|
||||
#pragma message("Enable toy FA fwd cache aware")
|
||||
const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1);
|
||||
|
||||
const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0;
|
||||
const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0);
|
||||
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) %
|
||||
num_tile_m0 * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) %
|
||||
num_tile_n1 * kN1PerBlock);
|
||||
|
||||
#else
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
|
||||
return make_tuple(quotient, modulus);
|
||||
};
|
||||
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
|
||||
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
|
||||
|
||||
#endif
|
||||
|
||||
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>{};
|
||||
|
||||
kernel_impl(q_ptr + iBatch * BatchStrideQ,
|
||||
k_ptr + iBatch * BatchStrideK,
|
||||
v_ptr + iBatch * BatchStrideV,
|
||||
o_ptr + iBatch * BatchStrideO,
|
||||
M0,
|
||||
N0,
|
||||
K0,
|
||||
N1,
|
||||
StrideQ,
|
||||
StrideK,
|
||||
StrideV,
|
||||
StrideO,
|
||||
iM0,
|
||||
iN1);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,440 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
// P[M0, N0] = Softmax(S[M0, N0])
|
||||
// O[M0, N1] = P[M0, N0] * V[N1, N0]
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType,
|
||||
index_t kBlockSize,
|
||||
index_t kHeadDim,
|
||||
index_t kM0PerBlock,
|
||||
index_t kN0PerBlock,
|
||||
index_t kK0PerBlock,
|
||||
index_t kN1PerBlock,
|
||||
index_t kK1PerBlock>
|
||||
struct FlashAttentionFwdImpl
|
||||
{
|
||||
// block gemm0 pipeline
|
||||
using BlockGemm0Problem =
|
||||
BlockGemmPipelineProblem<QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kM0PerBlock, kN0PerBlock, kK0PerBlock>>;
|
||||
|
||||
using BlockGemm0Policy =
|
||||
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
|
||||
|
||||
using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg<BlockGemm0Problem, BlockGemm0Policy>;
|
||||
|
||||
// block gemm1
|
||||
using BlockGemm1 = BlockGemmARegBSmemCRegV1<
|
||||
BlockGemmARegBSmemCRegProblem<PDataType,
|
||||
VDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
|
||||
BlockGemmARegBSmemCRegV1DefaultPolicy>;
|
||||
|
||||
// 3d, with padding
|
||||
__device__ static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = kN1PerBlock;
|
||||
constexpr index_t kKPerBlock = kK1PerBlock;
|
||||
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
|
||||
constexpr index_t kKPack = 4;
|
||||
#else
|
||||
constexpr index_t kKPack = 8;
|
||||
#endif
|
||||
|
||||
constexpr auto dataTypeSize = sizeof(VDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using BDataType = VDataType;
|
||||
|
||||
constexpr index_t kNPerBlock = kN1PerBlock;
|
||||
constexpr index_t kKPerBlock = kK1PerBlock;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return max(BlockGemm0Pipeline::GetStaticLdsSize(),
|
||||
static_cast<index_t>(MakeVLdsBlockDescriptor().get_element_space_size() *
|
||||
sizeof(VDataType)));
|
||||
}
|
||||
|
||||
__device__ void operator()(const QDataType* q_ptr,
|
||||
const KDataType* k_ptr,
|
||||
const VDataType* v_ptr,
|
||||
ODataType* o_ptr,
|
||||
const index_t M0,
|
||||
const index_t N0,
|
||||
const index_t K0,
|
||||
const index_t N1,
|
||||
const index_t StrideQ,
|
||||
const index_t StrideK,
|
||||
const index_t StrideV,
|
||||
const index_t StrideO,
|
||||
const index_t iM0,
|
||||
const index_t iN1) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetStaticLdsSize()];
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto v_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}),
|
||||
{iM0, 0},
|
||||
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
// V LDS and LDS window
|
||||
// V LDS occupies the same LDS allocation Q/K LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SaccBlockTileType =
|
||||
decltype(gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, nullptr));
|
||||
|
||||
using SBlockTileType = decltype(tile_elementwise_in(
|
||||
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
|
||||
|
||||
using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
|
||||
SaccBlockTileType{}));
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm1(
|
||||
get_slice_tile(
|
||||
PBlockTileType{}, sequence<0, 0>{}, sequence<kM0PerBlock, kK1PerBlock>{}),
|
||||
v_dram_window));
|
||||
|
||||
// init Sacc, Oacc, M, L
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
tile_elementwise_inout([](auto& e) { e = 0; }, o_acc);
|
||||
tile_elementwise_inout(
|
||||
[](auto& e) { e = std::numeric_limits<SMPLComputeDataType>::lowest(); }, m);
|
||||
tile_elementwise_inout([](auto& e) { e = 0; }, l);
|
||||
|
||||
// loop over Column of S (J loop)
|
||||
index_t iN0 = 0;
|
||||
|
||||
do
|
||||
{
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
|
||||
// S{j}
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// prefetch load v tile
|
||||
auto v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#endif
|
||||
// m_local = rowmax(S{j})
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
|
||||
|
||||
block_tile_reduce_sync(m_local, f_max);
|
||||
|
||||
// m{j-1}
|
||||
const auto m_old = m;
|
||||
|
||||
// m{j}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
// Pcompute{j}
|
||||
auto p_compute =
|
||||
make_static_distributed_tensor<SMPLComputeDataType>(s.get_tile_distribution());
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(p_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
sweep_tile_span(p_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
// rowsum(Pcompute{j})
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum);
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
|
||||
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
block_sync_lds();
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using VLdsTile = typename decltype(gemm1)::BLdsTile;
|
||||
VLdsTile vWarpTile;
|
||||
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
}
|
||||
if constexpr(k1_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
|
||||
// Global read 2
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
// move tile windows
|
||||
move_tile_window(k_dram_window, {kN0PerBlock, 0});
|
||||
iN0 += kN0PerBlock;
|
||||
} while(iN0 < N0);
|
||||
|
||||
// Oacc
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto tmp = 1 / l[i_idx];
|
||||
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
// type cast Oacc into O
|
||||
const auto o = tile_elementwise_in(type_convert<ODataType, OaccDataType>, o_acc);
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kN1PerBlock>{}),
|
||||
{iM0, iN1},
|
||||
o.get_tile_distribution());
|
||||
|
||||
// store O
|
||||
store_tile(o_dram_window, o);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
void reference_batched_gemm(const ck_tile::HostTensor<ADataType>& a_b_m_k,
|
||||
const ck_tile::HostTensor<BDataType>& b_b_n_k,
|
||||
ck_tile::HostTensor<CDataType>& c_b_m_n)
|
||||
{
|
||||
const int N = b_b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = b_b_n_k.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_b_m_k(batch, m, k);
|
||||
BDataType v_b = b_b_n_k(batch, n, k);
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(v_acc);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename AccDataType, typename BDataType>
|
||||
void reference_batched_softmax(const ck_tile::HostTensor<ADataType>& a_b_m_n,
|
||||
ck_tile::HostTensor<BDataType>& b_b_m_n)
|
||||
{
|
||||
const int N = a_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
AccDataType v_max = std::numeric_limits<ADataType>::lowest();
|
||||
|
||||
// max
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ADataType v_a = a_b_m_n(batch, m, n);
|
||||
|
||||
v_max = v_max < v_a ? v_a : v_max;
|
||||
}
|
||||
|
||||
AccDataType v_exp_sum = 0;
|
||||
|
||||
// sum
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ADataType v_a = a_b_m_n(batch, m, n);
|
||||
|
||||
v_exp_sum += ck_tile::exp(v_a - v_max);
|
||||
}
|
||||
|
||||
// elementwise
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ADataType v_a = a_b_m_n(batch, m, n);
|
||||
|
||||
b_b_m_n(batch, m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,66 @@
|
||||
set(FLASH_ATTENTION_FWD_KNOWN_APIS "fwd")
|
||||
set(FLASH_ATTENTION_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FLASH_ATTENTION_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
option(TOY_FA_FWD_OPT "Enable toy flash attention forward optimization" ON)
|
||||
option(TOY_FA_FWD_QK_SWIZZLE "Enable toy flash attention forward QK swizzle" OFF)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list Flash Attention kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/flash_attention_fwd_blobs.txt FLASH_ATTENTION_FWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--gen_blobs
|
||||
)
|
||||
|
||||
set(EXAMPLE_FA "codegen_basic_flash_attention_fwd")
|
||||
message("adding example ${EXAMPLE_FA}")
|
||||
|
||||
add_executable(${EXAMPLE_FA}
|
||||
EXCLUDE_FROM_ALL
|
||||
flash_attention_fwd.cpp
|
||||
)
|
||||
|
||||
target_compile_definitions(${EXAMPLE_FA}
|
||||
PRIVATE
|
||||
$<$<BOOL:${TOY_FA_FWD_OPT}>:TOY_FA_FWD_OPT=1>
|
||||
)
|
||||
|
||||
target_include_directories(${EXAMPLE_FA}
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
target_sources(${EXAMPLE_FA} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS})
|
||||
|
||||
message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}")
|
||||
|
||||
set(EXAMPLE_FA_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_FA_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
|
||||
target_compile_options(${EXAMPLE_FA}
|
||||
PRIVATE
|
||||
${EXAMPLE_FA_COMPILE_OPTIONS}
|
||||
)
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Problem Description for BlockGemmARegBSmemCReg
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmARegBSmemCRegProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,565 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "block_gemm_areg_bsmem_creg_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem, typename Policy = BlockGemmARegBSmemCRegV1DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using BlockGemmPolicy = Policy;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
constexpr index_t MPerXDL = WG::kM;
|
||||
constexpr index_t NPerXDL = WG::kN;
|
||||
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = config.template get<1>();
|
||||
|
||||
constexpr index_t B_LDS_RW_Width = SmemPack;
|
||||
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
|
||||
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// B split schedule
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
constexpr auto num_mfma_per_dswrite_b =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
|
||||
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_b *
|
||||
num_dswrite_per_issue_b,
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
#else
|
||||
#pragma message("Enable toy FA fwd QK swizzle")
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1K8Policy
|
||||
{
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
#else
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,397 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t k_loops = Policy::AKDim / kKPerBlock;
|
||||
|
||||
// Move this part into Policy?
|
||||
__host__ __device__ static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
// Cold A Register Cache
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename ARegBlockTensorTmp>
|
||||
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
ignore = b_element_func;
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
|
||||
|
||||
// B tile in LDS, blockWindow
|
||||
BDataType* p_b_lds =
|
||||
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
// This tensor view used to construct both tile window for lds store and read, with buffer
|
||||
// address info
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, 0>{},
|
||||
sequence<kMPerBlock, kKPerBlock>{});
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, (i_k0 + 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 2) * kKPerBlock>{});
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{});
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
}
|
||||
|
||||
set_slice_tile(a_reg_block_tensor_tmp,
|
||||
a_copy_reg_tensor,
|
||||
sequence<0, 0>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{});
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
// Hot A Register Cache
|
||||
template <typename BDramBlockWindowTmp, typename BElementFunction, typename ARegBlockTensorTmp>
|
||||
__host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = b_element_func;
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_reg_block_tensor_tmp,
|
||||
sequence<0, 0>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{});
|
||||
|
||||
// B tile in LDS, blockWindow
|
||||
BDataType* p_b_lds =
|
||||
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
// This tensor view used to construct both tile window for lds store and read, with buffer
|
||||
// address info
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
static_for<0, k_loops, 1>{}([&](auto i_k0) {
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Global read 0
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 0
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 1
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
}
|
||||
#endif
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename ARegBlockTensorTmp>
|
||||
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename BDramBlockWindowTmp, typename ARegBlockTensorTmp>
|
||||
__device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const ARegBlockTensorTmp& a_reg_block_tensor_tmp,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,144 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
{
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy;
|
||||
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = AKDim;
|
||||
|
||||
constexpr auto config =
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,173 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include "reference_batched_gemm.hpp"
|
||||
#include "reference_batched_softmax.hpp"
|
||||
#include "flash_attention_fwd.hpp"
|
||||
|
||||
/*
|
||||
* Toy code of flash attention forward pass
|
||||
* Assume simplest case.
|
||||
* Q [Batch, HeadNum, SeqenceLengthQ, HeadDim]
|
||||
* K [Batch, HeadNum, SeqenceLengthK, HeadDim]
|
||||
* V [Batch, HeadNum, HeadDim, SeqenceLengthK]
|
||||
* O [Batch, HeadNum, SeqenceLengthQ, HeadDim]
|
||||
*/
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using SaccDataType = float;
|
||||
using SMPLComputeDataType = float;
|
||||
using PDataType = ck_tile::half_t;
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
|
||||
if(argc == 3)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
verification = std::stoi(argv[2]);
|
||||
}
|
||||
else if(argc == 8)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
verification = std::stoi(argv[2]);
|
||||
Batch = std::stoi(argv[3]);
|
||||
M0 = std::stoi(argv[4]);
|
||||
N0 = std::stoi(argv[5]);
|
||||
K0 = std::stoi(argv[6]);
|
||||
N1 = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
|
||||
std::array<ck_tile::index_t, 3> q_strides{M0 * K0, K0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> k_lengths{Batch, N0, K0};
|
||||
std::array<ck_tile::index_t, 3> k_strides{N0 * K0, K0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> v_lengths{Batch, N1, N0};
|
||||
std::array<ck_tile::index_t, 3> v_strides{N1 * N0, N0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> s_lengths{Batch, M0, N0};
|
||||
std::array<ck_tile::index_t, 3> s_strides{M0 * N0, N0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> p_lengths{Batch, M0, N0};
|
||||
std::array<ck_tile::index_t, 3> p_strides{M0 * N0, N0, 1};
|
||||
|
||||
std::array<ck_tile::index_t, 3> o_lengths{Batch, M0, N1};
|
||||
std::array<ck_tile::index_t, 3> o_strides{M0 * N1, N1, 1};
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<QDataType> q_host(q_lengths, q_strides);
|
||||
ck_tile::HostTensor<KDataType> k_host(k_lengths, k_strides);
|
||||
ck_tile::HostTensor<VDataType> v_host(v_lengths, v_strides);
|
||||
ck_tile::HostTensor<ODataType> o_host_dev(o_lengths, o_strides);
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
|
||||
break;
|
||||
case 2:
|
||||
ck_tile::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
|
||||
break;
|
||||
default:
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
|
||||
}
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.mData.data());
|
||||
k_buf.ToDevice(k_host.mData.data());
|
||||
v_buf.ToDevice(v_host.mData.data());
|
||||
|
||||
// Construct the FlashAttnArgs object with your arguments
|
||||
ck_tile::FlashAttnArgs<QDataType, KDataType, VDataType, ODataType> flash_attention_args{
|
||||
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
|
||||
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
|
||||
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
|
||||
static_cast<ODataType*>(o_buf.GetDeviceBuffer()),
|
||||
M0,
|
||||
N0,
|
||||
K0,
|
||||
N1,
|
||||
Batch,
|
||||
K0, // strideQ
|
||||
K0, // strideK
|
||||
N0, // strideV
|
||||
N1, // strideO
|
||||
M0 * K0, // batchStrideQ
|
||||
N0 * K0, // batchStrideK
|
||||
N1 * N0, // batchStrideV
|
||||
M0 * N1 // batchStrideO
|
||||
};
|
||||
|
||||
float ave_time = ck_tile::flash_attention_fwd<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType>(flash_attention_args,
|
||||
ck_tile::stream_config{nullptr, true});
|
||||
|
||||
// reference
|
||||
auto pass = true;
|
||||
if(verification)
|
||||
{
|
||||
o_buf.FromDevice(o_host_dev.mData.data());
|
||||
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides);
|
||||
ck_tile::HostTensor<PDataType> p_host_ref(p_lengths, p_strides);
|
||||
ck_tile::HostTensor<ODataType> o_host_ref(o_lengths, o_strides);
|
||||
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host, k_host, s_host_ref);
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref);
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref, v_host, o_host_ref);
|
||||
|
||||
pass &= ck_tile::check_err(o_host_dev, o_host_ref);
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
std::size_t flop =
|
||||
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
|
||||
std::size_t num_btype =
|
||||
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 +
|
||||
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "flash_attention_fwd_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType, typename KDataType, typename VDataType, typename ODataType>
|
||||
struct FlashAttnArgs
|
||||
{
|
||||
// Pointers to device buffers for Q, K, V, O
|
||||
QDataType* q_ptr;
|
||||
KDataType* k_ptr;
|
||||
VDataType* v_ptr;
|
||||
ODataType* o_ptr;
|
||||
|
||||
// Problem sizes
|
||||
index_t M0;
|
||||
index_t N0;
|
||||
index_t K0;
|
||||
index_t N1;
|
||||
index_t Batch;
|
||||
|
||||
// Strides within a batch
|
||||
index_t strideQ;
|
||||
index_t strideK;
|
||||
index_t strideV;
|
||||
index_t strideO;
|
||||
|
||||
// Batch strides
|
||||
index_t batchStrideQ;
|
||||
index_t batchStrideK;
|
||||
index_t batchStrideV;
|
||||
index_t batchStrideO;
|
||||
};
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
// P[M0, N0] = Softmax(S[M0, N0])
|
||||
// O[M0, N1] = P[M0, N0] * V[N1, N0]
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType,
|
||||
index_t kBlockSize,
|
||||
index_t kHeadDim,
|
||||
index_t kM0PerBlock,
|
||||
index_t kN0PerBlock,
|
||||
index_t kK0PerBlock,
|
||||
index_t kN1PerBlock,
|
||||
index_t kK1PerBlock>
|
||||
struct FlashAttentionFwd
|
||||
{
|
||||
__device__ void operator()(const QDataType* q_ptr,
|
||||
const KDataType* k_ptr,
|
||||
const VDataType* v_ptr,
|
||||
ODataType* o_ptr,
|
||||
const index_t M0,
|
||||
const index_t N0,
|
||||
const index_t K0,
|
||||
const index_t N1,
|
||||
const index_t /* Batch */,
|
||||
const index_t StrideQ,
|
||||
const index_t StrideK,
|
||||
const index_t StrideV,
|
||||
const index_t StrideO,
|
||||
const index_t BatchStrideQ,
|
||||
const index_t BatchStrideK,
|
||||
const index_t BatchStrideV,
|
||||
const index_t BatchStrideO) const
|
||||
{
|
||||
const index_t id_block = get_block_id();
|
||||
|
||||
const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock);
|
||||
const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock);
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
|
||||
return make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
|
||||
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
|
||||
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
|
||||
|
||||
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>{};
|
||||
|
||||
kernel_impl(q_ptr + iBatch * BatchStrideQ,
|
||||
k_ptr + iBatch * BatchStrideK,
|
||||
v_ptr + iBatch * BatchStrideV,
|
||||
o_ptr + iBatch * BatchStrideO,
|
||||
M0,
|
||||
N0,
|
||||
K0,
|
||||
N1,
|
||||
StrideQ,
|
||||
StrideK,
|
||||
StrideV,
|
||||
StrideO,
|
||||
iM0,
|
||||
iN1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType>
|
||||
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const stream_config& stream_config);
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,440 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
// P[M0, N0] = Softmax(S[M0, N0])
|
||||
// O[M0, N1] = P[M0, N0] * V[N1, N0]
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType,
|
||||
index_t kBlockSize,
|
||||
index_t kHeadDim,
|
||||
index_t kM0PerBlock,
|
||||
index_t kN0PerBlock,
|
||||
index_t kK0PerBlock,
|
||||
index_t kN1PerBlock,
|
||||
index_t kK1PerBlock>
|
||||
struct FlashAttentionFwdImpl
|
||||
{
|
||||
// block gemm0 pipeline
|
||||
using BlockGemm0Problem =
|
||||
BlockGemmPipelineProblem<QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kM0PerBlock, kN0PerBlock, kK0PerBlock>>;
|
||||
|
||||
using BlockGemm0Policy =
|
||||
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
|
||||
|
||||
using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg<BlockGemm0Problem, BlockGemm0Policy>;
|
||||
|
||||
// block gemm1
|
||||
using BlockGemm1 = BlockGemmARegBSmemCRegV1<
|
||||
BlockGemmARegBSmemCRegProblem<PDataType,
|
||||
VDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
|
||||
BlockGemmARegBSmemCRegV1DefaultPolicy>;
|
||||
|
||||
// 3d, with padding
|
||||
__device__ static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = kN1PerBlock;
|
||||
constexpr index_t kKPerBlock = kK1PerBlock;
|
||||
#if !defined(TOY_FA_FWD_QK_SWIZZLE)
|
||||
constexpr index_t kKPack = 4;
|
||||
#else
|
||||
constexpr index_t kKPack = 8;
|
||||
#endif
|
||||
|
||||
constexpr auto dataTypeSize = sizeof(VDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using BDataType = VDataType;
|
||||
|
||||
constexpr index_t kNPerBlock = kN1PerBlock;
|
||||
constexpr index_t kKPerBlock = kK1PerBlock;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return max(BlockGemm0Pipeline::GetStaticLdsSize(),
|
||||
static_cast<index_t>(MakeVLdsBlockDescriptor().get_element_space_size() *
|
||||
sizeof(VDataType)));
|
||||
}
|
||||
|
||||
__device__ void operator()(const QDataType* q_ptr,
|
||||
const KDataType* k_ptr,
|
||||
const VDataType* v_ptr,
|
||||
ODataType* o_ptr,
|
||||
const index_t M0,
|
||||
const index_t N0,
|
||||
const index_t K0,
|
||||
const index_t N1,
|
||||
const index_t StrideQ,
|
||||
const index_t StrideK,
|
||||
const index_t StrideV,
|
||||
const index_t StrideO,
|
||||
const index_t iM0,
|
||||
const index_t iN1) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetStaticLdsSize()];
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto v_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}),
|
||||
{iM0, 0},
|
||||
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
// V LDS and LDS window
|
||||
// V LDS occupies the same LDS allocation Q/K LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SaccBlockTileType =
|
||||
decltype(gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, nullptr));
|
||||
|
||||
using SBlockTileType = decltype(tile_elementwise_in(
|
||||
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
|
||||
|
||||
using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
|
||||
SaccBlockTileType{}));
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm1(
|
||||
get_slice_tile(
|
||||
PBlockTileType{}, sequence<0, 0>{}, sequence<kM0PerBlock, kK1PerBlock>{}),
|
||||
v_dram_window));
|
||||
|
||||
// init Sacc, Oacc, M, L
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
tile_elementwise_inout([](auto& e) { e = 0; }, o_acc);
|
||||
tile_elementwise_inout(
|
||||
[](auto& e) { e = std::numeric_limits<SMPLComputeDataType>::lowest(); }, m);
|
||||
tile_elementwise_inout([](auto& e) { e = 0; }, l);
|
||||
|
||||
// loop over Column of S (J loop)
|
||||
index_t iN0 = 0;
|
||||
|
||||
do
|
||||
{
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
|
||||
// S{j}
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// prefetch load v tile
|
||||
auto v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#endif
|
||||
// m_local = rowmax(S{j})
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
|
||||
|
||||
block_tile_reduce_sync(m_local, f_max);
|
||||
|
||||
// m{j-1}
|
||||
const auto m_old = m;
|
||||
|
||||
// m{j}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
|
||||
// Pcompute{j}
|
||||
auto p_compute =
|
||||
make_static_distributed_tensor<SMPLComputeDataType>(s.get_tile_distribution());
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(p_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
sweep_tile_span(p_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - m[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
// rowsum(Pcompute{j})
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0});
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum);
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
|
||||
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
block_sync_lds();
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using VLdsTile = typename decltype(gemm1)::BLdsTile;
|
||||
VLdsTile vWarpTile;
|
||||
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
}
|
||||
if constexpr(k1_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
|
||||
// Global read 2
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
// move tile windows
|
||||
move_tile_window(k_dram_window, {kN0PerBlock, 0});
|
||||
iN0 += kN0PerBlock;
|
||||
} while(iN0 < N0);
|
||||
|
||||
// Oacc
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto tmp = 1 / l[i_idx];
|
||||
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
// type cast Oacc into O
|
||||
const auto o = tile_elementwise_in(type_convert<ODataType, OaccDataType>, o_acc);
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kN1PerBlock>{}),
|
||||
{iM0, iN1},
|
||||
o.get_tile_distribution());
|
||||
|
||||
// store O
|
||||
store_tile(o_dram_window, o);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,578 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
def get_if_str(size_, total, last_else=True):
|
||||
if size_ == "head_dim_128_seq_16384":
|
||||
return 'if'
|
||||
else:
|
||||
return 'else if'
|
||||
|
||||
DATA_TYPE_MAP = {'fp32': 'float',
|
||||
'fp16': 'ck_tile::half_t',
|
||||
'bf16': 'ck_tile::bf16_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
return 'true' if b_ else 'false'
|
||||
|
||||
class FlashAttentionFwdCodegen:
|
||||
API_TRAITS_DEFINE = """
|
||||
|
||||
template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kHeadDim = kHeadDim_;
|
||||
static constexpr index_t kM0PerBlock = kM0PerBlock_;
|
||||
static constexpr index_t kN0PerBlock = kN0PerBlock_;
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
};
|
||||
|
||||
template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
ck_tile::index_t kBlockSize = 256,
|
||||
ck_tile::index_t kHeadDim = 128,
|
||||
ck_tile::index_t kM0PerBlock = 128,
|
||||
ck_tile::index_t kN0PerBlock = 128,
|
||||
ck_tile::index_t kK0PerBlock = 64,
|
||||
ck_tile::index_t kN1PerBlock = 128,
|
||||
ck_tile::index_t kK1PerBlock = 64>
|
||||
using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>;
|
||||
"""
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "flash_attention_fwd.hpp"
|
||||
|
||||
namespace ck_tile {{
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename Traits_>
|
||||
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType>
|
||||
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config) {{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
|
||||
template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, float, float, ck_tile::half_t, float, ck_tile::half_t>(
|
||||
const FlashAttnArgs<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t>&,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
}}
|
||||
"""
|
||||
|
||||
API_INNER_CASE = """ {F_if} {F_VEC_COND}
|
||||
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(a, stream_config);
|
||||
"""
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "flash_attention_fwd_api_common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// clang-format off
|
||||
//
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, working_path, kernel_filter):
|
||||
self.working_path = working_path
|
||||
self.kernel_filter = kernel_filter
|
||||
|
||||
@dataclass
|
||||
class h_traits:
|
||||
F_SaccDataType: str
|
||||
F_SMPLComputeDataType: str
|
||||
F_PDataType: str
|
||||
F_OaccDataType: str
|
||||
F_kBlockSize: int
|
||||
F_kHeadDim: int
|
||||
F_kM0PerBlock: int
|
||||
F_kN0PerBlock: int
|
||||
F_kK0PerBlock: int
|
||||
F_kN1PerBlock: int
|
||||
F_kK1PerBlock: int
|
||||
|
||||
@property
|
||||
def trait_name(self) -> str:
|
||||
return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, "
|
||||
f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, "
|
||||
f"{DATA_TYPE_MAP[self.F_PDataType]}, "
|
||||
f"{DATA_TYPE_MAP[self.F_OaccDataType]}, "
|
||||
f"{self.F_kBlockSize}, {self.F_kHeadDim}, "
|
||||
f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, "
|
||||
f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}")
|
||||
|
||||
@property
|
||||
def def_name(self) -> str:
|
||||
return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, "
|
||||
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, "
|
||||
f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, "
|
||||
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, "
|
||||
"const ck_tile::stream_config&);")
|
||||
|
||||
@dataclass
|
||||
class h_instance:
|
||||
F_DataTypePair: str # "q,k,v,o"
|
||||
F_SizeCategory: str # "small", "medium", "large"
|
||||
instance_list: List[Any] # List[h_traits]
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "flash_attention_fwd_api_common.hpp"
|
||||
|
||||
namespace ck_tile {{
|
||||
// clang-format off
|
||||
//
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
}}
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
q_type, k_type, v_type, o_type = self.F_DataTypePair.split(',')
|
||||
dtype_str = f"{q_type}_{k_type}_{v_type}_{o_type}"
|
||||
return f"flash_attention_fwd_{dtype_str}_{self.F_SizeCategory}"
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
instance_defs = '\n'.join(ins.def_name for ins in self.instance_list)
|
||||
return self.INSTANCE_BASE.format(F_instance_def=instance_defs)
|
||||
|
||||
@property
|
||||
def name_api(self) -> str:
|
||||
return "flash_attention_fwd_api"
|
||||
|
||||
@property
|
||||
def name_common_header(self) -> str:
|
||||
return "flash_attention_fwd_api_common"
|
||||
|
||||
@property
|
||||
def content_common_header(self) -> str:
|
||||
return f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "flash_attention_fwd.hpp"
|
||||
|
||||
namespace ck_tile {{
|
||||
|
||||
template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kHeadDim = kHeadDim_;
|
||||
static constexpr index_t kM0PerBlock = kM0PerBlock_;
|
||||
static constexpr index_t kN0PerBlock = kN0PerBlock_;
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
}};
|
||||
|
||||
|
||||
template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
ck_tile::index_t kBlockSize,
|
||||
ck_tile::index_t kHeadDim,
|
||||
ck_tile::index_t kM0PerBlock,
|
||||
ck_tile::index_t kN0PerBlock,
|
||||
ck_tile::index_t kK0PerBlock,
|
||||
ck_tile::index_t kN1PerBlock,
|
||||
ck_tile::index_t kK1PerBlock>
|
||||
using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>;
|
||||
|
||||
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename Traits_>
|
||||
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config) {{
|
||||
using SaccDataType = typename Traits_::SaccDataType;
|
||||
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
using PDataType = typename Traits_::PDataType;
|
||||
using OaccDataType = typename Traits_::OaccDataType;
|
||||
|
||||
index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(stream_config,
|
||||
ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
|
||||
ck_tile::FlashAttentionFwd<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
Traits_::kBlockSize,
|
||||
Traits_::kHeadDim,
|
||||
Traits_::kM0PerBlock,
|
||||
Traits_::kN0PerBlock,
|
||||
Traits_::kK0PerBlock,
|
||||
Traits_::kN1PerBlock,
|
||||
Traits_::kK1PerBlock>{{}},
|
||||
kGridSize,
|
||||
Traits_::kBlockSize,
|
||||
0,
|
||||
a.q_ptr,
|
||||
a.k_ptr,
|
||||
a.v_ptr,
|
||||
a.o_ptr,
|
||||
a.M0,
|
||||
a.N0,
|
||||
a.K0,
|
||||
a.N1,
|
||||
a.Batch,
|
||||
a.strideQ, // StrideQ
|
||||
a.strideK, // StrideK
|
||||
a.strideV, // StrideV
|
||||
a.strideO, // StrideO
|
||||
a.batchStrideQ, // BatchStrideQ
|
||||
a.batchStrideK, // BatchStrideK
|
||||
a.batchStrideV, // BatchStrideV
|
||||
a.batchStrideO)); // BatchStrideO
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
def content_api(self, args) -> str:
|
||||
# Sort based on dtype
|
||||
t_dtype_dict = {}
|
||||
blobs = self.get_blobs(args)
|
||||
|
||||
for blob in blobs:
|
||||
if blob.F_DataTypePair not in t_dtype_dict:
|
||||
t_dtype_dict[blob.F_DataTypePair] = {}
|
||||
if blob.F_SizeCategory not in t_dtype_dict[blob.F_DataTypePair]:
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_SizeCategory] = []
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_SizeCategory].append(blob)
|
||||
|
||||
d_str = ''
|
||||
for i_d, dtype_ in enumerate(t_dtype_dict):
|
||||
blob_per_t = t_dtype_dict[dtype_]
|
||||
size_str = ''
|
||||
|
||||
for i_size, size_ in enumerate(blob_per_t):
|
||||
blob_per_size = blob_per_t[size_]
|
||||
inner_str = ""
|
||||
|
||||
for i_b, b_ in enumerate(blob_per_size):
|
||||
for i_ins, ins in enumerate(b_.instance_list):
|
||||
idx_in_size = i_b * len(b_.instance_list) + i_ins
|
||||
len_in_size = sum(len(b.instance_list) for b in blob_per_size)
|
||||
|
||||
size_cond = ""
|
||||
if size_ == "head_dim_128_seq_16384":
|
||||
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_16384":
|
||||
size_cond = "(a.M0 <= 16384 && a.N0 <= 16384 && a.M0 > 8192 && a.N0 > 8192 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_128_seq_8192":
|
||||
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_8192":
|
||||
size_cond = "(a.M0 <= 8192 && a.N0 <= 8192 && a.M0 > 4096 && a.N0 > 4096 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_256_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)"
|
||||
elif size_ == "head_dim_128_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_32_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)"
|
||||
elif size_ == "head_dim_128_seq_2048":
|
||||
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_2048":
|
||||
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 1024 && a.N0 > 1024 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_128_seq_1024":
|
||||
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_1024":
|
||||
size_cond = "(a.M0 <= 1024 && a.N0 <= 1024 && a.M0 > 512 && a.N0 > 512 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_128_seq_512":
|
||||
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_512":
|
||||
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 64 && a.N1 == 64)"
|
||||
else:
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
|
||||
inner_str += self.API_INNER_CASE.format(
|
||||
F_if=get_if_str(size_, len_in_size, False),
|
||||
F_VEC_COND=size_cond,
|
||||
F_trait_name=ins.trait_name
|
||||
)
|
||||
size_str += inner_str
|
||||
|
||||
d_str += size_str
|
||||
|
||||
api_base = self.API_BASE.format(
|
||||
F_traits_define=self.API_TRAITS_DEFINE,
|
||||
F_dispatch=d_str
|
||||
)
|
||||
return api_base
|
||||
|
||||
def get_blobs(self, args):
|
||||
h_traits = self.h_traits
|
||||
h_instance = self.h_instance
|
||||
|
||||
# Define kernel configurations for different size categories
|
||||
trait_dict = {
|
||||
"head_dim_128_seq_16384": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_16384": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_128_seq_8192": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_8192": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_256_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"head_dim_128_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
|
||||
],
|
||||
"head_dim_32_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32),
|
||||
],
|
||||
"head_dim_128_seq_2048": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_2048": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_128_seq_1024": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_1024": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 32, 64, 32),
|
||||
],
|
||||
"head_dim_128_seq_512": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
],
|
||||
"head_dim_64_seq_512": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
|
||||
],
|
||||
}
|
||||
|
||||
# Toy example only support fp16
|
||||
dtype_combinations = [
|
||||
"fp16,fp16,fp16,fp16"
|
||||
# "bf16,bf16,bf16,bf16"
|
||||
]
|
||||
|
||||
total_blob = []
|
||||
for dtype_pair in dtype_combinations:
|
||||
for size_category in trait_dict:
|
||||
traits = trait_dict[size_category]
|
||||
# Convert data types for the current dtype_pair
|
||||
q_type, k_type, v_type, o_type = dtype_pair.split(',')
|
||||
current_traits = []
|
||||
for t in traits:
|
||||
new_t = copy.copy(t)
|
||||
new_t.F_SaccDataType = 'fp32' # accumulation in fp32
|
||||
new_t.F_SMPLComputeDataType = 'fp32' # softmax compute in fp32
|
||||
new_t.F_PDataType = q_type
|
||||
new_t.F_OaccDataType = 'fp32' # output accumulation in fp32
|
||||
current_traits.append(new_t)
|
||||
|
||||
total_blob.append(h_instance(dtype_pair, size_category, current_traits))
|
||||
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self, args) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
list_p = w_p / 'flash_attention_fwd_blobs.txt'
|
||||
blobs = self.get_blobs(args)
|
||||
|
||||
with list_p.open('w') as list_f:
|
||||
# API related files
|
||||
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
||||
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
|
||||
# Kernel instance files
|
||||
for b in blobs:
|
||||
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
|
||||
|
||||
def gen_blobs(self, args) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
w_str = self.content_api(args)
|
||||
(w_p / (self.name_api + ".cpp")).write_text(w_str)
|
||||
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
|
||||
|
||||
blobs = self.get_blobs(args)
|
||||
for b in blobs:
|
||||
(w_p / (b.name + ".cpp")).write_text(b.content)
|
||||
|
||||
def list_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
FlashAttentionFwdCodegen(args.working_path, args.filter).list_blobs(args)
|
||||
|
||||
def gen_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
FlashAttentionFwdCodegen(args.working_path, args.filter).gen_blobs(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for Flash Attention kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--api",
|
||||
default='fwd',
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="the path where all the blobs are going to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action='store_true',
|
||||
help="list all the kernels to a file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action='store_true',
|
||||
help="generate all kernels into different tile"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
|
||||
print('gen_blobs/list_blobs must specify only one option')
|
||||
sys.exit()
|
||||
|
||||
p = Path(args.working_path)
|
||||
if not p.exists():
|
||||
p.mkdir()
|
||||
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
void reference_batched_gemm(const ck_tile::HostTensor<ADataType>& a_b_m_k,
|
||||
const ck_tile::HostTensor<BDataType>& b_b_n_k,
|
||||
ck_tile::HostTensor<CDataType>& c_b_m_n)
|
||||
{
|
||||
const int N = b_b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = b_b_n_k.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_b_m_k(batch, m, k);
|
||||
BDataType v_b = b_b_n_k(batch, n, k);
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(v_acc);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename ADataType, typename AccDataType, typename BDataType>
|
||||
void reference_batched_softmax(const ck_tile::HostTensor<ADataType>& a_b_m_n,
|
||||
ck_tile::HostTensor<BDataType>& b_b_m_n)
|
||||
{
|
||||
const int N = a_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
AccDataType v_max = std::numeric_limits<ADataType>::lowest();
|
||||
|
||||
// max
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ADataType v_a = a_b_m_n(batch, m, n);
|
||||
|
||||
v_max = v_max < v_a ? v_a : v_max;
|
||||
}
|
||||
|
||||
AccDataType v_exp_sum = 0;
|
||||
|
||||
// sum
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ADataType v_a = a_b_m_n(batch, m, n);
|
||||
|
||||
v_exp_sum += ck_tile::exp(v_a - v_max);
|
||||
}
|
||||
|
||||
// elementwise
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ADataType v_a = a_b_m_n(batch, m, n);
|
||||
|
||||
b_b_m_n(batch, m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
9
example/ck_tile/tutorial/CMakeLists.txt
Normal file
9
example/ck_tile/tutorial/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
include_directories(AFTER
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
add_subdirectory(00_add_basic)
|
||||
add_subdirectory(01_add)
|
||||
add_subdirectory(02_gemm)
|
||||
add_subdirectory(03_flash_attention_fwd)
|
||||
add_subdirectory(04_codegen_flash_attention_fwd)
|
||||
112
example/ck_tile/tutorial/README.md
Normal file
112
example/ck_tile/tutorial/README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
|
||||
|
||||
# CK_TILE Toy Example
|
||||
|
||||
This repository demonstrates a toy example implemented using ck_tile
|
||||
|
||||
## Build Instructions
|
||||
|
||||
Follow these steps to build the examples:
|
||||
|
||||
```sh
|
||||
cd composable_kernel
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
-Dkernel=N ..
|
||||
```
|
||||
|
||||
### Compile Examples
|
||||
|
||||
#### **Elementwise Add Example**
|
||||
```sh
|
||||
make -j add
|
||||
```
|
||||
|
||||
#### **GEMM Example**
|
||||
```sh
|
||||
make -j basic_gemm
|
||||
```
|
||||
|
||||
#### **Flash Attention Forward Example**
|
||||
```sh
|
||||
make -j basic_flash_attention_fwd
|
||||
```
|
||||
|
||||
## Running Examples
|
||||
|
||||
### **Elementwise Add**
|
||||
```sh
|
||||
./bin/add
|
||||
```
|
||||
|
||||
### **GEMM Example**
|
||||
```sh
|
||||
./bin/basic_gemm 1
|
||||
```
|
||||
|
||||
### **Flash Attention Forward Example**
|
||||
```sh
|
||||
./bin/basic_flash_attention_fwd 1 1
|
||||
```
|
||||
|
||||
## Advanced part
|
||||
#### **GEMM Example**
|
||||
##### Follow these steps to build and run the different kernels:
|
||||
```sh
|
||||
|
||||
cd composable_kernel
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
# for naive kernel
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=N .. && make -j basic_gemm
|
||||
|
||||
# for kernel A
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=A .. && make -j basic_gemm
|
||||
|
||||
# for kernel B
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=B .. && make -j basic_gemm
|
||||
|
||||
...
|
||||
|
||||
# for kernel H
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=H .. && make -j basic_gemm
|
||||
|
||||
```
|
||||
|
||||
```sh
|
||||
./bin/basic_gemm 1
|
||||
```
|
||||
|
||||
#### **Flash Attention Forward Example**
|
||||
##### Follow these steps to build the kernels
|
||||
|
||||
```sh
|
||||
# for naive kernel
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" .. && make -j basic_flash_attention_fwd
|
||||
|
||||
# for optimized kernel
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -DENABLE_TOY_FA_FWD_OPT=ON .. && make -j basic_flash_attention_fwd
|
||||
```
|
||||
```sh
|
||||
./bin/basic_flash_attention_fwd 1 1
|
||||
```
|
||||
|
||||
|
||||
##### Follow these steps to build the codegen instances
|
||||
|
||||
```sh
|
||||
mkdir build
|
||||
cd build
|
||||
../script/cmake-ck-release.sh .. gfx942
|
||||
make -j codegen_basic_flash_attention_fwd
|
||||
```
|
||||
|
||||
```sh
|
||||
./bin/codegen_basic_flash_attention_fwd 1 1 64 16384 16384 128 128
|
||||
```
|
||||
Reference in New Issue
Block a user