diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index 56767bec66..324d7924eb 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -16,6 +16,63 @@ namespace ck_tile { +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if(rlen > 0) + { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if(rlen > 0 and rm < M0) + { + n = rn + update_N0; + m = rm; + } + else + { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +} + template struct FlashAttnArgs {