This commit is contained in:
Bernard
2025-07-06 09:24:58 +00:00
parent 751d114b91
commit 470ae33980
10 changed files with 396 additions and 145 deletions

View File

@@ -108,7 +108,7 @@ else()
endif()
list(APPEND CMAKE_COMPILER_WARNINGS
-Wno-missing-field-initializers
-Wno-deprecated-declarations
# -Wno-deprecated-declarations
)
endif()
add_definitions(${CMAKE_COMPILER_WARNINGS})

View File

@@ -106,6 +106,10 @@ static void run(const ck_tile::stream_config& s, fmha_batch_decode_args a)
auto [kargs, grids] = fmha_batch_decode_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
printf("run into fmha_batch_decode_bf16_nlogits_nbias_nmask_lse_nsquant_pagedkv\\n");
printf("blocks: %d, %d, %d\\n", blocks.x, blocks.y, blocks.z);
printf("grids: %d, %d, %d\\n", grids.x, grids.y, grids.z);
printf("kBlockPerCu: %d\\n", kBlockPerCu);
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};

View File

@@ -62,6 +62,7 @@
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/debug.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"

View File

@@ -205,7 +205,7 @@ struct tile_scatter_gather_debug
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{}); // NumAccessPerCoord = 4
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
@@ -338,14 +338,14 @@ struct tile_scatter_gather_debug
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
auto idxx = bottom_tensor_thread_coord.get_index();
auto idxx = bottom_tensor_thread_coord.get_index(); // 16 * 128
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
// printf("bid %d tid %d coord_offset:%d %d %d, page_offset:%d v %f\n", blockIdx.x, threadIdx.x,
// bottom_tensor_thread_coord.get_offset(), idxx(I0), idxx(I1), page_offset, type_convert<float>(vec_value.template get_as<DataType>()[0]), type_convert<float>(vec_value.template get_as<DataType>()[4]));
// printf("bid %d tid %d coord_offset:%d %d %d, page_offset:%d v %f %f \n", blockIdx.x, threadIdx.x,
// bottom_tensor_thread_coord.get_offset(), idxx(I0), idxx(I1), page_offset, type_convert<float>(vec_value.template get_as<DataType>()[0]), type_convert<float>(vec_value.template get_as<DataType>()[0]), type_convert<float>(vec_value.template get_as<DataType>()[4]));
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
@@ -354,9 +354,9 @@ struct tile_scatter_gather_debug
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<NDimY>{}); // idx_ys = tuple<ck_tile::constant<0>, ck_tile::constant<2>, ck_tile::constant<4>>
constexpr index_t d =
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;

View File

