Fix flash attention 1 tile case

This commit is contained in:
mhYang
2025-04-11 15:44:53 +00:00
parent bfadc59277
commit 44eaa337f6
2 changed files with 72 additions and 446 deletions

View File

@@ -10,200 +10,6 @@
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BlockGemmPipelineAGmemBGmemCReg<Problem, BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy>
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
// Move this part into Policy?
__host__ __device__ static constexpr index_t GetStaticLdsSize()
{
return sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// A tile in RegblockTensor
// This tensor distribution used to construct both distributed tensor for local buffer store
// and read. without buffer address info
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
// B tile in LDS, blockWindow
BDataType* p_b_lds =
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
// This tensor view used to construct both tile window for lds store and read, with buffer
// address info
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){};
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// block buffer write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
// store_tile -> shuffle store tile
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
// global read 1
a_block_tile = load_tile(a_copy_dram_window);
// LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
// global read 1
b_block_tile = load_tile(b_copy_dram_window);
}
index_t iCounter = num_loop - 2;
do
{
block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
// global read i + 2
a_block_tile = load_tile(a_copy_dram_window);
// LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
// global read i + 2
b_block_tile = load_tile(b_copy_dram_window);
iCounter--;
} while(iCounter > 0);
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
block_sync_lds();
// LDS write num_loop - 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
}
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
@@ -248,12 +54,12 @@ struct BlockGemmPipelineAGmemBGmemCReg<
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
ignore = a_element_func;
@@ -312,12 +118,13 @@ struct BlockGemmPipelineAGmemBGmemCReg<
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(k_loops > 1)
{
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, 0>{},
@@ -327,6 +134,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
}
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
@@ -334,7 +142,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (i_k0)*kKPerBlock>{},
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
@@ -356,15 +164,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
// tail
{
block_sync_lds();
if constexpr(k_loops > 1)
{
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
block_sync_lds();
}
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
@@ -378,11 +189,10 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
b_copy_lds_window);
}
// store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor);
set_slice_tile(a_reg_block_tensor_tmp,
a_copy_reg_tensor,
sequence<0, 0>{},
@@ -402,7 +212,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
"wrong!");
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
ignore = b_element_func;
@@ -414,7 +224,6 @@ struct BlockGemmPipelineAGmemBGmemCReg<
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
// store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp);
set_slice_tile(a_copy_reg_tensor,
a_reg_block_tensor_tmp,
@@ -458,14 +267,16 @@ struct BlockGemmPipelineAGmemBGmemCReg<
b_lds_gemm_window)){};
auto b_block_tile = load_tile(b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(k_loops > 1)
{
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
}
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
@@ -473,7 +284,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (i_k0)*kKPerBlock>{},
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
@@ -488,16 +299,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
// tail
{
block_sync_lds();
if constexpr(k_loops > 1)
{
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
}
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
@@ -505,7 +318,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
b_copy_lds_window);
}

View File

@@ -10,200 +10,6 @@
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BlockGemmPipelineAGmemBGmemCReg<Problem, BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy>
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
// Move this part into Policy?
__host__ __device__ static constexpr index_t GetStaticLdsSize()
{
return sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
__host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// A tile in RegblockTensor
// This tensor distribution used to construct both distributed tensor for local buffer store
// and read. without buffer address info
constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor<Problem>();
// B tile in LDS, blockWindow
BDataType* p_b_lds =
static_cast<BDataType*>(static_cast<void*>(static_cast<char*>(p_smem)));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
// This tensor view used to construct both tile window for lds store and read, with buffer
// address info
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){};
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// block buffer write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
// store_tile -> shuffle store tile
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
// global read 1
a_block_tile = load_tile(a_copy_dram_window);
// LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
// global read 1
b_block_tile = load_tile(b_copy_dram_window);
}
index_t iCounter = num_loop - 2;
do
{
block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
// global read i + 2
a_block_tile = load_tile(a_copy_dram_window);
// LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
// global read i + 2
b_block_tile = load_tile(b_copy_dram_window);
iCounter--;
} while(iCounter > 0);
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
block_sync_lds();
// LDS write num_loop - 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_reg_tensor, a_block_tile_tmp);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window);
}
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
__device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
@@ -248,12 +54,12 @@ struct BlockGemmPipelineAGmemBGmemCReg<
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
ignore = a_element_func;
@@ -312,12 +118,13 @@ struct BlockGemmPipelineAGmemBGmemCReg<
auto a_block_tile = load_tile(a_copy_dram_window);
auto b_block_tile = load_tile(b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(k_loops > 1)
{
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
sequence<0, 0>{},
@@ -327,6 +134,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
}
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
@@ -334,7 +142,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (i_k0)*kKPerBlock>{},
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
@@ -356,15 +164,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
// tail
{
block_sync_lds();
if constexpr(k_loops > 1)
{
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
block_sync_lds();
}
set_slice_tile(a_copy_reg_tensor,
a_block_tile,
@@ -378,11 +189,10 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
b_copy_lds_window);
}
// store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor);
set_slice_tile(a_reg_block_tensor_tmp,
a_copy_reg_tensor,
sequence<0, 0>{},
@@ -402,7 +212,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
"wrong!");
static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
ignore = b_element_func;
@@ -414,7 +224,6 @@ struct BlockGemmPipelineAGmemBGmemCReg<
// A Reg tensor for store, also used for block GEMM
auto a_copy_reg_tensor = make_static_distributed_tensor<ADataType>(a_reg_block_dstr);
// store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp);
set_slice_tile(a_copy_reg_tensor,
a_reg_block_tensor_tmp,
@@ -458,14 +267,16 @@ struct BlockGemmPipelineAGmemBGmemCReg<
b_lds_gemm_window)){};
auto b_block_tile = load_tile(b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(k_loops > 1)
{
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
b_block_tile = load_tile(b_copy_dram_window);
}
if constexpr(k_loops > 2)
{
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
@@ -473,7 +284,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (i_k0)*kKPerBlock>{},
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
@@ -488,16 +299,18 @@ struct BlockGemmPipelineAGmemBGmemCReg<
// tail
{
block_sync_lds();
if constexpr(k_loops > 1)
{
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
}
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
@@ -505,7 +318,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops)*kKPerBlock>{}),
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
b_copy_lds_window);
}