mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
tempsave. asynccopy+trload sanity checked
This commit is contained in:
@@ -274,8 +274,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType*>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
@@ -338,16 +338,11 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(1 == k0_loops);
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 == k1_loops);
|
||||
|
||||
block_sync_lds();
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
// move K tile windows
|
||||
i_page_block_k =
|
||||
k_page_block_navigator.move_tile_window(i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
@@ -369,70 +364,20 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
// auto v_temp = load_tile(v_dram_window);
|
||||
// printf("Tid: %02d, v_temp: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x|\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+7>{})))));
|
||||
|
||||
// move V tile windows
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {kK1, 0});
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
|
||||
// CK_PRINT<decltype(v_dram_window.get_num_of_access())>();
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
gemm_0(
|
||||
s_acc,
|
||||
get_slice_tile(
|
||||
q_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(
|
||||
k_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kN0, k0_loops * kK0>{}));
|
||||
|
||||
// printf("Tid: %02d, SAcc: %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// s_acc.get_thread_buffer()(number<0>{}),
|
||||
// s_acc.get_thread_buffer()(number<1>{}),
|
||||
// s_acc.get_thread_buffer()(number<2>{}),
|
||||
// s_acc.get_thread_buffer()(number<3>{}),
|
||||
// s_acc.get_thread_buffer()(number<4>{}),
|
||||
// s_acc.get_thread_buffer()(number<5>{}),
|
||||
// s_acc.get_thread_buffer()(number<6>{}),
|
||||
// s_acc.get_thread_buffer()(number<7>{}));
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_tile);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
@@ -546,15 +491,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
// move K tile windows after current status checked
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
// In Nwarp=1 and NXdl=32, GEMM0 output naturally fit the input of GEMM1
|
||||
// Otherwise shuffle through LDS so that the tile layout is consistent with required by
|
||||
block_sync_lds();
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
|
||||
// Gemm1
|
||||
auto s_new = [&]() {
|
||||
if constexpr(kNWarp > 1)
|
||||
@@ -641,28 +587,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
Policy::template MakePRegTileDistribution<Problem>());
|
||||
p_tile.get_thread_buffer() = cast_tile<PDataType>(p_compute).get_thread_buffer();
|
||||
|
||||
// printf("Tid: %02d, PCompute: %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// p_compute.get_thread_buffer()(number<0>{}),
|
||||
// p_compute.get_thread_buffer()(number<1>{}),
|
||||
// p_compute.get_thread_buffer()(number<2>{}),
|
||||
// p_compute.get_thread_buffer()(number<3>{}),
|
||||
// p_compute.get_thread_buffer()(number<4>{}),
|
||||
// p_compute.get_thread_buffer()(number<5>{}),
|
||||
// p_compute.get_thread_buffer()(number<6>{}),
|
||||
// p_compute.get_thread_buffer()(number<7>{}));
|
||||
|
||||
// printf("Tid: %02d, p_tile: %04x %04x %04x %04x %04x %04x %04x %04x\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<7>{})))));
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -703,61 +627,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
block_sync_lds_direct_load<k_vmem_insts>();
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
// printf("Tid: %02d, v_tile: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x|\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+7>{})))));
|
||||
|
||||
// block_sync_lds();
|
||||
// CK_PRINT<decltype(p_tile)>();
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<1>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 1, 16>, ck_tile::sequence<1, 2, 4, 4>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 2>>,
|
||||
// ck_tile::sequence<1, 2, 2, 2>,
|
||||
// ck_tile::sequence<0, 0, 1, 3>>;
|
||||
// CK_PRINT<decltype(v_tile)>();
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<1>,
|
||||
// ck_tile::tuple<ck_tile::sequence<4, 1, 16>, ck_tile::sequence<1, 2, 4, 4>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<2, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<2, 2>>,
|
||||
// ck_tile::sequence<1, 2, 2, 2>,
|
||||
// ck_tile::sequence<0, 0, 1, 3>>;
|
||||
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(
|
||||
|
||||
@@ -106,6 +106,33 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user