Clean-up pipeline

This commit is contained in:
PoYen, Chen
2024-07-22 03:14:10 +00:00
parent fffd6799e6
commit 01865d2ae4

View File

@@ -153,106 +153,65 @@ struct BlockFmhaFwdAppendKVPipeline
rotary_sin_block_window_tmp.get_window_origin(),
Policy::template MakeRotaryCosSinInterleaveDramTileDistribution<Problem>());
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
#if 0
auto rotary_tile = rotary_sin_tile;
constexpr auto spans = decltype(rotary_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
rotary_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ksmem[row * (kTileSizeD / 2) + col] = rotary_tile(i_j_idx);
});
});
block_sync_lds();
DEVICE_DEBUG_STMTS
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
{
for(int row = 0; row < kTileSizeSk; ++row)
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t KPerThread = 16 / sizeof(KDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
if((start_x + KPerThread) <= rotary_dim)
{
printf("[DEVICE] rotary_smem[%3d] = ", row);
for(int col = 0; col < rotary_dim / 2; ++col)
{
printf("%11.7f", type_convert<float>(ksmem[row * (kTileSizeD / 2) + col]));
}
printf("\n");
}
}
#endif
#define DUMP_KNEW 0
constexpr index_t KPerThread = 16 / sizeof(KDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
static_assert(thread_buffer_size % KPerThread == 0);
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
if((start_x + KPerThread) <= rotary_dim)
{
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
auto knew_other_dram_window = knew_dram_window;
DEVICE_DEBUG_STMTS
{
auto origin = knew_other_dram_window.get_window_origin();
printf("after move window, origin = (%3d, %3d)\n",
origin.at(number<0>{}),
origin.at(number<1>{}));
}
move_tile_window(knew_other_dram_window,
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
DEVICE_DEBUG_STMTS
{
auto origin = knew_other_dram_window.get_window_origin();
printf("after move window, origin = (%3d, %3d)\n",
origin.at(number<0>{}),
origin.at(number<1>{}));
}
auto knew_other_tile = load_tile(knew_other_dram_window);
#if !DUMP_KNEW
{
constexpr auto spans = decltype(knew_other_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
knew_other_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ksmem[row * kTileSizeD + col] = knew_other_tile(i_j_idx);
});
knew_tile.thread_buf_[idx] = left * cos - right * sin;
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
});
}
#endif
#if !defined(DUMP_KNEW)
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
static_assert(thread_buffer_size % KPerThread == 0);
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<float>(knew_tile.thread_buf_[idx]);
const auto right = type_convert<float>(knew_tile.thread_buf_[idx + 1]);
const auto cos = type_convert<float>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin = type_convert<float>(rotary_sin_tile.thread_buf_[idx / 2]);
knew_tile.thread_buf_[idx] = left * cos - right * sin;
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
});
#endif
}
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); }
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
{
constexpr index_t KPerThread = 16 / sizeof(KDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
#if DUMP_KNEW
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
if((start_x + KPerThread) <= rotary_dim)
{
auto knew_other_dram_window = knew_dram_window;
DEVICE_DEBUG_STMTS
{
auto origin = knew_other_dram_window.get_window_origin();
printf("after move window, origin = (%3d, %3d)\n",
origin.at(number<0>{}),
origin.at(number<1>{}));
}
move_tile_window(knew_other_dram_window,
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
DEVICE_DEBUG_STMTS
{
auto origin = knew_other_dram_window.get_window_origin();
printf("after move window, origin = (%3d, %3d)\n",
origin.at(number<0>{}),
origin.at(number<1>{}));
}
auto knew_other_tile = load_tile(knew_other_dram_window);
}
}
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
{
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
@@ -268,7 +227,6 @@ struct BlockFmhaFwdAppendKVPipeline
});
});
}
#endif
block_sync_lds();
@@ -276,11 +234,7 @@ struct BlockFmhaFwdAppendKVPipeline
{
for(int row = 0; row < 7; ++row)
{
#if DUMP_KNEW
printf("[DEVICE] knew_tile[%3d] = ", row);
#else
printf("[DEVICE] knew_other_tile[%3d] = ", row);
#endif
for(int col = 0; col < kTileSizeD; ++col)
{
@@ -305,29 +259,6 @@ struct BlockFmhaFwdAppendKVPipeline
Policy::template MakeVnewDramTileDistribution<Problem>());
auto vnew_tile = load_tile(vnew_dram_window);
#if defined(ENABLE_PIPELINE_DEBUG_PRINT)
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == TID)
{
printf("[POYENC][DEVICE] tid: %d\n", TID);
constexpr auto spans = decltype(vnew_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
vnew_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
printf("[POYENC][DEVICE] vnew_tile(%2d,%2d): %11.7f\n",
row,
col,
type_convert<float>(vnew_tile(i_j_idx)));
});
});
}
#endif
store_tile(v_dram_block_window_tmp, vnew_tile);
}