mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Clean-up pipeline
This commit is contained in:
@@ -153,106 +153,65 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
rotary_sin_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeRotaryCosSinInterleaveDramTileDistribution<Problem>());
|
||||
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
#if 0
|
||||
auto rotary_tile = rotary_sin_tile;
|
||||
constexpr auto spans = decltype(rotary_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
rotary_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
ksmem[row * (kTileSizeD / 2) + col] = rotary_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
DEVICE_DEBUG_STMTS
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
for(int row = 0; row < kTileSizeSk; ++row)
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
printf("[DEVICE] rotary_smem[%3d] = ", row);
|
||||
for(int col = 0; col < rotary_dim / 2; ++col)
|
||||
{
|
||||
printf("%11.7f", type_convert<float>(ksmem[row * (kTileSizeD / 2) + col]));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#define DUMP_KNEW 0
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
auto knew_other_dram_window = knew_dram_window;
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
auto origin = knew_other_dram_window.get_window_origin();
|
||||
printf("after move window, origin = (%3d, %3d)\n",
|
||||
origin.at(number<0>{}),
|
||||
origin.at(number<1>{}));
|
||||
}
|
||||
move_tile_window(knew_other_dram_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
auto origin = knew_other_dram_window.get_window_origin();
|
||||
printf("after move window, origin = (%3d, %3d)\n",
|
||||
origin.at(number<0>{}),
|
||||
origin.at(number<1>{}));
|
||||
}
|
||||
auto knew_other_tile = load_tile(knew_other_dram_window);
|
||||
|
||||
#if !DUMP_KNEW
|
||||
{
|
||||
constexpr auto spans = decltype(knew_other_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
knew_other_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
ksmem[row * kTileSizeD + col] = knew_other_tile(i_j_idx);
|
||||
});
|
||||
knew_tile.thread_buf_[idx] = left * cos - right * sin;
|
||||
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(DUMP_KNEW)
|
||||
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
knew_tile.thread_buf_[idx] = left * cos - right * sin;
|
||||
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
|
||||
DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); }
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
#if DUMP_KNEW
|
||||
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
auto knew_other_dram_window = knew_dram_window;
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
auto origin = knew_other_dram_window.get_window_origin();
|
||||
printf("after move window, origin = (%3d, %3d)\n",
|
||||
origin.at(number<0>{}),
|
||||
origin.at(number<1>{}));
|
||||
}
|
||||
move_tile_window(knew_other_dram_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
auto origin = knew_other_dram_window.get_window_origin();
|
||||
printf("after move window, origin = (%3d, %3d)\n",
|
||||
origin.at(number<0>{}),
|
||||
origin.at(number<1>{}));
|
||||
}
|
||||
auto knew_other_tile = load_tile(knew_other_dram_window);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
|
||||
{
|
||||
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -268,7 +227,6 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
});
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -276,11 +234,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
{
|
||||
for(int row = 0; row < 7; ++row)
|
||||
{
|
||||
#if DUMP_KNEW
|
||||
printf("[DEVICE] knew_tile[%3d] = ", row);
|
||||
#else
|
||||
printf("[DEVICE] knew_other_tile[%3d] = ", row);
|
||||
#endif
|
||||
|
||||
for(int col = 0; col < kTileSizeD; ++col)
|
||||
{
|
||||
@@ -305,29 +259,6 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
Policy::template MakeVnewDramTileDistribution<Problem>());
|
||||
|
||||
auto vnew_tile = load_tile(vnew_dram_window);
|
||||
|
||||
#if defined(ENABLE_PIPELINE_DEBUG_PRINT)
|
||||
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
|
||||
{
|
||||
printf("[POYENC][DEVICE] tid: %d\n", TID);
|
||||
constexpr auto spans = decltype(vnew_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
vnew_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
printf("[POYENC][DEVICE] vnew_tile(%2d,%2d): %11.7f\n",
|
||||
row,
|
||||
col,
|
||||
type_convert<float>(vnew_tile(i_j_idx)));
|
||||
});
|
||||
});
|
||||
}
|
||||
#endif
|
||||
store_tile(v_dram_block_window_tmp, vnew_tile);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user