mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Implement prefetch and instruction schedule
This commit is contained in:
@@ -306,11 +306,11 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message("global prefetch")
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message("global prefetch")
|
||||
// Prefetch
|
||||
// Global read 0
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -325,7 +325,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 0
|
||||
// Global read 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
@@ -347,11 +347,11 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 0
|
||||
// LDS write 1
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 0
|
||||
// Global read 2
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
@@ -387,18 +387,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
#else
|
||||
// non-prefetch
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
|
||||
@@ -10,6 +10,13 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF)
|
||||
if(ENABLE_TOY_FA_FWD_OPT)
|
||||
message("Compiling with toy FA fwd optimization")
|
||||
# target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT)
|
||||
add_definitions(-DTOY_FA_FWD_OPT)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
|
||||
@@ -29,6 +29,35 @@ struct BlockGemmARegBSmemCRegV1
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
@@ -118,6 +147,129 @@ struct BlockGemmARegBSmemCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator() (CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
|
||||
@@ -13,16 +13,13 @@ namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, index_t kHeadDim>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
Problem,
|
||||
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>>
|
||||
template <typename Problem, typename Policy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -222,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
ignore = b_element_func;
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
@@ -261,62 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
static_for<0, k_loops, 1>{}([&](auto i_k0) {
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Global read 0
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 0
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
// LDS read 1
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -324,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
}
|
||||
|
||||
#endif
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,42 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// NOTE: Assume A is K-Major
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_reg_block_descriptor<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
return policy_impl::make_b_lds_block_descriptor_3d_pad<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_dram_tile_distribution_skip_lds<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
return policy_impl::make_b_dram_tile_distribution<Problem>();
|
||||
}
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto GetBlockGemm()
|
||||
@@ -47,13 +20,7 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
: BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
{
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
@@ -93,11 +60,88 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace policy_impl {
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
transform_tensor_descriptor(a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_reg_block_descriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds()
|
||||
{
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K2 =
|
||||
WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane;
|
||||
// // 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_dram_tile_distribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto get_block_gemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy;
|
||||
|
||||
return BlockGemmASmemBSmemCReg<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
} // namespace policy_impl
|
||||
} // namespace ck_tile
|
||||
@@ -8,11 +8,11 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "flash_attention_fwd_impl.hpp"
|
||||
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
@@ -115,7 +115,8 @@ struct FlashAttentionFwd
|
||||
const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock);
|
||||
const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock);
|
||||
|
||||
#if defined(GEMM_OPT)
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
#pragma message("Enable toy FA fwd opt")
|
||||
const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1);
|
||||
|
||||
const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
@@ -152,6 +153,10 @@ struct FlashAttentionFwdImpl
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetStaticLdsSize()];
|
||||
|
||||
@@ -179,7 +184,6 @@ struct FlashAttentionFwdImpl
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
@@ -188,12 +192,22 @@ struct FlashAttentionFwdImpl
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
@@ -239,9 +253,10 @@ struct FlashAttentionFwdImpl
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// prefetch load v tile
|
||||
const auto v_prefetch = load_tile(v_dram_window);
|
||||
|
||||
auto v_prefetch = load_tile(v_dram_window);
|
||||
#endif
|
||||
// m_local = rowmax(S{j})
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
|
||||
@@ -291,10 +306,30 @@ struct FlashAttentionFwdImpl
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v_prefetch);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using VLdsTile = typename decltype(gemm1)::BLdsTile;
|
||||
VLdsTile vWarpTile;
|
||||
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
@@ -304,34 +339,60 @@ struct FlashAttentionFwdImpl
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
}
|
||||
if constexpr(k1_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v);
|
||||
|
||||
// LDS write 1
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
|
||||
// Global read 2
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr (k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
v_lds_window);
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
// move tile windows
|
||||
move_tile_window(k_dram_window, {kN0PerBlock, 0});
|
||||
iN0 += kN0PerBlock;
|
||||
|
||||
@@ -26,6 +26,250 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
constexpr index_t MPerXDL = WG::kM;
|
||||
constexpr index_t NPerXDL = WG::kN;
|
||||
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = config.template get<1>();
|
||||
|
||||
constexpr index_t B_LDS_RW_Width = SmemPack;
|
||||
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
|
||||
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// B split schedule
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
constexpr auto num_mfma_per_dswrite_b =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
|
||||
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_b *
|
||||
num_dswrite_per_issue_b,
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator() (CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
|
||||
@@ -13,16 +13,13 @@ namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, index_t kHeadDim>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
Problem,
|
||||
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>>
|
||||
template <typename Problem, typename Policy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -134,6 +131,8 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -158,6 +157,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -217,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
ignore = b_element_func;
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
@@ -256,58 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
static_for<0, k_loops, 1>{}([&](auto i_k0) {
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Global read 0
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 0
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
// LDS read 1
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -315,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
}
|
||||
|
||||
#endif
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,42 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// NOTE: Assume A is K-Major
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_reg_block_descriptor<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
return policy_impl::make_b_lds_block_descriptor_3d_pad<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_dram_tile_distribution_skip_lds<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
return policy_impl::make_b_dram_tile_distribution<Problem>();
|
||||
}
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto GetBlockGemm()
|
||||
@@ -47,13 +20,7 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
: BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
{
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
@@ -93,11 +60,88 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace policy_impl {
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
transform_tensor_descriptor(a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_reg_block_descriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds()
|
||||
{
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K2 =
|
||||
WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane;
|
||||
// // 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_dram_tile_distribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto get_block_gemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy;
|
||||
|
||||
return BlockGemmASmemBSmemCReg<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
} // namespace policy_impl
|
||||
} // namespace ck_tile
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "flash_attention_fwd_impl.hpp"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
@@ -152,6 +153,10 @@ struct FlashAttentionFwdImpl
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetStaticLdsSize()];
|
||||
|
||||
@@ -179,7 +184,6 @@ struct FlashAttentionFwdImpl
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
@@ -188,12 +192,22 @@ struct FlashAttentionFwdImpl
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
@@ -239,9 +253,10 @@ struct FlashAttentionFwdImpl
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// prefetch load v tile
|
||||
const auto v_prefetch = load_tile(v_dram_window);
|
||||
|
||||
auto v_prefetch = load_tile(v_dram_window);
|
||||
#endif
|
||||
// m_local = rowmax(S{j})
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
|
||||
@@ -291,10 +306,30 @@ struct FlashAttentionFwdImpl
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v_prefetch);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using VLdsTile = typename decltype(gemm1)::BLdsTile;
|
||||
VLdsTile vWarpTile;
|
||||
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
@@ -305,29 +340,59 @@ struct FlashAttentionFwdImpl
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
}
|
||||
if constexpr(k1_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v);
|
||||
|
||||
// LDS write 1
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
|
||||
// Global read 2
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr (k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
v_lds_window);
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
// move tile windows
|
||||
move_tile_window(k_dram_window, {kN0PerBlock, 0});
|
||||
iN0 += kN0PerBlock;
|
||||
|
||||
Reference in New Issue
Block a user