mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Fix register spilling and K0 tile size issues
This commit is contained in:
@@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
|
||||
@@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
|
||||
q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
|
||||
@@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl
|
||||
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram, make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}), {iM0, 0});
|
||||
q_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kHeadDim>{}),
|
||||
{iM0, 0},
|
||||
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
|
||||
@@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
|
||||
// Q in Register
|
||||
auto q_reg_tensor = make_static_distributed_tensor<QDataType>(
|
||||
BlockGemm0Policy::template MakeARegBlockDescriptor<BlockGemm0Problem>());
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
// V LDS and LDS window
|
||||
// V LDS occupies the same LDS allocation Q/K LDS
|
||||
@@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl
|
||||
// loop over Column of S (J loop)
|
||||
index_t iN0 = 0;
|
||||
|
||||
// Cold Q_Reg_Cache
|
||||
s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr);
|
||||
do
|
||||
{
|
||||
// Hot Q_Reg_Cache
|
||||
if(iN0 > 0)
|
||||
{
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
}
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
|
||||
// S{j}
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
@@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
|
||||
@@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
|
||||
q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto k_dram = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{});
|
||||
@@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl
|
||||
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram, make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}), {iM0, 0});
|
||||
q_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kHeadDim>{}),
|
||||
{iM0, 0},
|
||||
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
|
||||
@@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
|
||||
// Q in Register
|
||||
auto q_reg_tensor = make_static_distributed_tensor<QDataType>(
|
||||
BlockGemm0Policy::template MakeARegBlockDescriptor<BlockGemm0Problem>());
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
// V LDS and LDS window
|
||||
// V LDS occupies the same LDS allocation Q/K LDS
|
||||
@@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl
|
||||
// loop over Column of S (J loop)
|
||||
index_t iN0 = 0;
|
||||
|
||||
// Cold Q_Reg_Cache
|
||||
s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr);
|
||||
do
|
||||
{
|
||||
// Hot Q_Reg_Cache
|
||||
if(iN0 > 0)
|
||||
{
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
}
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
|
||||
// S{j}
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
Reference in New Issue
Block a user