mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
tempsave, trload+asyncload done
This commit is contained in:
@@ -676,7 +676,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "exp" || init_method == "99")
|
||||
else if(init_method == "v1" || init_method == "97")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
@@ -685,6 +685,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "kv1" || init_method == "98")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "qkv1" || init_method == "99")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{1.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
|
||||
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
|
||||
#define CK_TILE_VMCNT(cnt) \
|
||||
([]() { static_assert((cnt) < 0b111111, "VMCNT only has 6 bits"); }(), \
|
||||
((cnt)&0b1111) | (((cnt)&0b110000) << 10))
|
||||
#define CK_TILE_EXPCNT(cnt) \
|
||||
([]() { static_assert((cnt) < 0b111, "EXP only has 3 bits"); }(), ((cnt) << 4))
|
||||
#define CK_TILE_LGKMCNT(cnt) \
|
||||
([]() { static_assert((cnt) < 0b1111, "LGKM only has 4 bits"); }(), ((cnt) << 8))
|
||||
([]() { static_assert((cnt) < (1 << 6), "VMCNT only has 6 bits"); }(), \
|
||||
((cnt)&0b1111) | (((cnt)&0b110000) << 14) | 0b0000'1111'0111'0000)
|
||||
#define CK_TILE_EXPCNT(cnt) \
|
||||
([]() { static_assert((cnt) < (1 << 3), "EXP only has 3 bits"); }(), \
|
||||
((cnt) << 4) | 0b1100'1111'0000'1111)
|
||||
#define CK_TILE_LGKMCNT(cnt) \
|
||||
([]() { static_assert((cnt) < (1 << 4), "LGKM only has 4 bits"); }(), \
|
||||
((cnt) << 8) | 0b1100'0000'0111'1111)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -127,7 +129,7 @@ template <index_t vmcnt>
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
// We don't sync the lds insts here.
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_VMCNT(vmcnt));
|
||||
__builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_VMCNT(vmcnt));
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
|
||||
@@ -433,6 +433,8 @@ struct tile_window_with_static_distribution
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// printf("Tid: %02d, tr_load_idx: %d\n",
|
||||
// get_thread_local_1d_id(),bottom_tensor_thread_coord.get_offset());
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view()
|
||||
|
||||
@@ -788,7 +788,7 @@ struct FmhaFwdDecodeKernel
|
||||
amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(length, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
make_tuple(kargs.hdim_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
@@ -349,8 +349,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
k_dram_window = make_tile_window(k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
// constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
// constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
|
||||
constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
do
|
||||
{
|
||||
@@ -370,12 +370,50 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
|
||||
async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
|
||||
// auto v_temp = load_tile(v_dram_window);
|
||||
// printf("Tid: %02d, v_temp: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x|\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<8+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<16+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_temp.get_thread_buffer()(number<24+7>{})))));
|
||||
|
||||
// move V tile windows
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {kK1, 0});
|
||||
|
||||
// CK_PRINT<decltype(v_dram_window.get_num_of_access())>();
|
||||
// block_sync_lds_direct_load<v_vmem_insts>();
|
||||
block_sync_lds_direct_load<v_vmem_insts>();
|
||||
auto k_tile = load_tile(k_lds_read_window);
|
||||
|
||||
gemm_0(
|
||||
@@ -385,6 +423,17 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
get_slice_tile(
|
||||
k_tile, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kN0, k0_loops * kK0>{}));
|
||||
|
||||
// printf("Tid: %02d, SAcc: %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// s_acc.get_thread_buffer()(number<0>{}),
|
||||
// s_acc.get_thread_buffer()(number<1>{}),
|
||||
// s_acc.get_thread_buffer()(number<2>{}),
|
||||
// s_acc.get_thread_buffer()(number<3>{}),
|
||||
// s_acc.get_thread_buffer()(number<4>{}),
|
||||
// s_acc.get_thread_buffer()(number<5>{}),
|
||||
// s_acc.get_thread_buffer()(number<6>{}),
|
||||
// s_acc.get_thread_buffer()(number<7>{}));
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -592,6 +641,28 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
Policy::template MakePRegTileDistribution<Problem>());
|
||||
p_tile.get_thread_buffer() = cast_tile<PDataType>(p_compute).get_thread_buffer();
|
||||
|
||||
// printf("Tid: %02d, PCompute: %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// p_compute.get_thread_buffer()(number<0>{}),
|
||||
// p_compute.get_thread_buffer()(number<1>{}),
|
||||
// p_compute.get_thread_buffer()(number<2>{}),
|
||||
// p_compute.get_thread_buffer()(number<3>{}),
|
||||
// p_compute.get_thread_buffer()(number<4>{}),
|
||||
// p_compute.get_thread_buffer()(number<5>{}),
|
||||
// p_compute.get_thread_buffer()(number<6>{}),
|
||||
// p_compute.get_thread_buffer()(number<7>{}));
|
||||
|
||||
// printf("Tid: %02d, p_tile: %04x %04x %04x %04x %04x %04x %04x %04x\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(p_tile.get_thread_buffer()(number<7>{})))));
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -629,9 +700,64 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
// block_sync_lds_direct_load<k_vmem_insts>();
|
||||
block_sync_lds_direct_load<k_vmem_insts>();
|
||||
auto v_tile = load_tile_transpose(v_lds_read_window);
|
||||
|
||||
// printf("Tid: %02d, v_tile: %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x| %04x %04x %04x %04x %04x %04x %04x %04x| %04x %04x %04x
|
||||
// %04x %04x %04x %04x %04x|\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<8+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<16+7>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+0>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+1>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+2>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+3>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+4>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+5>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+6>{})))),
|
||||
// *(reinterpret_cast<uint16_t*>(&(v_tile.get_thread_buffer()(number<24+7>{})))));
|
||||
|
||||
// block_sync_lds();
|
||||
// CK_PRINT<decltype(p_tile)>();
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<1>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 1, 16>, ck_tile::sequence<1, 2, 4, 4>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 2>>,
|
||||
// ck_tile::sequence<1, 2, 2, 2>,
|
||||
// ck_tile::sequence<0, 0, 1, 3>>;
|
||||
// CK_PRINT<decltype(v_tile)>();
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<1>,
|
||||
// ck_tile::tuple<ck_tile::sequence<4, 1, 16>, ck_tile::sequence<1, 2, 4, 4>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<2, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<2, 2>>,
|
||||
// ck_tile::sequence<1, 2, 2, 2>,
|
||||
// ck_tile::sequence<0, 0, 1, 3>>;
|
||||
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(
|
||||
|
||||
@@ -487,15 +487,18 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
|
||||
{
|
||||
return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::SaccDataType);
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
return NWarp > 1 ? MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::SaccDataType)
|
||||
: 0;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) +
|
||||
max(GetSmemSizeV<Problem>(), GetSmemSizeS<Problem>());
|
||||
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) + GetSmemSizeS<Problem>() +
|
||||
GetSmemSizeV<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user