@@ -738,8 +738,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
make_tuple(kargs.num_total_pages / 16, kargs.hdim_v, 16 / 8, 8),
make_tuple(kargs.hdim_v * 16, 16, 8, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
@@ -747,15 +747,15 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
v_dram_naive,
make_tuple(
make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_merge_transform(make_tuple(kargs.num_total_pages / 16, 16 / 8, 8))),
make_tuple(sequence<1>{}, sequence<0, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
// constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
sequence<kPadHeadDimV, true>{});
}
else
{

View File

@@ -219,7 +219,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemmPreshuffled<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemmPreshuffled<Problem>();
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -259,7 +259,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const index_t num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const index_t num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); // [64, 0, 64 -> 1]
// check early exit if no work to do
if(num_total_loop <= 0)
@@ -297,22 +298,24 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
const auto VPageIndexDim = I1;
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
auto v_dist = Policy::template MakeVDramTileDistributionPreshuffled<Problem>();
auto v_coord = v_dist.calculate_index();
const auto VPageIndexDim = I1;
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I0];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = kv_page_indices[v_coord[VPageIndexDim] + k0.value] * stride_v;
});
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
v_dist,
v_offsets,
VPageIndexDim);
// v_dram_block_window_tmp.get_bottom_tensor_view(): see dram_bottom_tensor_view_record
auto v_dram_window = make_tile_scatter_gather_debug(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp
.get_window_lengths(), // tuple<ck_tile::constant<128>, ck_tile::constant<64>>
{0, seqlen_k_start}, // TODO: hdim split?
v_dist,
v_offsets,
VPageIndexDim);
// store Q into LDS
__builtin_amdgcn_sched_barrier(0);
@@ -343,21 +346,25 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
static_assert(1 <= k1_loops);
auto k_dram_window = [&] {
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
auto k_coord = k_dist.calculate_index();
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
auto k_coord = k_dist.calculate_index();
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
// constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I1];
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = kv_page_indices[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * stride_k;
k_offsets[n0] =
kv_page_indices[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * stride_k;
// printf("threadIdx.x %d, k_offsets[%d, %d] = %d\n", threadIdx.x, k_coord[0],
// k_coord[1], k_offsets[n0]);
});
return make_tile_scatter_gather_debug(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
return make_tile_scatter_gather_debug(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(), // [64, 128]
k_dram_block_window.get_window_origin(), //
k_dist,
k_offsets); // K DRAM tile window for
}();
// load the first tile of the first iteration and store to LDS
@@ -379,7 +386,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(blockIdx.x==0)
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, type_convert<float>(k_block_tile(i_j_idx)));
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x,
// type_convert<float>(k_block_tile(i_j_idx)));
// });
// });
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
@@ -409,19 +417,21 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
// store_tile(
// k_lds_window,
// tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(blockIdx.x==0)
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, type_convert<float>(k_block_tile(i_j_idx)));
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x,
// type_convert<float>(k_block_tile(i_j_idx)));
// });
// });
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
// const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
const auto v_block_tile = load_tile(v_dram_window);
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = kv_page_indices[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
@@ -488,8 +498,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) {
x = x * scale_s; }, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#else
if constexpr(kHasLogitsSoftCap)
{
@@ -545,22 +554,23 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window = [&] {
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
auto k_coord = k_dist.calculate_index();
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
auto k_coord = k_dist.calculate_index();
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] =
(kv_page_indices + kN0)[k_coord[0] + kN0 / NRepeat * n0.value] / 16 *
(kv_page_indices + kN0)[k_coord[0] + kN0 / NRepeat * n0.value] / 16 *
stride_k;
});
return make_tile_scatter_gather_debug(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
return make_tile_scatter_gather_debug(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
}();
// laod the first tile of the first iteration and store to LDS
@@ -682,28 +692,30 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
});
});
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1});
// block_sync_lds();
// if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
// {
// auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
// Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
// shuffle_tile(v_shuffle_tmp, v_prefetch);
// store_tile(
// v_lds_window,
// tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
// }
// else
// {
// store_tile(v_lds_window,
// tile_elementwise_in(v_element_func, v_prefetch)); // store the
// prefetch
// }
move_tile_window(v_dram_window, {0, kK1});
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
// const auto v = load_tile(v_dram_window); // load next v
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = kv_page_indices[kK1 * 2 + i_k1.value * kK1 +
@@ -713,35 +725,152 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
v_dram_window.update_page_idx(v_offsets);
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
// gemm_1(o_acc,
// get_slice_tile(
// p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
// v_block_tile);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
// if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
// {
// auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
// Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
// shuffle_tile(v_shuffle_tmp, v);
// store_tile(v_lds_window,
// tile_elementwise_in(v_element_func,
// v_shuffle_tmp)); // store the prefetch
// }
// else
// {
// store_tile(v_lds_window,
// tile_elementwise_in(v_element_func, v)); // store next v
// }
move_tile_window(v_dram_window, {0, kK1});
const auto v_block_tile = load_tile(v_dram_window);
});
}
// tail
{
block_sync_lds();
auto temp = get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{});
// using a = ck_tile::static_distributed_tensor<
// unsigned short,
// ck_tile::tile_distribution<
// ck_tile::tensor_adaptor<
// ck_tile::tuple<
// ck_tile::replicate<ck_tile::tuple<ck_tile::constant<4>>>,
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<1>,
// ck_tile::constant<1>,
// ck_tile::constant<16>>,
// false>,
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<1>,
// ck_tile::constant<4>,
// ck_tile::constant<4>,
// ck_tile::constant<4>>,
// false>,
// ck_tile::merge_v2_magic_division<
// ck_tile::tuple<ck_tile::constant<1>, ck_tile::constant<4>>>,
// ck_tile::merge_v2_magic_division<
// ck_tile::tuple<ck_tile::constant<4>,
// ck_tile::constant<16>>>>,
// ck_tile::tuple<ck_tile::sequence<>,
// ck_tile::sequence<0>,
// ck_tile::sequence<1>,
// ck_tile::sequence<4, 2>,
// ck_tile::sequence<8, 5>>,
// ck_tile::tuple<ck_tile::sequence<2>,
// ck_tile::sequence<3, 4, 5>,
// ck_tile::sequence<6, 7, 8, 9>,
// ck_tile::sequence<10>,
// ck_tile::sequence<11>>,
// ck_tile::sequence<0, 1>,
// ck_tile::sequence<10, 11, 3, 6, 7, 9>>,
// ck_tile::tensor_descriptor<
// ck_tile::tuple<ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<1>,
// ck_tile::constant<1>,
// ck_tile::constant<4>,
// ck_tile::constant<4>>,
// false>>,
// ck_tile::tuple<ck_tile::sequence<0>>,
// ck_tile::tuple<ck_tile::sequence<1, 2, 3, 4>>,
// ck_tile::sequence<1, 2, 3, 4>,
// ck_tile::constant<16>,
// ck_tile::sequence<-1, -1, -1, -1, -1>,
// ck_tile::sequence<-1, -1, -1, -1, -1>>,
// ck_tile::tile_distribution_encoding<
// ck_tile::sequence<4>,
// ck_tile::tuple<ck_tile::sequence<1, 1, 16>,
// ck_tile::sequence<1, 4, 4, 4>>,
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 1>>,
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 2>>,
// ck_tile::sequence<1, 2, 2, 2>,
// ck_tile::sequence<0, 0, 1, 3>>,
// ck_tile::detail::tile_distribution_detail<
// ck_tile::tuple<ck_tile::sequence<2>,
// ck_tile::sequence<3, 4, 5>,
// ck_tile::sequence<6, 7, 8, 9>>>>>;
// using b = ck_tile::static_distributed_tensor<
// unsigned short,
// ck_tile::tile_distribution<
// ck_tile::tensor_adaptor<
// ck_tile::tuple<
// ck_tile::replicate<ck_tile::tuple<ck_tile::constant<1>>>,
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<4>,
// ck_tile::constant<2>,
// ck_tile::constant<16>>,
// false>,
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<2>,
// ck_tile::constant<2>,
// ck_tile::constant<2>,
// ck_tile::constant<8>>,
// false>,
// ck_tile::merge_v2_magic_division<ck_tile::tuple<ck_tile::constant<4>>>,
// ck_tile::merge_v2_magic_division<
// ck_tile::tuple<ck_tile::constant<2>,
// ck_tile::constant<2>,
// ck_tile::constant<16>>>>,
// ck_tile::tuple<ck_tile::sequence<>,
// ck_tile::sequence<0>,
// ck_tile::sequence<1>,
// ck_tile::sequence<3>,
// ck_tile::sequence<7, 8, 5>>,
// ck_tile::tuple<ck_tile::sequence<2>,
// ck_tile::sequence<3, 4, 5>,
// ck_tile::sequence<6, 7, 8, 9>,
// ck_tile::sequence<10>,
// ck_tile::sequence<11>>,
// ck_tile::sequence<0, 1>,
// ck_tile::sequence<10, 11, 6, 4, 9>>,
// ck_tile::tensor_descriptor<
// ck_tile::tuple<ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<2>,
// ck_tile::constant<2>,
// ck_tile::constant<8>>,
// false>>,
// ck_tile::tuple<ck_tile::sequence<0>>,
// ck_tile::tuple<ck_tile::sequence<1, 2, 3>>,
// ck_tile::sequence<1, 2, 3>,
// ck_tile::constant<32>,
// ck_tile::sequence<-1, -1, -1, -1>,
// ck_tile::sequence<-1, -1, -1, -1>>,
// ck_tile::tile_distribution_encoding<
// ck_tile::sequence<1>,
// ck_tile::tuple<ck_tile::sequence<4, 2, 16>, ck_tile::sequence<2, 2, 2, 8>>,
// ck_tile::tuple<ck_tile::sequence<1>, ck_tile::sequence<2, 2, 1>>,
// ck_tile::tuple<ck_tile::sequence<0>, ck_tile::sequence<1, 2, 2>>,
// ck_tile::sequence<2, 1, 2>,
// ck_tile::sequence<0, 1, 3>>,
// ck_tile::detail::tile_distribution_detail<
// ck_tile::tuple<ck_tile::sequence<2>,
// ck_tile::sequence<3, 4, 5>,
// ck_tile::sequence<6, 7, 8, 9>>>>>
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
v_lds_window);
v_block_tile);
block_sync_lds();
}
kv_page_indices += kN0;

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v3.hpp"
namespace ck_tile {
@@ -59,20 +60,20 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
constexpr index_t KPerThread = kMaxVecLoad; // 8
constexpr index_t KThreads = kKPerBlock / KPerThread; // 16
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; // 4
constexpr index_t NumWarps = kBlockSize / get_warp_size(); // 4
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); // 1
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>, // 1, 4, 4
sequence<KThreads, KPerThread>>, // 16, 8
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<1>, sequence<2, 0>>, // NumWarps, MThreadPerWarp, KThreads 4, 4, 16
sequence<1, 2>,
sequence<0, 1>>{});
sequence<0, 1>>{}); // MPerThread, KPerThread 1, 8
}
// template <typename Problem>
@@ -227,26 +228,58 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; // 64
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; // 128
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; // 32
constexpr index_t K2 = min(MaxVectorSize, ElemPerThread); //8
constexpr index_t K2 = min(MaxVectorSize, ElemPerThread); // 8
constexpr index_t K1 = 4;
constexpr index_t K0 = kKPerBlock / K1 / K2; // 2
constexpr index_t N2 = 16; // 8
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
constexpr index_t K0 = kKPerBlock / K1 / K2; // 4
constexpr index_t N2 = 16;
constexpr index_t N1 = kBlockSize / get_warp_size(); // 4
constexpr index_t N0 = kNPerBlock / (N2 * N1); // 1
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>, // 1, 4, 16, 4, 4, 8
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, K1, N2 : 4, 4, 16
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
sequence<0, 0, 2>>{}); // N0, K0, K2 : 1, 4, 8
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistributionPreshuffled()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>, "wrong");
constexpr index_t N2 = 16;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // 32
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t K3 = min(MaxVectorSize, total_pixels); // 8
constexpr index_t K1 = 2;
constexpr index_t K2 = 2;
constexpr index_t K0 = kKPerBlock / K1 / K2 / K3; // 2
constexpr index_t N0 = kBlockSize / get_warp_size(); // 4
constexpr index_t N1 = kNPerBlock / (N2 * N0); // 2
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2, K3>>, // 4, 2, 16, 2, 2, 2, 8
tuple<sequence<1>, sequence<2, 2, 1>>,
tuple<sequence<0>, sequence<1, 2, 2>>, // N0, K1, K2, N2 : 4, 2, 2, 16
sequence<2, 1, 2>,
sequence<0, 1, 3>>{}); // K0, N1, K3 : 2, 2, 8
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
@@ -320,6 +353,67 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy
static_assert(1 < Problem::kNumGemm0Warps);
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmPreshuffled()
{
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::PDataType, half_t> &&
std::is_same_v<typename Problem::VDataType, half_t> &&
std::is_same_v<typename Problem::OaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::PDataType, bf16_t> &&
std::is_same_v<typename Problem::VDataType, bf16_t> &&
std::is_same_v<typename Problem::OaccDataType, float>)
{
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::PDataType, fp8_t> &&
std::is_same_v<typename Problem::VDataType, fp8_t> &&
std::is_same_v<typename Problem::OaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
} // TODO - bf8_t
}();
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
decltype(warp_gemm)>;
static_assert(1 < Problem::kNumGemm1Warps);
return BlockGemmARegBRegCRegV3<GemmProblem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -152,7 +152,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; // 16 * 128 / 256 = 8
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
@@ -729,16 +729,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); //8
constexpr index_t K0 = kKPerBlock / K1; // 8
constexpr index_t N2 = get_warp_size() / K0; // 8
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
constexpr index_t N1 = kBlockSize / get_warp_size(); // 4
constexpr index_t N0 = kNPerBlock / (N2 * N1); // 4
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, // 4, 4, 8, 8, 8
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
tuple<sequence<1>, sequence<2, 0>>, // N1, N2, K0: 4, 8, 8
sequence<1, 2>,
sequence<0, 1>>{});
sequence<0, 1>>{}); // N0, K1 48
}
else
{
@@ -778,32 +778,32 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
constexpr index_t N0 = kNPerBlock / N1; // 16
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // 32
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K3 = total_pixels / N1; // 4
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // 8
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave // 2
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t K1 = get_warp_size() / (K2 * N0); // 2
constexpr index_t K0 = kBlockSize / get_warp_size(); // 4
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, // 16, 8, 4, 2, 2, 4
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>, // K0, K1, N0, K2: 4, 2, 16, 2
sequence<2, 1>,
sequence<3, 1>>{});
sequence<3, 1>>{}); // K3, N1: 4, 8
}
else
{

View File

@@ -44,15 +44,16 @@ struct BlockGemmARegBRegCRegV2
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
// const const tuple<ck_tile::WarpGemmImpl<ck_tile::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<ck_tile::WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<ck_tile::WGAttrCtlEnum::Default_>, 2>>, int, int>
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MWarp = config.template at<1>(); // 1
constexpr index_t NWarp = config.template at<2>(); // 4
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // WG::kM, kN, kK: 16, 16, 32
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
@@ -60,13 +61,18 @@ struct BlockGemmARegBRegCRegV2
const index_t iNWarp = get_warp_id() % NWarp;
// if(threadIdx.x%64==0)
// printf("tid %d %d %d %d %d KIterPerWarp %d\n",threadIdx.x, MPerBlock, NPerBlock, NIterPerWarp, iNWarp, KIterPerWarp);
// tid 0 16 64 1 0 KIterPerWarp 4
// tid 64 16 64 1 1 KIterPerWarp 4
// tid 128 16 64 1 2 KIterPerWarp 4
// tid 192 16 64 1 3 KIterPerWarp 4
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
tuple<sequence<1, 1>>, // MWarp, NWarp 1, 4
sequence<1, 2>,
sequence<0, 0>>{};
sequence<0, 0>>{}; // MIterPerWarp, NIterPerWarp, 1, 1
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
@@ -123,9 +129,10 @@ struct BlockGemmARegBRegCRegV2
// printf("get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size());
// if(threadIdx.x==0)
// printf("get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size());
// printf("b get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size());
// auto &xx = b_warp_tensor.get_thread_buffer();
// printf("bid %d tid %d b %f %f %f %f\n", blockIdx.x, threadIdx.x, type_convert<float>(xx[number<0>{}]), type_convert<float>(xx[number<2>{}]), type_convert<float>(xx[number<4>{}]), type_convert<float>(xx[number<6>{}]));
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
@@ -196,19 +203,19 @@ struct BlockGemmARegBRegCRegV2
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MWarp = config.template at<1>(); // 1
constexpr index_t NWarp = config.template at<2>(); // 4
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 1
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 4
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tile_distribution_encoding<sequence<NWarp>, // 4
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // 1, 1, 4
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>, // MWarp, NWarp 1, 4
sequence<1, 2>,
sequence<0, 0>>{};
sequence<0, 0>>{}; // MIterPerWarp, KIterPerWarp 1, 4
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});

View File

@@ -478,16 +478,32 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
// static constexpr index_t kM = 16;
// static constexpr index_t kN = 16;
// static constexpr index_t kK = 16;
// static constexpr index_t kAMBlock = 1;
// static constexpr index_t kBNBlock = 1;
// static constexpr index_t kAMLane = 16;
// static constexpr index_t kBNLane = 16;
// static constexpr index_t kABKLane = 4;
// static constexpr index_t kABKPerLane = 4;
// static constexpr index_t kCMLane = 4;
// static constexpr index_t kCNLane = 16;
// static constexpr index_t kCM0PerLane = 1;
// static constexpr index_t kCM1PerLane = 4;
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<Impl::kBNLane>, // 32/16 warp shape N
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, // 2/4, 8 instruction data layout
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
tuple<sequence<0, 0>>, // 4 16
sequence<2>,
sequence<1>>{};
sequence<1>>{}; // 8
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{