mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Add MakeBlock2TileMap in 04_codegen_flash_attention_fwd
This commit is contained in:
@@ -16,6 +16,63 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 4;
|
||||
constexpr index_t GroupNum = 8;
|
||||
|
||||
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
|
||||
const auto update_M0 =
|
||||
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
|
||||
|
||||
const auto xcd_id = block_1d_id % GroupNum;
|
||||
|
||||
const auto l_block_id = block_1d_id - (xcd_id % 2);
|
||||
|
||||
const auto ridn = GroupNum * M01 * (update_N0 / 2);
|
||||
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
|
||||
const auto lu = (l_block_id % GroupNum) + rid * ridn;
|
||||
|
||||
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
|
||||
const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
|
||||
|
||||
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
|
||||
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
|
||||
|
||||
const auto total_update_size = update_N0 * update_M0;
|
||||
|
||||
if(block_1d_id >= total_update_size)
|
||||
{
|
||||
auto x = (block_1d_id + 1) - total_update_size;
|
||||
auto rlen = N0 - update_N0;
|
||||
|
||||
auto rm = 0;
|
||||
auto rn = 0;
|
||||
if(rlen > 0)
|
||||
{
|
||||
rm = (x - 1) / rlen;
|
||||
rn = x % rlen;
|
||||
}
|
||||
|
||||
if(rlen > 0 and rm < M0)
|
||||
{
|
||||
n = rn + update_N0;
|
||||
m = rm;
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x - rlen * M0;
|
||||
rm = (x - 1) / update_N0;
|
||||
rn = x % update_N0;
|
||||
n = rn;
|
||||
m = update_M0 + rm;
|
||||
}
|
||||
}
|
||||
return make_multi_index(m, n);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename QDataType, typename KDataType, typename VDataType, typename ODataType>
|
||||
struct FlashAttnArgs
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user