mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
temp
This commit is contained in:
@@ -108,7 +108,7 @@ else()
|
||||
endif()
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wno-missing-field-initializers
|
||||
-Wno-deprecated-declarations
|
||||
# -Wno-deprecated-declarations
|
||||
)
|
||||
endif()
|
||||
add_definitions(${CMAKE_COMPILER_WARNINGS})
|
||||
|
||||
@@ -106,6 +106,10 @@ static void run(const ck_tile::stream_config& s, fmha_batch_decode_args a)
|
||||
auto [kargs, grids] = fmha_batch_decode_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
printf("run into fmha_batch_decode_bf16_nlogits_nbias_nmask_lse_nsquant_pagedkv\\n");
|
||||
printf("blocks: %d, %d, %d\\n", blocks.x, blocks.y, blocks.z);
|
||||
printf("grids: %d, %d, %d\\n", grids.x, grids.y, grids.z);
|
||||
printf("kBlockPerCu: %d\\n", kBlockPerCu);
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
|
||||
@@ -62,6 +62,7 @@
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/debug.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
|
||||
@@ -205,7 +205,7 @@ struct tile_scatter_gather_debug
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{}); // NumAccessPerCoord = 4
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
@@ -338,14 +338,14 @@ struct tile_scatter_gather_debug
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
// read from bottom tensor
|
||||
auto idxx = bottom_tensor_thread_coord.get_index();
|
||||
auto idxx = bottom_tensor_thread_coord.get_index(); // 16 * 128
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
// printf("bid %d tid %d coord_offset:%d %d %d, page_offset:%d v %f\n", blockIdx.x, threadIdx.x,
|
||||
// bottom_tensor_thread_coord.get_offset(), idxx(I0), idxx(I1), page_offset, type_convert<float>(vec_value.template get_as<DataType>()[0]), type_convert<float>(vec_value.template get_as<DataType>()[4]));
|
||||
// printf("bid %d tid %d coord_offset:%d %d %d, page_offset:%d v %f %f \n", blockIdx.x, threadIdx.x,
|
||||
// bottom_tensor_thread_coord.get_offset(), idxx(I0), idxx(I1), page_offset, type_convert<float>(vec_value.template get_as<DataType>()[0]), type_convert<float>(vec_value.template get_as<DataType>()[0]), type_convert<float>(vec_value.template get_as<DataType>()[4]));
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
@@ -354,9 +354,9 @@ struct tile_scatter_gather_debug
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<NDimY>{}); // idx_ys = tuple<ck_tile::constant<0>, ck_tile::constant<2>, ck_tile::constant<4>>
|
||||
|
||||
constexpr index_t d =
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
|
||||
@@ -738,8 +738,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
make_tuple(kargs.num_total_pages / 16, kargs.hdim_v, 16 / 8, 8),
|
||||
make_tuple(kargs.hdim_v * 16, 16, 8, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -747,15 +747,15 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel
|
||||
v_dram_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_merge_transform(make_tuple(kargs.num_total_pages / 16, 16 / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
// constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
sequence<kPadHeadDimV, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -219,7 +219,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemmPreshuffled<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemmPreshuffled<Problem>();
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -259,7 +259,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
const index_t num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); // [64, 0, 64 -> 1]
|
||||
|
||||
// check early exit if no work to do
|
||||
if(num_total_loop <= 0)
|
||||
@@ -297,22 +298,24 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
auto v_coord = v_dist.calculate_index();
|
||||
const auto VPageIndexDim = I1;
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
|
||||
auto v_dist = Policy::template MakeVDramTileDistributionPreshuffled<Problem>();
|
||||
auto v_coord = v_dist.calculate_index();
|
||||
const auto VPageIndexDim = I1;
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I0];
|
||||
statically_indexed_array<index_t, V_KRepeat> v_offsets;
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = kv_page_indices[v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
});
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
v_dist,
|
||||
v_offsets,
|
||||
VPageIndexDim);
|
||||
// v_dram_block_window_tmp.get_bottom_tensor_view(): see dram_bottom_tensor_view_record
|
||||
auto v_dram_window = make_tile_scatter_gather_debug(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp
|
||||
.get_window_lengths(), // tuple<ck_tile::constant<128>, ck_tile::constant<64>>
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
v_dist,
|
||||
v_offsets,
|
||||
VPageIndexDim);
|
||||
|
||||
// store Q into LDS
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -343,21 +346,25 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
static_assert(1 <= k1_loops);
|
||||
|
||||
auto k_dram_window = [&] {
|
||||
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
// constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I1];
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = kv_page_indices[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * stride_k;
|
||||
k_offsets[n0] =
|
||||
kv_page_indices[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * stride_k;
|
||||
// printf("threadIdx.x %d, k_offsets[%d, %d] = %d\n", threadIdx.x, k_coord[0],
|
||||
// k_coord[1], k_offsets[n0]);
|
||||
});
|
||||
|
||||
return make_tile_scatter_gather_debug(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
return make_tile_scatter_gather_debug(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(), // [64, 128]
|
||||
k_dram_block_window.get_window_origin(), //
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
}();
|
||||
|
||||
// load the first tile of the first iteration and store to LDS
|
||||
@@ -379,7 +386,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// if(blockIdx.x==0)
|
||||
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, type_convert<float>(k_block_tile(i_j_idx)));
|
||||
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x,
|
||||
// type_convert<float>(k_block_tile(i_j_idx)));
|
||||
// });
|
||||
// });
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
@@ -409,19 +417,21 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
// store_tile(
|
||||
// k_lds_window,
|
||||
// tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
});
|
||||
|
||||
|
||||
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// if(blockIdx.x==0)
|
||||
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, type_convert<float>(k_block_tile(i_j_idx)));
|
||||
// printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x,
|
||||
// type_convert<float>(k_block_tile(i_j_idx)));
|
||||
// });
|
||||
// });
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
// const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
const auto v_block_tile = load_tile(v_dram_window);
|
||||
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = kv_page_indices[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
@@ -488,8 +498,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) {
|
||||
x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#else
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
@@ -545,22 +554,23 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window = [&] {
|
||||
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled<Problem>();
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] =
|
||||
(kv_page_indices + kN0)[k_coord[0] + kN0 / NRepeat * n0.value] / 16 *
|
||||
(kv_page_indices + kN0)[k_coord[0] + kN0 / NRepeat * n0.value] / 16 *
|
||||
stride_k;
|
||||
});
|
||||
|
||||
return make_tile_scatter_gather_debug(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
return make_tile_scatter_gather_debug(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
}();
|
||||
|
||||
// laod the first tile of the first iteration and store to LDS
|
||||
@@ -682,28 +692,30 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
// block_sync_lds();
|
||||
// if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
// Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
// shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
// store_tile(
|
||||
// v_lds_window,
|
||||
// tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// store_tile(v_lds_window,
|
||||
// tile_elementwise_in(v_element_func, v_prefetch)); // store the
|
||||
// prefetch
|
||||
// }
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
// const auto v = load_tile(v_dram_window); // load next v
|
||||
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = kv_page_indices[kK1 * 2 + i_k1.value * kK1 +
|
||||
@@ -713,35 +725,152 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
|
||||
// gemm_1(o_acc,
|
||||
// get_slice_tile(
|
||||
// p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
// v_block_tile);
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v);
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
// if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
// Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
// shuffle_tile(v_shuffle_tmp, v);
|
||||
// store_tile(v_lds_window,
|
||||
// tile_elementwise_in(v_element_func,
|
||||
// v_shuffle_tmp)); // store the prefetch
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// store_tile(v_lds_window,
|
||||
// tile_elementwise_in(v_element_func, v)); // store next v
|
||||
// }
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
const auto v_block_tile = load_tile(v_dram_window);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
auto temp = get_slice_tile(
|
||||
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{});
|
||||
|
||||
// using a = ck_tile::static_distributed_tensor<
|
||||
// unsigned short,
|
||||
// ck_tile::tile_distribution<
|
||||
// ck_tile::tensor_adaptor<
|
||||
// ck_tile::tuple<
|
||||
// ck_tile::replicate<ck_tile::tuple<ck_tile::constant<4>>>,
|
||||
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<1>,
|
||||
// ck_tile::constant<1>,
|
||||
// ck_tile::constant<16>>,
|
||||
// false>,
|
||||
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<1>,
|
||||
// ck_tile::constant<4>,
|
||||
// ck_tile::constant<4>,
|
||||
// ck_tile::constant<4>>,
|
||||
// false>,
|
||||
// ck_tile::merge_v2_magic_division<
|
||||
// ck_tile::tuple<ck_tile::constant<1>, ck_tile::constant<4>>>,
|
||||
// ck_tile::merge_v2_magic_division<
|
||||
// ck_tile::tuple<ck_tile::constant<4>,
|
||||
// ck_tile::constant<16>>>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<>,
|
||||
// ck_tile::sequence<0>,
|
||||
// ck_tile::sequence<1>,
|
||||
// ck_tile::sequence<4, 2>,
|
||||
// ck_tile::sequence<8, 5>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<2>,
|
||||
// ck_tile::sequence<3, 4, 5>,
|
||||
// ck_tile::sequence<6, 7, 8, 9>,
|
||||
// ck_tile::sequence<10>,
|
||||
// ck_tile::sequence<11>>,
|
||||
// ck_tile::sequence<0, 1>,
|
||||
// ck_tile::sequence<10, 11, 3, 6, 7, 9>>,
|
||||
// ck_tile::tensor_descriptor<
|
||||
// ck_tile::tuple<ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<1>,
|
||||
// ck_tile::constant<1>,
|
||||
// ck_tile::constant<4>,
|
||||
// ck_tile::constant<4>>,
|
||||
// false>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 2, 3, 4>>,
|
||||
// ck_tile::sequence<1, 2, 3, 4>,
|
||||
// ck_tile::constant<16>,
|
||||
// ck_tile::sequence<-1, -1, -1, -1, -1>,
|
||||
// ck_tile::sequence<-1, -1, -1, -1, -1>>,
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<4>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 1, 16>,
|
||||
// ck_tile::sequence<1, 4, 4, 4>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 0>, ck_tile::sequence<2, 2>>,
|
||||
// ck_tile::sequence<1, 2, 2, 2>,
|
||||
// ck_tile::sequence<0, 0, 1, 3>>,
|
||||
// ck_tile::detail::tile_distribution_detail<
|
||||
// ck_tile::tuple<ck_tile::sequence<2>,
|
||||
// ck_tile::sequence<3, 4, 5>,
|
||||
// ck_tile::sequence<6, 7, 8, 9>>>>>;
|
||||
|
||||
// using b = ck_tile::static_distributed_tensor<
|
||||
// unsigned short,
|
||||
// ck_tile::tile_distribution<
|
||||
// ck_tile::tensor_adaptor<
|
||||
// ck_tile::tuple<
|
||||
// ck_tile::replicate<ck_tile::tuple<ck_tile::constant<1>>>,
|
||||
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<4>,
|
||||
// ck_tile::constant<2>,
|
||||
// ck_tile::constant<16>>,
|
||||
// false>,
|
||||
// ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<2>,
|
||||
// ck_tile::constant<2>,
|
||||
// ck_tile::constant<2>,
|
||||
// ck_tile::constant<8>>,
|
||||
// false>,
|
||||
// ck_tile::merge_v2_magic_division<ck_tile::tuple<ck_tile::constant<4>>>,
|
||||
// ck_tile::merge_v2_magic_division<
|
||||
// ck_tile::tuple<ck_tile::constant<2>,
|
||||
// ck_tile::constant<2>,
|
||||
// ck_tile::constant<16>>>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<>,
|
||||
// ck_tile::sequence<0>,
|
||||
// ck_tile::sequence<1>,
|
||||
// ck_tile::sequence<3>,
|
||||
// ck_tile::sequence<7, 8, 5>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<2>,
|
||||
// ck_tile::sequence<3, 4, 5>,
|
||||
// ck_tile::sequence<6, 7, 8, 9>,
|
||||
// ck_tile::sequence<10>,
|
||||
// ck_tile::sequence<11>>,
|
||||
// ck_tile::sequence<0, 1>,
|
||||
// ck_tile::sequence<10, 11, 6, 4, 9>>,
|
||||
// ck_tile::tensor_descriptor<
|
||||
// ck_tile::tuple<ck_tile::unmerge<ck_tile::tuple<ck_tile::constant<2>,
|
||||
// ck_tile::constant<2>,
|
||||
// ck_tile::constant<8>>,
|
||||
// false>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1, 2, 3>>,
|
||||
// ck_tile::sequence<1, 2, 3>,
|
||||
// ck_tile::constant<32>,
|
||||
// ck_tile::sequence<-1, -1, -1, -1>,
|
||||
// ck_tile::sequence<-1, -1, -1, -1>>,
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<1>,
|
||||
// ck_tile::tuple<ck_tile::sequence<4, 2, 16>, ck_tile::sequence<2, 2, 2, 8>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<1>, ck_tile::sequence<2, 2, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0>, ck_tile::sequence<1, 2, 2>>,
|
||||
// ck_tile::sequence<2, 1, 2>,
|
||||
// ck_tile::sequence<0, 1, 3>>,
|
||||
// ck_tile::detail::tile_distribution_detail<
|
||||
// ck_tile::tuple<ck_tile::sequence<2>,
|
||||
// ck_tile::sequence<3, 4, 5>,
|
||||
// ck_tile::sequence<6, 7, 8, 9>>>>>
|
||||
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
|
||||
v_lds_window);
|
||||
v_block_tile);
|
||||
block_sync_lds();
|
||||
}
|
||||
kv_page_indices += kN0;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v3.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -59,20 +60,20 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
constexpr index_t KPerThread = kMaxVecLoad; // 8
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread; // 16
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; // 4
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size(); // 4
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); // 1
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>, // 1, 4, 4
|
||||
sequence<KThreads, KPerThread>>, // 16, 8
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // NumWarps, MThreadPerWarp, KThreads 4, 4, 16
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
sequence<0, 1>>{}); // MPerThread, KPerThread 1, 8
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
@@ -227,26 +228,58 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; // 64
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; // 128
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; // 32
|
||||
|
||||
constexpr index_t K2 = min(MaxVectorSize, ElemPerThread); //8
|
||||
constexpr index_t K2 = min(MaxVectorSize, ElemPerThread); // 8
|
||||
constexpr index_t K1 = 4;
|
||||
constexpr index_t K0 = kKPerBlock / K1 / K2; // 2
|
||||
constexpr index_t N2 = 16; // 8
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
constexpr index_t K0 = kKPerBlock / K1 / K2; // 4
|
||||
constexpr index_t N2 = 16;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size(); // 4
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1); // 1
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>, // 1, 4, 16, 4, 4, 8
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, K1, N2 : 4, 4, 16
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
sequence<0, 0, 2>>{}); // N0, K0, K2 : 1, 4, 8
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistributionPreshuffled()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
|
||||
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>, "wrong");
|
||||
|
||||
constexpr index_t N2 = 16;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // 32
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
constexpr index_t K3 = min(MaxVectorSize, total_pixels); // 8
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t K0 = kKPerBlock / K1 / K2 / K3; // 2
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size(); // 4
|
||||
constexpr index_t N1 = kNPerBlock / (N2 * N0); // 2
|
||||
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2, K3>>, // 4, 2, 16, 2, 2, 2, 8
|
||||
tuple<sequence<1>, sequence<2, 2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 2, 2>>, // N0, K1, K2, N2 : 4, 2, 2, 16
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 1, 3>>{}); // K0, N1, K3 : 2, 2, 8
|
||||
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
@@ -320,6 +353,67 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy
|
||||
static_assert(1 < Problem::kNumGemm0Warps);
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmPreshuffled()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kNumGemm1Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::PDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::PDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::PDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
static_assert(1 < Problem::kNumGemm1Warps);
|
||||
return BlockGemmARegBRegCRegV3<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -152,7 +152,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
|
||||
// this should align with MakeQDramTileDistribution()
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; // 16 * 128 / 256 = 8
|
||||
static_assert(0 < ElemPerThread);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
@@ -729,16 +729,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread); //8
|
||||
constexpr index_t K0 = kKPerBlock / K1; // 8
|
||||
constexpr index_t N2 = get_warp_size() / K0; // 8
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size(); // 4
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1); // 4
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, // 4, 4, 8, 8, 8
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // N1, N2, K0: 4, 8, 8
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
sequence<0, 1>>{}); // N0, K1: 4,8
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -778,32 +778,32 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
|
||||
constexpr index_t N0 = kNPerBlock / N1; // 16
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // 32
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1; // 4
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // 8
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave // 2
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0); // 2
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size(); // 4
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, // 16, 8, 4, 2, 2, 4
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>, // K0, K1, N0, K2: 4, 2, 16, 2
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
sequence<3, 1>>{}); // K3, N1: 4, 8
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -44,15 +44,16 @@ struct BlockGemmARegBRegCRegV2
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
// const const tuple<ck_tile::WarpGemmImpl<ck_tile::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<ck_tile::WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<ck_tile::WGAttrCtlEnum::Default_>, 2>>, int, int>
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr index_t MWarp = config.template at<1>(); // 1
|
||||
constexpr index_t NWarp = config.template at<2>(); // 4
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // WG::kM, kN, kK: 16, 16, 32
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
@@ -60,13 +61,18 @@ struct BlockGemmARegBRegCRegV2
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
// if(threadIdx.x%64==0)
|
||||
// printf("tid %d %d %d %d %d KIterPerWarp %d\n",threadIdx.x, MPerBlock, NPerBlock, NIterPerWarp, iNWarp, KIterPerWarp);
|
||||
|
||||
// tid 0 16 64 1 0 KIterPerWarp 4
|
||||
// tid 64 16 64 1 1 KIterPerWarp 4
|
||||
// tid 128 16 64 1 2 KIterPerWarp 4
|
||||
// tid 192 16 64 1 3 KIterPerWarp 4
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
tuple<sequence<1, 1>>, // MWarp, NWarp 1, 4
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
sequence<0, 0>>{}; // MIterPerWarp, NIterPerWarp, 1, 1
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
@@ -123,9 +129,10 @@ struct BlockGemmARegBRegCRegV2
|
||||
// printf("get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size());
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size());
|
||||
// printf("b get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size());
|
||||
// auto &xx = b_warp_tensor.get_thread_buffer();
|
||||
// printf("bid %d tid %d b %f %f %f %f\n", blockIdx.x, threadIdx.x, type_convert<float>(xx[number<0>{}]), type_convert<float>(xx[number<2>{}]), type_convert<float>(xx[number<4>{}]), type_convert<float>(xx[number<6>{}]));
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
@@ -196,19 +203,19 @@ struct BlockGemmARegBRegCRegV2
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr index_t MWarp = config.template at<1>(); // 1
|
||||
constexpr index_t NWarp = config.template at<2>(); // 4
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 1
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 4
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tile_distribution_encoding<sequence<NWarp>, // 4
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // 1, 1, 4
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>, // MWarp, NWarp 1, 4
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
sequence<0, 0>>{}; // MIterPerWarp, KIterPerWarp 1, 4
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
@@ -478,16 +478,32 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
// static constexpr index_t kM = 16;
|
||||
// static constexpr index_t kN = 16;
|
||||
// static constexpr index_t kK = 16;
|
||||
// static constexpr index_t kAMBlock = 1;
|
||||
// static constexpr index_t kBNBlock = 1;
|
||||
|
||||
// static constexpr index_t kAMLane = 16;
|
||||
// static constexpr index_t kBNLane = 16;
|
||||
// static constexpr index_t kABKLane = 4;
|
||||
// static constexpr index_t kABKPerLane = 4;
|
||||
|
||||
// static constexpr index_t kCMLane = 4;
|
||||
// static constexpr index_t kCNLane = 16;
|
||||
// static constexpr index_t kCM0PerLane = 1;
|
||||
// static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<Impl::kBNLane>, // 32/16 warp shape N
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, // 2/4, 8 instruction data layout
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
tuple<sequence<0, 0>>, // 4, 16
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
sequence<1>>{}; // 8
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user