mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Use tensor_descriptor to locate LSEacc elements
This commit is contained in:
@@ -67,8 +67,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
/// TODO: add padding to avoid bank conflict
|
||||
return (kM0 * kMaxSplits * sizeof(LSEDataType));
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename LSEaccDramBlockWindowTmp,
|
||||
@@ -87,8 +86,17 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
index_t max_seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
// LSEacc tile in LDS
|
||||
LSEDataType* lse_acc_lds_ptr =
|
||||
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
#if 0
|
||||
auto lse_acc_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_acc_lds_ptr, Policy::template MakeLSEaccLdsBlockDescriptor<Problem>());
|
||||
auto lse_acc_lds_window =
|
||||
make_tile_window(lse_acc_lds, make_tuple(number<kM0>{}, number<kMaxSplits>{}), {0, 0});
|
||||
#endif
|
||||
|
||||
auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
|
||||
auto lse_acc_dram_window =
|
||||
@@ -116,13 +124,15 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(row, col));
|
||||
if(row < num_splits && col < seqlen_q)
|
||||
{
|
||||
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
|
||||
lse_acc_lds_ptr[offset] = lse_acc(distributed_indices);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_acc_lds_ptr[row + col * kMaxSplits] = -numeric<LSEDataType>::infinity();
|
||||
lse_acc_lds_ptr[offset] = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -148,9 +158,11 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
|
||||
if(col < num_splits)
|
||||
{
|
||||
lse_accum(distributed_indices) = lse_acc_lds_ptr[col + row * kMaxSplits];
|
||||
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -190,11 +202,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
|
||||
if(col < num_splits)
|
||||
{
|
||||
// from shared memory
|
||||
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
|
||||
get_validated_m(lse_max(i_idx)));
|
||||
p_compute(i_j_idx) =
|
||||
ck_tile::exp(lse_acc_lds_ptr[offset] - get_validated_m(lse_max(i_idx)));
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -240,8 +254,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
for(index_t col = 0; col < num_splits; ++col)
|
||||
{
|
||||
lse_acc_lds_ptr[col + row * kMaxSplits] = ck_tile::exp(
|
||||
lse_acc_lds_ptr[col + row * kMaxSplits] - lse_logsum(distributed_indices));
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
|
||||
lse_acc_lds_ptr[offset] =
|
||||
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices));
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -291,7 +307,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(i_split, row));
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[offset];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -31,6 +31,13 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
return 16 / sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return sizeof(typename Problem::LSEDataType) *
|
||||
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
|
||||
{
|
||||
@@ -73,6 +80,33 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
}
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
// [kMaxSplits x kM0]
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
|
||||
{
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kMaxSplits;
|
||||
constexpr index_t kNPerBlock = Problem::kM0;
|
||||
constexpr index_t NPack = 16 / sizeof(LSEDataType);
|
||||
|
||||
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
|
||||
lse_acc_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lse_acc_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user