mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Fix flash attention 1 tile case
This commit is contained in:
@@ -10,200 +10,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg<Problem, BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy>
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
// Move this part into Policy?
|
||||
__host__ __device__ static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
|
||||
|
||||
// B tile in LDS, blockWindow
|
||||
BDataType* p_b_lds =
|
||||
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
// This tensor view used to construct both tile window for lds store and read, with buffer
|
||||
// address info
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){};
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// block buffer write 0
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
// store_tile -> shuffle store tile
|
||||
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
|
||||
// global read 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// LDS write 0
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
// global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 2;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
|
||||
// global read i + 2
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// LDS write i + 1
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
// global read i + 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
iCounter--;
|
||||
|
||||
} while(iCounter > 0);
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
|
||||
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
@@ -248,12 +54,12 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
@@ -312,12 +118,13 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, 0>{},
|
||||
@@ -327,6 +134,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -334,7 +142,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (i_k0)*kKPerBlock>{},
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
@@ -356,15 +164,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
@@ -378,11 +189,10 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
}
|
||||
|
||||
// store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor);
|
||||
set_slice_tile(a_reg_block_tensor_tmp,
|
||||
a_copy_reg_tensor,
|
||||
sequence<0, 0>{},
|
||||
@@ -402,7 +212,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
"wrong!");
|
||||
|
||||
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = b_element_func;
|
||||
@@ -414,7 +224,6 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
// store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp);
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_reg_block_tensor_tmp,
|
||||
@@ -458,14 +267,16 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -473,7 +284,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (i_k0)*kKPerBlock>{},
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
@@ -488,16 +299,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -505,7 +318,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,200 +10,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg<Problem, BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy>
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
// Move this part into Policy?
|
||||
__host__ __device__ static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
|
||||
|
||||
// B tile in LDS, blockWindow
|
||||
BDataType* p_b_lds =
|
||||
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
// This tensor view used to construct both tile window for lds store and read, with buffer
|
||||
// address info
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){};
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// block buffer write 0
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
// store_tile -> shuffle store tile
|
||||
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
|
||||
// global read 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// LDS write 0
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
// global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 2;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
|
||||
// global read i + 2
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// LDS write i + 1
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
// global read i + 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
iCounter--;
|
||||
|
||||
} while(iCounter > 0);
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
|
||||
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
@@ -248,12 +54,12 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
@@ -312,12 +118,13 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
sequence<0, 0>{},
|
||||
@@ -327,6 +134,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -334,7 +142,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (i_k0)*kKPerBlock>{},
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
@@ -356,15 +164,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_block_tile,
|
||||
@@ -378,11 +189,10 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
}
|
||||
|
||||
// store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor);
|
||||
set_slice_tile(a_reg_block_tensor_tmp,
|
||||
a_copy_reg_tensor,
|
||||
sequence<0, 0>{},
|
||||
@@ -402,7 +212,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
"wrong!");
|
||||
|
||||
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = b_element_func;
|
||||
@@ -414,7 +224,6 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// A Reg tensor for store, also used for block GEMM
|
||||
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
|
||||
// store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp);
|
||||
|
||||
set_slice_tile(a_copy_reg_tensor,
|
||||
a_reg_block_tensor_tmp,
|
||||
@@ -458,14 +267,16 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -473,7 +284,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (i_k0)*kKPerBlock>{},
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
@@ -488,16 +299,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -505,7 +318,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user