Add MakeBlock2TileMap in 04_codegen_flash_attention_fwd

This commit is contained in:
BoboFang
2025-04-23 09:47:37 +00:00
committed by BoboFang
parent 1a91c220a1
commit 068d9fdbf7

View File

@@ -16,6 +16,63 @@
namespace ck_tile {
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
{
return [=](index_t block_1d_id) {
constexpr index_t M01 = 4;
constexpr index_t GroupNum = 8;
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
const auto update_M0 =
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
const auto xcd_id = block_1d_id % GroupNum;
const auto l_block_id = block_1d_id - (xcd_id % 2);
const auto ridn = GroupNum * M01 * (update_N0 / 2);
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
const auto lu = (l_block_id % GroupNum) + rid * ridn;
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
const auto total_update_size = update_N0 * update_M0;
if(block_1d_id >= total_update_size)
{
auto x = (block_1d_id + 1) - total_update_size;
auto rlen = N0 - update_N0;
auto rm = 0;
auto rn = 0;
if(rlen > 0)
{
rm = (x - 1) / rlen;
rn = x % rlen;
}
if(rlen > 0 and rm < M0)
{
n = rn + update_N0;
m = rm;
}
else
{
x = x - rlen * M0;
rm = (x - 1) / update_N0;
rn = x % update_N0;
n = rn;
m = update_M0 + rm;
}
}
return make_multi_index(m, n);
};
}
template <typename QDataType, typename KDataType, typename VDataType, typename ODataType>
struct FlashAttnArgs
{