diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 11794e67be..a1c21ed1ed 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -153,106 +153,65 @@ struct BlockFmhaFwdAppendKVPipeline rotary_sin_block_window_tmp.get_window_origin(), Policy::template MakeRotaryCosSinInterleaveDramTileDistribution()); - auto rotary_cos_tile = load_tile(rotary_cos_window); - auto rotary_sin_tile = load_tile(rotary_sin_window); - -#if 0 - auto rotary_tile = rotary_sin_tile; - constexpr auto spans = decltype(rotary_tile)::get_distributed_spans(); - sweep_tile_span(spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - rotary_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = tile_idx.at(number<0>{}); - const auto col = tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - ksmem[row * (kTileSizeD / 2) + col] = rotary_tile(i_j_idx); - }); - }); - - block_sync_lds(); - - DEVICE_DEBUG_STMTS + if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) { - for(int row = 0; row < kTileSizeSk; ++row) + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t KPerThread = 16 / sizeof(KDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + + if((start_x + KPerThread) <= rotary_dim) { - printf("[DEVICE] rotary_smem[%3d] = ", row); - for(int col = 0; col < rotary_dim / 2; ++col) - { - printf("%11.7f", type_convert(ksmem[row * (kTileSizeD / 2) + col])); - } - printf("\n"); - } - } -#endif -#define DUMP_KNEW 0 - constexpr index_t KPerThread = 16 / sizeof(KDataType); - static_assert(kTileSizeD % KPerThread == 0); - constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; - index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; + constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); + static_assert(thread_buffer_size % KPerThread == 0); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(knew_tile.thread_buf_[idx]); + const auto right = type_convert(knew_tile.thread_buf_[idx + 1]); - if((start_x + KPerThread) <= rotary_dim) - { - bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx / 2]); - auto knew_other_dram_window = knew_dram_window; - DEVICE_DEBUG_STMTS - { - auto origin = knew_other_dram_window.get_window_origin(); - printf("after move window, origin = (%3d, %3d)\n", - origin.at(number<0>{}), - origin.at(number<1>{})); - } - move_tile_window(knew_other_dram_window, - {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - DEVICE_DEBUG_STMTS - { - auto origin = knew_other_dram_window.get_window_origin(); - printf("after move window, origin = (%3d, %3d)\n", - origin.at(number<0>{}), - origin.at(number<1>{})); - } - auto knew_other_tile = load_tile(knew_other_dram_window); - -#if !DUMP_KNEW - { - constexpr auto spans = decltype(knew_other_tile)::get_distributed_spans(); - sweep_tile_span(spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - knew_other_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = tile_idx.at(number<0>{}); - const auto col = tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - ksmem[row * kTileSizeD + col] = knew_other_tile(i_j_idx); - }); + knew_tile.thread_buf_[idx] = left * cos - right * sin; + knew_tile.thread_buf_[idx + 1] = right * cos + left * sin; }); } -#endif - -#if !defined(DUMP_KNEW) - constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size(); - static_assert(thread_buffer_size % KPerThread == 0); - static_for<0, thread_buffer_size, 2>{}([&](auto idx) { - const auto left = type_convert(knew_tile.thread_buf_[idx]); - const auto right = type_convert(knew_tile.thread_buf_[idx + 1]); - - const auto cos = type_convert(rotary_cos_tile.thread_buf_[idx / 2]); - const auto sin = type_convert(rotary_sin_tile.thread_buf_[idx / 2]); - - knew_tile.thread_buf_[idx] = left * cos - right * sin; - knew_tile.thread_buf_[idx + 1] = right * cos + left * sin; - }); -#endif } -#if defined(ENABLE_DEVICE_DEBUG_STMTS) - DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); } + else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED + { + constexpr index_t KPerThread = 16 / sizeof(KDataType); + static_assert(kTileSizeD % KPerThread == 0); + constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread; + index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread; -#if DUMP_KNEW + bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); + + if((start_x + KPerThread) <= rotary_dim) + { + auto knew_other_dram_window = knew_dram_window; + DEVICE_DEBUG_STMTS + { + auto origin = knew_other_dram_window.get_window_origin(); + printf("after move window, origin = (%3d, %3d)\n", + origin.at(number<0>{}), + origin.at(number<1>{})); + } + move_tile_window(knew_other_dram_window, + {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + DEVICE_DEBUG_STMTS + { + auto origin = knew_other_dram_window.get_window_origin(); + printf("after move window, origin = (%3d, %3d)\n", + origin.at(number<0>{}), + origin.at(number<1>{})); + } + auto knew_other_tile = load_tile(knew_other_dram_window); + } + } + +#if defined(ENABLE_DEVICE_DEBUG_STMTS) { constexpr auto spans = decltype(knew_tile)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { @@ -268,7 +227,6 @@ struct BlockFmhaFwdAppendKVPipeline }); }); } -#endif block_sync_lds(); @@ -276,11 +234,7 @@ struct BlockFmhaFwdAppendKVPipeline { for(int row = 0; row < 7; ++row) { -#if DUMP_KNEW printf("[DEVICE] knew_tile[%3d] = ", row); -#else - printf("[DEVICE] knew_other_tile[%3d] = ", row); -#endif for(int col = 0; col < kTileSizeD; ++col) { @@ -305,29 +259,6 @@ struct BlockFmhaFwdAppendKVPipeline Policy::template MakeVnewDramTileDistribution()); auto vnew_tile = load_tile(vnew_dram_window); - -#if defined(ENABLE_PIPELINE_DEBUG_PRINT) - if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID) - { - printf("[POYENC][DEVICE] tid: %d\n", TID); - constexpr auto spans = decltype(vnew_tile)::get_distributed_spans(); - sweep_tile_span(spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - vnew_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = tile_idx.at(number<0>{}); - const auto col = tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - printf("[POYENC][DEVICE] vnew_tile(%2d,%2d): %11.7f\n", - row, - col, - type_convert(vnew_tile(i_j_idx))); - }); - }); - } -#endif store_tile(v_dram_block_window_tmp, vnew_tile); }