tempsave, trload+asyncload done

This commit is contained in:
aska-0096
2025-07-21 05:55:55 +00:00
parent afd96d8180
commit 1b468bac0b
6 changed files with 168 additions and 17 deletions

View File

@@ -788,7 +788,7 @@ struct FmhaFwdDecodeKernel
amd_buffer_coherence_enum::SYSTEM_NT1>(
data, // will update this pointer if using paged-kvcache
make_tuple(length, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
make_tuple(kargs.hdim_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});

View File

@@ -349,8 +349,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
k_dram_window = make_tile_window(k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>());
// constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
// constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
do
{
@@ -370,12 +370,50 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
}
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
// auto v_temp = load_tile(v_dram_window);
// printf("Tid: %02d, v_temp: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
// %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
// %04x %04x %04x %04x %04x|\n",
// get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<7>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+7>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+7>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+7>{})))));
// move V tile windows
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {kK1, 0});
// CK_PRINT<decltype(v_dram_window.get_num_of_access())>();
// block_sync_lds_direct_load<v_vmem_insts>();
block_sync_lds_direct_load<v_vmem_insts>();
auto k_tile = load_tile(k_lds_read_window);
gemm_0(
@@ -385,6 +423,17 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
get_slice_tile(
k_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kN0, k0_loops * kK0>{}));
// printf("Tid: %02d, SAcc: %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf\n",
// get_thread_local_1d_id(),
// s_acc.get_thread_buffer()(number<0>{}),
// s_acc.get_thread_buffer()(number<1>{}),
// s_acc.get_thread_buffer()(number<2>{}),
// s_acc.get_thread_buffer()(number<3>{}),
// s_acc.get_thread_buffer()(number<4>{}),
// s_acc.get_thread_buffer()(number<5>{}),
// s_acc.get_thread_buffer()(number<6>{}),
// s_acc.get_thread_buffer()(number<7>{}));
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -592,6 +641,28 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
Policy::template MakePRegTileDistribution<Problem>());
p_tile.get_thread_buffer() = cast_tile<PDataType>(p_compute).get_thread_buffer();
// printf("Tid: %02d, PCompute: %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf\n",
// get_thread_local_1d_id(),
// p_compute.get_thread_buffer()(number<0>{}),
// p_compute.get_thread_buffer()(number<1>{}),
// p_compute.get_thread_buffer()(number<2>{}),
// p_compute.get_thread_buffer()(number<3>{}),
// p_compute.get_thread_buffer()(number<4>{}),
// p_compute.get_thread_buffer()(number<5>{}),
// p_compute.get_thread_buffer()(number<6>{}),
// p_compute.get_thread_buffer()(number<7>{}));
// printf("Tid: %02d, p_tile: %04x %04x %04x %04x %04x %04x %04x %04x\n",
// get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<0>{})))),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<1>{})))),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<2>{})))),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<3>{})))),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<4>{})))),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<5>{})))),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<6>{})))),
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<7>{})))));
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
@@ -629,9 +700,64 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
});
});
// block_sync_lds_direct_load<k_vmem_insts>();
block_sync_lds_direct_load<k_vmem_insts>();
auto v_tile = load_tile_transpose(v_lds_read_window);
// printf("Tid: %02d, v_tile: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
// %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
// %04x %04x %04x %04x %04x|\n",
// get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<7>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+7>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+7>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+0>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+1>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+2>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+3>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+4>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+5>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+6>{})))),
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+7>{})))));
// block_sync_lds();
// CK_PRINT<decltype(p_tile)>();
// ck_tile::tile_distribution_encoding<
// ck_tile::sequence<1>,
// ck_tile::tuple<ck_tile::sequence<1, 1, 16>, ck_tile::sequence<1, 2, 4, 4>>,
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 1>>,
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 2>>,
// ck_tile::sequence<1, 2, 2, 2>,
// ck_tile::sequence<0, 0, 1, 3>>;
// CK_PRINT<decltype(v_tile)>();
// ck_tile::tile_distribution_encoding<
// ck_tile::sequence<1>,
// ck_tile::tuple<ck_tile::sequence<4, 1, 16>, ck_tile::sequence<1, 2, 4, 4>>,
// ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<2, 1>>,
// ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<2, 2>>,
// ck_tile::sequence<1, 2, 2, 2>,
// ck_tile::sequence<0, 0, 1, 3>>;
gemm_1(
o_acc,
get_slice_tile(

View File

@@ -487,15 +487,18 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
{
return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::SaccDataType);
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
return NWarp > 1 ? MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::SaccDataType)
: 0;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) +
max(GetSmemSizeV<Problem>(), GetSmemSizeS<Problem>());
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) + GetSmemSizeS<Problem>() +
GetSmemSizeV<Problem>();
}
};