Fix register spilling and K0 tile size issues

This commit is contained in:
MHYang
2025-04-18 10:15:17 +00:00
committed by Philip Maybank
parent eb737b8f82
commit 4a264eb9ed
6 changed files with 32 additions and 28 deletions

View File

@@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
ignore = a_element_func;

View File

@@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
return a_block_dstr;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeADramTileDistribution()
{
return MakeARegBlockDescriptor<Problem>();
}
};
} // namespace ck_tile

View File

@@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl
// Q/K/V DRAM and DRAM window
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
@@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
auto q_dram_window = make_tile_window(
q_dram, make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}), {iM0, 0});
q_dram,
make_tuple(number<kM0PerBlock>{}, number<kHeadDim>{}),
{iM0, 0},
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
@@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl
{iN1, 0},
MakeVDramTileDistribution());
// Q in Register
auto q_reg_tensor = make_static_distributed_tensor<QDataType>(
BlockGemm0Policy::template MakeARegBlockDescriptor<BlockGemm0Problem>());
// Q in register
auto q_reg_tensor = load_tile(q_dram_window);
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
@@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl
// loop over Column of S (J loop)
index_t iN0 = 0;
// Cold Q_Reg_Cache
s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr);
do
{
// Hot Q_Reg_Cache
if(iN0 > 0)
{
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
}
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
// S{j}
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);

View File

@@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
ignore = a_element_func;

View File

@@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
return a_block_dstr;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeADramTileDistribution()
{
return MakeARegBlockDescriptor<Problem>();
}
};
} // namespace ck_tile

View File

@@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl
// Q/K/V DRAM and DRAM window
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
@@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
auto q_dram_window = make_tile_window(
q_dram, make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}), {iM0, 0});
q_dram,
make_tuple(number<kM0PerBlock>{}, number<kHeadDim>{}),
{iM0, 0},
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
@@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl
{iN1, 0},
MakeVDramTileDistribution());
// Q in Register
auto q_reg_tensor = make_static_distributed_tensor<QDataType>(
BlockGemm0Policy::template MakeARegBlockDescriptor<BlockGemm0Problem>());
// Q in register
auto q_reg_tensor = load_tile(q_dram_window);
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
@@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl
// loop over Column of S (J loop)
index_t iN0 = 0;
// Cold Q_Reg_Cache
s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr);
do
{
// Hot Q_Reg_Cache
if(iN0 > 0)
{
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
}
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
// S{j}
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);