Use tensor_descriptor to locate LSEacc elements

This commit is contained in:
PoYen, Chen
2024-06-12 02:32:33 +00:00
parent ec82f3bbd6
commit b994668714
2 changed files with 63 additions and 10 deletions

View File

@@ -67,8 +67,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
/// TODO: add padding to avoid bank conflict
return (kM0 * kMaxSplits * sizeof(LSEDataType));
return Policy::template GetSmemSize<Problem>();
}
template <typename LSEaccDramBlockWindowTmp,
@@ -87,8 +86,17 @@ struct BlockFmhaFwdSplitKVCombinePipeline
index_t max_seqlen_q,
void* smem_ptr) const
{
// LSEacc tile in LDS
LSEDataType* lse_acc_lds_ptr =
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
#if 0
auto lse_acc_lds = make_tensor_view<address_space_enum::lds>(
lse_acc_lds_ptr, Policy::template MakeLSEaccLdsBlockDescriptor<Problem>());
auto lse_acc_lds_window =
make_tile_window(lse_acc_lds, make_tuple(number<kM0>{}, number<kMaxSplits>{}), {0, 0});
#endif
auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>();
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
auto lse_acc_dram_window =
@@ -116,13 +124,15 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(row, col));
if(row < num_splits && col < seqlen_q)
{
lse_acc_lds_ptr[row + col * kMaxSplits] = lse_acc(distributed_indices);
lse_acc_lds_ptr[offset] = lse_acc(distributed_indices);
}
else
{
lse_acc_lds_ptr[row + col * kMaxSplits] = -numeric<LSEDataType>::infinity();
lse_acc_lds_ptr[offset] = -numeric<LSEDataType>::infinity();
}
});
});
@@ -148,9 +158,11 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
if(col < num_splits)
{
lse_accum(distributed_indices) = lse_acc_lds_ptr[col + row * kMaxSplits];
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
}
else
{
@@ -190,11 +202,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
if(col < num_splits)
{
// from shared memory
p_compute(i_j_idx) = ck_tile::exp(lse_acc_lds_ptr[col + row * kMaxSplits] -
get_validated_m(lse_max(i_idx)));
p_compute(i_j_idx) =
ck_tile::exp(lse_acc_lds_ptr[offset] - get_validated_m(lse_max(i_idx)));
}
});
});
@@ -240,8 +254,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
for(index_t col = 0; col < num_splits; ++col)
{
lse_acc_lds_ptr[col + row * kMaxSplits] = ck_tile::exp(
lse_acc_lds_ptr[col + row * kMaxSplits] - lse_logsum(distributed_indices));
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
lse_acc_lds_ptr[offset] =
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices));
}
});
}
@@ -291,7 +307,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
LSEDataType lse_scale = lse_acc_lds_ptr[i_split + row * kMaxSplits];
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(i_split, row));
LSEDataType lse_scale = lse_acc_lds_ptr[offset];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
});
});

View File

@@ -31,6 +31,13 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return 16 / sizeof(ODataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
@@ -73,6 +80,33 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
}
}
// 3d + padding
// [kMaxSplits x kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
lse_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lse_acc_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccTDramTileDistribution()
{