basic gemm softmax

This commit is contained in:
huizzhan
2025-08-05 04:26:23 +00:00
parent 2a78da4708
commit edf79c7064
13 changed files with 1966 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
set(EXAMPLE_REDUCE "tile_example_basic_gemm_softmax")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_REDUCE}")
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL gemm_softmax.cpp)
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
# generate assembly
# list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
if(DEFINED kernel)
message("Compiling with Kernel: ${kernel}")
target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel}=1)
endif()
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,58 @@
# CK_TILE Toy Example
This repository demonstrates a toy example implemented using ck_tile
## Build Instructions
Follow these steps to build the examples:
```sh
cd composable_kernel
mkdir build
cd build
cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx942" \
-Dkernel=N ..
```
### Compile Examples
#### **GEMM Softmax Example**
```sh
make -j128 tile_example_basic_gemm_softmax
```
## Running Examples
### **GEMM Softmax Example**
```sh
./bin/tile_example_basic_gemm_softmax 1 4096 256 7168
```
## Advanced part
#### **GEMM Example**
##### Follow these steps to build and run the different kernels:
```sh
cd composable_kernel
mkdir build
cd build
# for naive kernel
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=N .. && make -j128 tile_example_basic_gemm_softmax
# for kernel A
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=A .. && make -j128 tile_example_basic_gemm_softmax
# for kernel B
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=B .. && make -j128 tile_example_basic_gemm_softmax
...
# for kernel H
cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=H .. && make -j128 tile_example_basic_gemm_softmax

View File

@@ -0,0 +1,372 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "block_gemm_asmem_bsmem_creg_default_policy.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegDefaultPolicy>
struct BlockGemmASmemBSmemCReg
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using WarpGemm = remove_cvref_t<
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
static constexpr index_t MWarp =
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
static constexpr index_t NWarp =
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
#if defined(ENABLE_PREFETCH)
// A block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM);
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
// B block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN);
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile aWarpTile;
BLdsTile bWarpTile;
// Prefetch from LDS to warp register
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
aWarpTile = load_tile(a_block_window);
bWarpTile = load_tile(b_block_window);
}
#endif
// C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
#if !defined(ENABLE_PREFETCH)
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// Construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
a_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// Read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_PREFETCH)
#pragma message("local data share prefetch")
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
#else
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
#else
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// Read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// Warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// Write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
// C = A * B
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
#if !defined(ENABLE_PREFETCH)
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// Construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
a_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
b_block_window_tmp.get_window_origin().at(number<1>{})},
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
// Construct C-Block-Tensor
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
// Hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// Read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_PREFETCH)
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
#else
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
#else
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// Read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// Warp GEMM
if constexpr(KIterPerWarp == 0)
{
// c = a * b
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
// c += a * b
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
// Write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,100 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "config.h"
namespace ck_tile {
// Default policy for BlockGemmASmemBSmemCReg
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmASmemBSmemCRegDefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if defined(ADJUST_BLOCK_TILE_SHAPE)
constexpr index_t kMWarp = 2;
constexpr index_t kNWarp = 2;
#else
constexpr index_t kMWarp = 4;
constexpr index_t kNWarp = 1;
#endif
#if defined(NAIVE_IMPLEMENTATION)
#pragma message("mfma m32 n32 k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
}
#elif defined(USING_MFMA_32x32x_8x2)
#pragma message("mfma m32 n32 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
}
#elif defined(USING_MFMA_16x16x16)
#pragma message("mfma m16 n16 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
}
#elif defined(USING_MFMA_16x16x_16x2)
#pragma message("mfma m16 n16 k32")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
}
#endif
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,464 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = ck_tile::BlockGemmPipelineAGmemBGmemCRegDefaultPolicy>
struct BlockGemmSoftmaxPipelineAGmemBGmemCReg
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = float;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
#if defined(ENABLE_INSTRUCTION_SCH)
static constexpr index_t kPackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetSmemPack() { return Policy::template GetSmemPack<Problem>(); }
static constexpr bool HasHotLoop = Problem::HasHotLoop;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemm::MWarp;
constexpr index_t WaveNumN = BlockGemm::NWarp;
constexpr index_t AB_LDS_RW_Width = GetSmemPack();
constexpr index_t A_Buffer_Load_Inst_Num =
kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t B_LDS_Write_Inst_Num =
kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock /
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(CDataType) / sizeof(ADataType) >
// sizeof(CDataType) /
// sizeof(BDataType)
// ? sizeof(CDataType) /
// sizeof(ADataType) : sizeof(CDataType)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
constexpr auto num_mfma_per_dswrite_a =
(num_mfma_per_issue - num_dswrite_per_issue_a * 2 >= 1) ? 2 : 1;
constexpr auto num_mfma_per_dswrite_b =
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_mfma_per_dswrite_a *
num_dswrite_per_issue_a,
0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008,
num_mfma_per_issue - num_mfma_per_dswrite_b *
num_dswrite_per_issue_b,
0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
#endif
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// -----------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
#if defined(ENABLE_PREFETCH)
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()));
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()));
#else
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
#endif
// Block GEMM
auto block_gemm = BlockGemm();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
// -------------------------------------------------------------------------------------
// Gemm pipeline start
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
#if defined(ENABLE_PREFETCH)
#pragma message("global prefetch")
// Prefetch
// Global read 0
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
if(num_loop > 1)
{
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// LDS write 0
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
// Global read 1
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
// Prefetch from LDS to warp register in block gemm
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
}
__builtin_amdgcn_sched_barrier(0);
// Main body
if(num_loop > 2)
{
index_t iCounter = 0;
do
{
block_sync_lds();
// LDS write 1
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
// Global read 2
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// Prefetch from LDS to warp register in block gemm
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
#if defined(ENABLE_INSTRUCTION_SCH)
HotLoopScheduler();
#endif
__builtin_amdgcn_sched_barrier(0);
iCounter += 1;
} while(iCounter < (num_loop - 2));
}
// Tail
if(num_loop > 1)
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
}
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
#else
// non-prefetch
index_t iCounter = num_loop;
while(iCounter > 0)
{
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
iCounter--;
}
#endif
// apply softmax for c_block_tile
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// m_local = rowmax(c_block_tile)
auto m_local = block_tile_reduce<ComputeDataType>(
c_block_tile, sequence<1>{}, f_max, std::numeric_limits<ComputeDataType>::lowest());
block_tile_reduce_sync(m_local, f_max);
// Pcompute{j} = sum(exp(x - m_local))
auto p_compute =
make_static_distributed_tensor<ComputeDataType>(c_block_tile.get_tile_distribution());
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = exp(c_block_tile[i_j_idx] - m_local[i_idx]);
});
});
// rowsum for p_compute{i, j}
auto rowsum_p = block_tile_reduce<ComputeDataType>(
p_compute, sequence<1>{}, f_sum, ComputeDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum);
// softmax = p_compute{i, j} / rowsum_p
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = p_compute[i_j_idx] / rowsum_p[i_idx];
});
});
// CDramBlockWindowTmp c_dram_block_window_tmp = c_dram_block_window;
// store_tile(c_dram_block_window_tmp, type_convert<CDataType>(p_compute));
return p_compute;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,352 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "block_gemm_asmem_bsmem_creg.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "config.h"
namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCReg
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
{
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
#if defined(NAIVE_IMPLEMENTATION)
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_K_FIRST)
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_MN_FIRST)
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
number<kMPerBlock / MLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
return a_lds_block_desc;
}
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = 8;
#if defined(PADDING_K_FIRST) || defined(NAIVE_IMPLEMENTATION)
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_K_FIRST)
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(PADDING_MN_FIRST)
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
number<kNPerBlock / NLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
number<kKPerBlock / kKPack * NLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
#if defined(ENABLE_INSTRUCTION_SCH)
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{
return (PackedSize * 32 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (PackedSize * 16 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
{
return (PackedSize * 8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
{
return (PackedSize * 4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
{
return (PackedSize * 2 / sizeof(DataType));
}
else
{
return PackedSize;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPack()
{
constexpr index_t kKPack = 8;
return kKPack;
}
#endif
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
return BlockGemmASmemBSmemCReg<Problem>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,38 @@
#if defined(KERNEL_A)
#define PADDING_K_FIRST
#define USING_MFMA_32x32x_8x2
#elif defined(KERNEL_B)
#define PADDING_K_FIRST
#define USING_MFMA_16x16x16
#elif defined(KERNEL_C)
#define PADDING_K_FIRST
#define USING_MFMA_16x16x_16x2
#elif defined(KERNEL_D)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#elif defined(KERNEL_E)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#elif defined(KERNEL_F)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#elif defined(KERNEL_G)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#define ENABLE_INSTRUCTION_SCH
#elif defined(KERNEL_H)
#define USING_MFMA_16x16x_16x2
#define USING_XOR_BASED_BANK_CONFLICT_FREE
#define ADJUST_BLOCK_TILE_SHAPE
#define ENABLE_PREFETCH
#define ENABLE_INSTRUCTION_SCH
#define ENABLE_CACHE_AWARE_WG_SCH
#else
#define NAIVE_IMPLEMENTATION
#endif

View File

@@ -0,0 +1,195 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
#include "config.h"
#include "grid_gemm.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename CElementFunction_>
struct GridGemmProblem
{
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CDataType = CDataType_;
using CElementFunction = CElementFunction_;
};
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
struct TileGemmShape
{
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
};
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
// C = A * B
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename CElementFunction,
index_t kAAlignment,
index_t kBAlignment,
index_t kCAlignment,
index_t kBlockSize_,
index_t kMPerBlock_,
index_t kNPerBlock_,
index_t kKPerBlock_>
struct Gemm
{
using GridGemmProblem =
GridGemmProblem<ADataType, BDataType, AccDataType, CDataType, CElementFunction>;
struct GridGemmPolicy
{
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMPerBlock = kMPerBlock_;
static constexpr index_t kNPerBlock = kNPerBlock_;
static constexpr index_t kKPerBlock = kKPerBlock_;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
{
#if defined(ENABLE_CACHE_AWARE_WG_SCH)
return [=](index_t block_1d_id) {
constexpr index_t M01 = 4;
constexpr index_t GroupNum = 8;
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
const auto update_M0 =
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
const auto xcd_id = block_1d_id % GroupNum;
const auto l_block_id = block_1d_id - (xcd_id % 2);
const auto ridn = GroupNum * M01 * (update_N0 / 2);
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
const auto lu = (l_block_id % GroupNum) + rid * ridn;
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
const auto sub_M0_id =
(l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
const auto total_update_size = update_N0 * update_M0;
if(block_1d_id >= total_update_size)
{
auto x = (block_1d_id + 1) - total_update_size;
auto rlen = N0 - update_N0;
auto rm = 0;
auto rn = 0;
if(rlen > 0)
{
rm = (x - 1) / rlen;
rn = x % rlen;
}
if(rlen > 0 and rm < M0)
{
n = rn + update_N0;
m = rm;
}
else
{
x = x - rlen * M0;
rm = (x - 1) / update_N0;
rn = x % update_N0;
n = rn;
m = update_M0 + rm;
}
}
return make_multi_index(m, n);
};
#else
const auto unmerge = make_merge_transform(make_tuple(N0, M0));
return [unmerge](index_t block_id) {
multi_index<2> unmerged;
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
};
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
{
using BlockGemmPipelineProblem_ =
BlockGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
kBlockSize,
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
return BlockGemmSoftmaxPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
}
};
using GridGemm = GridGemm<GridGemmProblem, GridGemmPolicy>;
CK_TILE_DEVICE void operator()(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t Lda,
const index_t Ldb,
const index_t Ldc,
const CElementFunction& c_element_func) const
{
const auto a_dram = [&] {
return make_naive_tensor_view<address_space_enum::global>(
p_a, make_tuple(M, K), make_tuple(Lda, 1), number<kAAlignment>{}, number<1>{});
}();
const auto b_dram = [&] {
return make_naive_tensor_view<address_space_enum::global>(
p_b, make_tuple(N, K), make_tuple(Ldb, 1), number<kBAlignment>{}, number<1>{});
}();
const auto c_dram = [&] {
return make_naive_tensor_view<address_space_enum::global>(
p_c, make_tuple(M, N), make_tuple(Ldc, 1), number<kCAlignment>{}, number<1>{});
}();
GridGemm{}(a_dram, b_dram, c_dram, c_element_func);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,202 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include "config.h"
#include "ck_tile/host.hpp"
#include "gemm.hpp"
#include "reference_gemm.hpp"
/*
* Toy code of GEMM
* Assume simplest case.
* A [M, K]
* B [N, K]
* C [M, N]
*/
// elementwise lambda
struct CElementFunction
{
template <typename X>
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
{
return x;
}
};
int main(int argc, char* argv[])
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
ck_tile::index_t verification = 0;
ck_tile::index_t M = 3328;
ck_tile::index_t N = 4096;
ck_tile::index_t K = 4096;
if(argc == 2)
{
verification = std::stoi(argv[1]);
}
if(argc == 5)
{
verification = std::stoi(argv[1]);
M = std::stoi(argv[2]);
N = std::stoi(argv[3]);
K = std::stoi(argv[4]);
}
#if defined(KERNEL_A)
printf("*** Kernel A test *** \n");
printf(" --> Using mfma_32x32x(8x2)\n");
#elif defined(KERNEL_B)
printf("*** Kernel B test *** \n");
printf(" --> Using mfma_16x16x16\n");
#elif defined(KERNEL_C)
printf("*** Kernel C test *** \n");
printf(" --> Using mfma_16x16x(16x2)\n");
#elif defined(KERNEL_D)
printf("*** Kernel D test *** \n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
#elif defined(KERNEL_E)
printf("*** Kernel E test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
#elif defined(KERNEL_F)
printf("*** Kernel F test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
#elif defined(KERNEL_G)
printf("*** Kernel G test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
printf(" --> Enable instruction schedule\n");
#elif defined(KERNEL_H)
printf("*** Kernel H test ***\n");
printf(" --> Using mfma_16x16x(16x2)\n");
printf(" --> XOR-based bank-conflict-free\n");
printf(" --> Adjust block tile shape\n");
printf(" --> Enable prefetch\n");
printf(" --> Enable instruction schedule\n");
printf(" --> Enable cache-aware thread blocks schedule\n");
#else
printf("*** Naive implementation test ***\n");
#endif
const ck_tile::index_t Lda = K;
const ck_tile::index_t Ldb = K;
const ck_tile::index_t Ldc = N;
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
// host verify
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.mData.data());
b_buf.ToDevice(b_host.mData.data());
// Alignment
constexpr ck_tile::index_t kAAlignment = 8;
constexpr ck_tile::index_t kBAlignment = 8;
constexpr ck_tile::index_t kCAlignment = 8;
constexpr ck_tile::index_t kBlockSize = 256;
#ifdef ADJUST_BLOCK_TILE_SHAPE
constexpr ck_tile::index_t kGemmMPerBlock = 128;
constexpr ck_tile::index_t kGemmKPerBlock = 64;
#else
constexpr ck_tile::index_t kGemmMPerBlock = 128;
constexpr ck_tile::index_t kGemmKPerBlock = 16;
#endif
constexpr ck_tile::index_t kGemmNPerBlock = 256;
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
std::cout << "grid size " << kGridSize << std::endl;
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
using gemm_kernel = ck_tile::Gemm<ADataType,
BDataType,
AccDataType,
CDataType,
CElementFunction,
kAAlignment,
kBAlignment,
kCAlignment,
kBlockSize,
kGemmMPerBlock,
kGemmNPerBlock,
kGemmKPerBlock>;
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
gemm_kernel{},
kGridSize,
kBlockSize,
0,
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
M,
N,
K,
Lda,
Ldb,
Ldc,
CElementFunction{}));
auto pass = true;
if(verification)
{
// reference gemm
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
reference_basic_gemm_softmax<ADataType, ADataType, AccDataType, CDataType>(
a_host, b_host, c_host_ref);
c_buf.FromDevice(c_host_dev.mData.data());
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
return !pass;
}

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename Problem, typename Policy>
struct GridGemm
{
using ADataType = typename Problem::ADataType;
using BDataType = typename Problem::BDataType;
using CDataType = typename Problem::CDataType;
using AccDataType = typename Problem::AccDataType;
using ComputeDataType = float;
using CElementFunction = typename Problem::CElementFunction;
static constexpr auto kMPerBlock = Policy::kMPerBlock;
static constexpr auto kNPerBlock = Policy::kNPerBlock;
static constexpr auto kKPerBlock = Policy::kKPerBlock;
template <typename AGridTensorView, typename BGridTensorView, typename CGridTensorView>
CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid,
const BGridTensorView& b_grid,
CGridTensorView& c_grid,
const CElementFunction& c_element_func) const
{
const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{});
const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{});
const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{});
// divide problem
const auto id_block = get_block_id();
const auto num_tile_m = integer_divide_ceil(M, kMPerBlock);
const auto num_tile_n = integer_divide_ceil(N, kNPerBlock);
const auto block2tile = Policy::template MakeBlock2TileMap<Problem>(num_tile_m, num_tile_n);
const auto id_tile = block2tile(id_block);
const auto iM =
__builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock);
const auto iN =
__builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock);
// A block window
auto a_block_window = make_tile_window(
a_grid, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {iM, 0});
// B block window
auto b_block_window = make_tile_window(
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()];
// store C
auto c_window = make_tile_window(
c_grid, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, iN});
// block_gemm_pipeline(a_block_window, b_block_window, c_window, K / kKPerBlock, p_smem_char, c_element_func);
const auto acc_block_tile =
block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char);
// cast to CDataType and apply CElementFunction
const auto c_block_tile = tile_elementwise_in(
[&](const auto& acc) { return c_element_func(type_convert<CDataType>(acc)); },
acc_block_tile);
store_tile(c_window, c_block_tile);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_basic_gemm_softmax(const ck_tile::HostTensor<ADataType>& a_m_k,
const ck_tile::HostTensor<BDataType>& b_n_k,
ck_tile::HostTensor<CDataType>& c_m_n)
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_n_k(n, k);
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<AccDataType>(v_acc);
}
// reference softmax
AccDataType v_max = std::numeric_limits<ADataType>::lowest();
// max
for(int n = 0; n < N; ++n)
{
const AccDataType v_c = c_m_n(m, n);
v_max = v_max < v_c ? v_c : v_max;
}
AccDataType v_exp_sum = 0;
// sum
for(int n = 0; n < N; ++n)
{
const AccDataType v_c = c_m_n(m, n);
v_exp_sum += ck_tile::exp(v_c - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
{
const AccDataType v_c = c_m_n(m, n);
c_m_n(m, n) = ck_tile::exp(v_c - v_max) / v_exp_sum;
}
};
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
struct StreamConfig
{
hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false;
int log_level_ = 0;
};

1
example/ck_tile/CMakeLists.txt Normal file → Executable file
View File

@@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution)
add_subdirectory(21_elementwise)
add_subdirectory(35_batched_transpose)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(39_gemm_softmax)