mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
new layout sanity checked. not correct instruction generated
This commit is contained in:
@@ -600,6 +600,54 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{1.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{1.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{1.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 5)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{1.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{1.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 6)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{1.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{1.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{1.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 7)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{1.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{1.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{1.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{1.f, 1.f, seed}(do_host);
|
||||
}
|
||||
else if(init_method == 8)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{1.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{1.f, 1.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistribution<OGradDataType>{1.f, 1.f, seed}(do_host);
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<AccDataType>(nhead);
|
||||
@@ -619,9 +667,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
// for(int iM=0; iM<128; iM++){
|
||||
// for(int iK=0; iK<16; iK++){
|
||||
// printf("%04x ", *(reinterpret_cast<uint16_t*>(&(q_host(0, 0, iK, iM)))));
|
||||
// for(int iM=0; iM<128; iM++){
|
||||
// printf("%04x ", *(reinterpret_cast<uint16_t*>(&(q_host(0, 0, iK, iM)))));
|
||||
// if(iM%16==15){
|
||||
// printf("|");
|
||||
// }
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
@@ -76,6 +76,7 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
|
||||
sequence<SliceEnds...> slice_ends)
|
||||
{
|
||||
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
|
||||
using SrcDistribution = remove_cvref_t<SrcStaticTileDistribution_>;
|
||||
|
||||
constexpr auto sliced_dstr_yidx_ylen =
|
||||
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
|
||||
@@ -84,9 +85,10 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
|
||||
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
|
||||
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
|
||||
|
||||
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
static_assert(std::is_same_v<remove_cvref_t<decltype(sliced_dstr)>, SrcDistribution>, "wrong!");
|
||||
|
||||
dst_tile.set_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
dst_tile.set_y_sliced_thread_data(
|
||||
sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -186,7 +186,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 5000)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -246,7 +246,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 5000)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -305,7 +305,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 5000)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -360,7 +360,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 5000)
|
||||
{
|
||||
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
|
||||
<< std::endl;
|
||||
@@ -437,7 +437,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 5000)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
|
||||
@@ -495,7 +495,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 5000)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
|
||||
@@ -1277,14 +1277,18 @@ struct FmhaBwdDQDKDVKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
|
||||
constexpr auto kSeq0 = 64;
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
// make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
make_tuple(number<kSeq0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
// make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
make_tuple(number<kSeq0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
|
||||
@@ -125,8 +125,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
// kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
// kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -138,7 +138,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
"wrong!");
|
||||
|
||||
// if (threadIdx.x == 0){
|
||||
// HotLoopScheduler::print();
|
||||
// // HotLoopScheduler::print();
|
||||
// }
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
@@ -175,19 +175,21 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
constexpr auto kSeq0 = 64;
|
||||
|
||||
KDataType* k_lds_ptr =
|
||||
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
make_tile_window(k_lds, make_tuple(number<kSeq0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
make_tuple(number<kSeq0>{}, number<kQKHeaddim>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
@@ -207,39 +209,46 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kSeq0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
make_tuple(number<kSeq0>{}, number<kVHeaddim>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
|
||||
auto kt_block_tile = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
|
||||
|
||||
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
|
||||
|
||||
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
auto kt_lds_write = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
auto kt_lds_write_window = make_tile_window(
|
||||
kt_lds_write, make_tuple(number<kSeq0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(kt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
|
||||
make_tuple(number<kQKHeaddim>{}, number<kSeq0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeKTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Pre-Load KV into Registers
|
||||
#if 0
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
auto v_block_tile = load_tile(v_dram_window);
|
||||
|
||||
@@ -259,31 +268,108 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
|
||||
#if 0
|
||||
constexpr auto kSeq0 = 64;
|
||||
#elif 1
|
||||
// Looped data loading
|
||||
static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) {
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
#if 0
|
||||
if(get_block_1d_id()==0 && get_thread_local_1d_id()<256){
|
||||
printf("iter: %01d, Tid: %03d, K_global_read: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x |\n",
|
||||
i_n0.value, get_thread_local_1d_id(),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<0>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<1>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<2>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<3>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<4>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<5>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<6>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<7>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<0 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<1 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<2 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<3 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<4 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<5 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<6 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<7 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<0 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<1 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<2 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<3 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<4 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<5 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<6 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<7 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<0 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<1 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<2 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<3 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<4 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<5 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<6 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_block_tile.get_thread_buffer()[number<7 + 24>{}])))
|
||||
);
|
||||
}
|
||||
#endif
|
||||
move_tile_window(k_dram_window, {kSeq0, 0});
|
||||
|
||||
store_tile(k_lds_write_window, k_block_tile);
|
||||
|
||||
shuffle_distributed_tensor(kt_block_tile, k_block_tile);
|
||||
shuffle_tile(kt_block_tile, k_block_tile);
|
||||
store_tile(kt_lds_write_window, kt_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto k_reg_tensor_slice = load_tile(k_lds_read_window);
|
||||
#if 0
|
||||
if(get_block_1d_id()==0 && get_thread_local_1d_id()<256){
|
||||
printf("iter: %01d, Tid: %03d, K_lds_read: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x |\n",
|
||||
i_n0.value, get_thread_local_1d_id(),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<0>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<1>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<2>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<3>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<4>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<5>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<6>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<7>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<0 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<1 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<2 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<3 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<4 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<5 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<6 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<7 + 8>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<0 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<1 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<2 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<3 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<4 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<5 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<6 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<7 + 16>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<0 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<1 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<2 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<3 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<4 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<5 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<6 + 24>{}]))),
|
||||
*(reinterpret_cast<const uint16_t*>(&(k_reg_tensor_slice.get_thread_buffer()[number<7 + 24>{}])))
|
||||
);
|
||||
}
|
||||
#endif
|
||||
set_slice_tile(k_reg_tensor,
|
||||
k_reg_tensor_slice,
|
||||
Sequence<i_n0*kSeq0, 0>{},
|
||||
Sequence<(i_n0+1)*kSeq0, kQKHeaddim>{});
|
||||
sequence<i_n0 * kSeq0, 0>{},
|
||||
sequence<(i_n0 + 1) * kSeq0, kQKHeaddim>{});
|
||||
|
||||
auto kt_reg_tensor_slice = load_tile(kt_lds_read_window);
|
||||
set_slice_tile(kt_reg_tensor,
|
||||
kt_reg_tensor_slice,
|
||||
Sequence<0, i_n0*kSeq0>{},
|
||||
Sequence<kQKHeaddim, (i_n0+1)*kSeq0>{});
|
||||
sequence<0, i_n0 * kSeq0>{},
|
||||
sequence<kQKHeaddim, (i_n0 + 1) * kSeq0>{});
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
@@ -298,10 +384,32 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto v_reg_tensor_slice = load_tile(v_lds_read_window);
|
||||
set_slice_tile(v_reg_tensor,
|
||||
v_reg_tensor_slice,
|
||||
Sequence<i_n0*kSeq0, 0>{},
|
||||
Sequence<(i_n0+1)*kSeq0, kVHeaddim>{});
|
||||
sequence<i_n0 * kSeq0, 0>{},
|
||||
sequence<(i_n0 + 1) * kSeq0, kVHeaddim>{});
|
||||
block_sync_lds();
|
||||
});
|
||||
#if 0
|
||||
if(get_block_1d_id()==0 && get_thread_local_1d_id()<256){
|
||||
printf("Tid: %03d, K: %04x %04x %04x %04x %04x %04x %04x %04x \n",
|
||||
get_thread_local_1d_id(),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<0>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<1>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<2>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<3>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<4>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<5>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<6>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(k_reg_tensor.get_thread_buffer()[number<7>{}]))));
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
@@ -597,11 +705,32 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
move_tile_window(do_dram_window, {kM0, 0});
|
||||
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id()==0 && get_thread_local_1d_id()<64){
|
||||
printf("Tid: %02d, Q: %04x %04x %04x %04x %04x %04x %04x %04x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<0>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<1>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<2>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<3>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<4>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<5>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<6>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(q_reg_tensor.get_thread_buffer()[number<7>{}]))));
|
||||
}
|
||||
#endif
|
||||
auto dot_reg_tensor = load_tile(dot_lds_read_window);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<0>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// HotLoopScheduler::template GemmStagedScheduler<0>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -715,9 +844,54 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto qt_reg_tensor = load_tile(qt_lds_read_window);
|
||||
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#if 0
|
||||
if(get_block_1d_id()==0 && get_thread_local_1d_id()<64){
|
||||
printf("Tid: %02d, Pt: %04x %04x %04x %04x %04x %04x %04x %04x DoT: %04x %04x %04x %04x %04x %04x %04x %04x dv_acc: %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf\n",
|
||||
get_thread_local_1d_id(),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<0>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<1>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<2>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<3>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<4>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<5>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<6>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(pt_reg_tensor.get_thread_buffer()[number<7>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<0>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<1>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<2>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<3>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<4>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<5>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<6>{}]))),
|
||||
*(reinterpret_cast<const
|
||||
uint16_t*>(&(dot_reg_tensor.get_thread_buffer()[number<7>{}]))),
|
||||
dv_acc.get_thread_buffer()[number<0>{}],
|
||||
dv_acc.get_thread_buffer()[number<1>{}],
|
||||
dv_acc.get_thread_buffer()[number<2>{}],
|
||||
dv_acc.get_thread_buffer()[number<3>{}],
|
||||
dv_acc.get_thread_buffer()[number<4>{}],
|
||||
dv_acc.get_thread_buffer()[number<5>{}],
|
||||
dv_acc.get_thread_buffer()[number<6>{}],
|
||||
dv_acc.get_thread_buffer()[number<7>{}]);
|
||||
}
|
||||
#endif
|
||||
// HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
|
||||
@@ -737,8 +911,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
|
||||
@@ -848,8 +1022,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE7 SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
@@ -875,7 +1049,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
d = load_tile(d_lds_read_window);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
// HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
@@ -1010,13 +1184,15 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
}
|
||||
}();
|
||||
|
||||
Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(p_gemm)>(
|
||||
pt_reg_tensor, p_gemm);
|
||||
auto dot_reg_tensor = load_tile(dot_lds_read_window);
|
||||
// Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor),
|
||||
// decltype(p_gemm)>(
|
||||
// pt_reg_tensor, p_gemm);
|
||||
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
|
||||
auto dot_reg_tensor = load_tile(dot_lds_read_window);
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
@@ -1025,8 +1201,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
@@ -1085,8 +1261,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
@@ -1107,8 +1283,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
}
|
||||
});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
|
||||
@@ -22,6 +22,8 @@ namespace ck_tile {
|
||||
|
||||
struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
static constexpr index_t kKVSeq0 = 64;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
@@ -44,7 +46,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
false,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true,
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
|
||||
@@ -111,7 +114,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
|
||||
false,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true,
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
|
||||
@@ -391,7 +395,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
#elif 1
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kNPerBlock = kKVSeq0;
|
||||
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
@@ -401,7 +405,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
constexpr index_t kMGroup = kNPerBlock/16;
|
||||
constexpr index_t kMGroup = kNPerBlock / 16;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
@@ -437,7 +441,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#elif 1
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kNPerBlock = kKVSeq0;
|
||||
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
@@ -447,7 +451,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
constexpr index_t kMGroup = kNPerBlock/16;
|
||||
constexpr index_t kMGroup = kNPerBlock / 16;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
@@ -715,11 +719,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MWarp = 2;
|
||||
constexpr index_t KWarp = 2;
|
||||
constexpr index_t KRow = 2;
|
||||
constexpr index_t MRow = 2;
|
||||
constexpr index_t KBit0 = 2;
|
||||
constexpr index_t KBit1 = 2;
|
||||
constexpr index_t KBit2 = 2;
|
||||
constexpr index_t KBit3 = 2;
|
||||
constexpr index_t KBit4 = 2;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t MPair = 2;
|
||||
constexpr index_t MRepeat = 2;
|
||||
@@ -734,6 +738,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
// M = 2^4
|
||||
// K = 2^7
|
||||
|
||||
// constexpr index_t kMWarps = 2;
|
||||
// constexpr index_t kKWarps = 2;
|
||||
// constexpr index_t kKRow = 2;
|
||||
// constexpr index_t kMRow = 2;
|
||||
// constexpr index_t kRowsize = 16;
|
||||
// constexpr index_t K1 = 2;
|
||||
// constexpr index_t kMPair = 2;
|
||||
// constexpr index_t kMRepeat = 2;
|
||||
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<sequence<>,
|
||||
// tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
// sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
// tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
// tuple<sequence<0, 0>, sequence<1, 2, 2>>,
|
||||
// sequence<1, 1, 2>,
|
||||
// sequence<1, 3, 3>>{});
|
||||
|
||||
constexpr auto lds_16x128_block_desc_raw = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KWarp>{},
|
||||
number<MPair>{},
|
||||
@@ -742,39 +764,38 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
number<MRepeat>{},
|
||||
number<KBit1>{},
|
||||
number<KBit2>{},
|
||||
number<KBit4>{},
|
||||
number<MRow>{},
|
||||
number<KBit3>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{}),
|
||||
make_tuple(
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 *
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 *
|
||||
KBit1*(MRepeat * MWarp * KRow * MPair + 1)>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1 * MRepeat * MWarp>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1 * MRepeat>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 * KBit1 * MRepeat * MWarp>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 * KBit1 * MRepeat>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 * KBit1>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow>{},
|
||||
number<K1 * KBit0 * KBit3>{},
|
||||
number<K1 * KBit0>{},
|
||||
number<K1>{},
|
||||
number<1>{}),
|
||||
number<K1>{},
|
||||
number<K1 * KBit0 * KBit3>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_16x128_block_desc = transform_tensor_descriptor(
|
||||
lds_16x128_block_desc_raw,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MWarp>{}, number<MRepeat>{}, number<MPair>{})),
|
||||
make_tuple(number<MWarp>{}, number<MRepeat>{}, number<MRow>{},number<MPair>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<KWarp>{},
|
||||
number<KRow>{},
|
||||
number<KBit4>{},
|
||||
number<KBit3>{},
|
||||
number<KBit2>{},
|
||||
number<KBit1>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{}))),
|
||||
make_tuple(sequence<3, 4, 1>{}, sequence<0, 2, 7, 8, 6, 5, 9, 10>{}),
|
||||
make_tuple(sequence<3, 4, 7, 1>{}, sequence<0, 2, 8, 6, 5, 9, 10>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_16x128_block_desc;
|
||||
@@ -829,7 +850,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
number<MPair * MRow>{},
|
||||
number<MPair>{},
|
||||
number<1>{}),
|
||||
number<MPair>{},
|
||||
number<MPair * MRow * KGroup>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_16x128_trans_block_desc = transform_tensor_descriptor(
|
||||
@@ -849,6 +870,155 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return lds_16x128_trans_block_desc;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make64x128LdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MWarp = 2;
|
||||
constexpr index_t KWarp = 2;
|
||||
constexpr index_t KRow = 2;
|
||||
constexpr index_t MRow = 2;
|
||||
constexpr index_t KBit0 = 2;
|
||||
constexpr index_t KBit1 = 2;
|
||||
constexpr index_t KBit2 = 2;
|
||||
constexpr index_t KBit3 = 2;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t MPair = 2;
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t MGroup = 4;
|
||||
|
||||
// K:HeadDim, M:Seq, 13 Dimensions Total
|
||||
// I W T I V
|
||||
// Total: 4*4*64*4*2 = 2^13
|
||||
|
||||
// I W I T W I T T T T T V
|
||||
// 4 2 2 2 2 2 2 2 2 2 2 2
|
||||
// MGroup, KWarp, MPair, KRow, MWarp, MRepeat, KBit<1, 2, 4, 3, 0>, K1
|
||||
// M = 2^6
|
||||
// K = 2^7
|
||||
|
||||
constexpr auto lds_64x128_block_desc_raw = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MGroup>{},
|
||||
number<KWarp>{},
|
||||
number<MPair>{},
|
||||
number<KRow>{},
|
||||
number<MWarp>{},
|
||||
number<MRepeat>{},
|
||||
number<KBit1>{},
|
||||
number<KBit2>{},
|
||||
number<MRow>{},
|
||||
number<KBit3>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{}),
|
||||
make_tuple(
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 *
|
||||
KBit1*(MRepeat * MWarp * KRow * MPair + 1) * KWarp>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 *
|
||||
KBit1*(MRepeat * MWarp * KRow * MPair + 1)>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 * KBit1 * MRepeat * MWarp>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 * KBit1 * MRepeat>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2 * KBit1>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow * KBit2>{},
|
||||
number<K1 * KBit0 * KBit3 * MRow>{},
|
||||
number<K1 * KBit0 * KBit3>{},
|
||||
number<K1 * KBit0>{},
|
||||
number<K1>{},
|
||||
number<1>{}),
|
||||
number<K1 * KBit0 * KBit3>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_64x128_block_desc = transform_tensor_descriptor(
|
||||
lds_64x128_block_desc_raw,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<MGroup>{}, number<MWarp>{}, number<MRepeat>{}, number<MRow>{}, number<MPair>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<KWarp>{},
|
||||
number<KRow>{},
|
||||
number<KBit3>{},
|
||||
number<KBit2>{},
|
||||
number<KBit1>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{}))),
|
||||
make_tuple(sequence<0, 4, 5, 8, 2>{}, sequence<1, 3, 9, 7, 6, 10, 11>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_64x128_block_desc;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make64x128TransLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MWarp = 2;
|
||||
constexpr index_t KWarp = 2;
|
||||
constexpr index_t KRow = 2;
|
||||
constexpr index_t MRow = 2;
|
||||
constexpr index_t KGroup = 2;
|
||||
constexpr index_t KBit0 = 2;
|
||||
constexpr index_t KBit1 = 2;
|
||||
constexpr index_t KBit2 = 2;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t MPair = 2;
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t MGroup = 4;
|
||||
|
||||
// K:HeadDim, M:Seq, 13 Dimensions Total
|
||||
// I W T I V
|
||||
// Total: 4* 4*64*4*2 = 2^13
|
||||
// I W I T W I T T T T T V
|
||||
// 4 2 2 2 2 2 2 2 2 2 2 2
|
||||
// MGroup, Kwarp, K1, KRow, MWarp, MRepeat, <KBit1, KBit2, KBit0>, KGroup, MRow, MPair
|
||||
// M = 2^6
|
||||
// K = 2^7
|
||||
|
||||
constexpr auto lds_64x128_trans_block_desc_raw = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MGroup>{},
|
||||
number<KWarp>{},
|
||||
number<K1>{},
|
||||
number<KRow>{},
|
||||
number<MWarp>{},
|
||||
number<MRepeat>{},
|
||||
number<KBit1>{},
|
||||
number<KBit2>{},
|
||||
number<KBit0>{},
|
||||
number<KGroup>{},
|
||||
number<MRow>{},
|
||||
number<MPair>{}),
|
||||
make_tuple(number<MPair * MRow * KGroup * KBit0 * KBit2 *
|
||||
KBit1*(MRepeat * MWarp * KRow * K1 + 1) * KWarp>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2 *
|
||||
KBit1*(MRepeat * MWarp * KRow * K1 + 1)>{},
|
||||
// Padding
|
||||
number<MPair * MRow * KGroup *
|
||||
KBit0*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1 * MRepeat * MWarp>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1 * MRepeat>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2>{},
|
||||
number<MPair * MRow * KGroup * KBit0>{},
|
||||
number<MPair * MRow * KGroup>{},
|
||||
number<MPair * MRow>{},
|
||||
number<MPair>{},
|
||||
number<1>{}),
|
||||
number<MPair * MRow * KGroup>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_64x128_trans_block_desc = transform_tensor_descriptor(
|
||||
lds_64x128_trans_block_desc_raw,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(number<KWarp>{},
|
||||
number<KRow>{},
|
||||
number<KGroup>{},
|
||||
number<KBit2>{},
|
||||
number<KBit1>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<MGroup>{},
|
||||
number<MWarp>{},
|
||||
number<MRepeat>{},
|
||||
number<MRow>{},
|
||||
number<MPair>{}))),
|
||||
make_tuple(sequence<1, 3, 9, 7, 6, 8, 2>{}, sequence<0, 4, 5, 10, 11>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return lds_64x128_trans_block_desc;
|
||||
}
|
||||
|
||||
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, bool XorLdsLayout = true>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
|
||||
{
|
||||
@@ -1023,11 +1193,48 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
||||
#elif 1
|
||||
return Make64x128LdsBlockDescriptor();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
// constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kNPerBlock = kKVSeq0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1065,12 +1272,48 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
||||
#elif 1
|
||||
return Make64x128LdsBlockDescriptor();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = kKVSeq0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1108,28 +1351,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = GetTransposedAlignmentK<Problem>();
|
||||
constexpr index_t N1 = get_warp_size() / K0;
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
constexpr index_t kKRow = 2;
|
||||
constexpr index_t kMRow = 2;
|
||||
constexpr index_t kRowsize = 16;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
constexpr index_t kMGroup = kKVSeq0 / 16;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
tuple<sequence<kMGroup, kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 3, 2>>,
|
||||
sequence<1, 1, 2, 1>,
|
||||
sequence<0, 2, 3, 4>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor()
|
||||
{
|
||||
#if 0
|
||||
// Hold all data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
@@ -1138,6 +1383,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
|
||||
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
#elif 1
|
||||
return Make64x128TransLdsBlockDescriptor();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1156,6 +1404,38 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = kKVSeq0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto kt_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
|
||||
|
||||
return kt_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor()
|
||||
{
|
||||
@@ -1261,7 +1541,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
|
||||
{
|
||||
return Make16x128TransLdsBlockDescriptor();
|
||||
return Make16x128TransLdsBlockDescriptor();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -2047,6 +2327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
private:
|
||||
// Read64Seq per tile for KV
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
@@ -25,6 +25,11 @@ using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterate
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K32StaggeredK = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
true>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
@@ -84,6 +89,11 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaItera
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32StaggeredK = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
true>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
@@ -73,7 +73,7 @@ struct WarpGemmAtrributeMfma
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, bool StaggeredK = false>
|
||||
struct WarpGemmAtrributeMfmaIterateK
|
||||
{
|
||||
static_assert(kKIter > 0, "wrong!");
|
||||
@@ -102,79 +102,161 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
if constexpr(!StaggeredK)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
else
|
||||
{
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<kKIter, Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
tuple<sequence<Impl::kAMLane>,
|
||||
sequence<kKIter, Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMBlock, Impl::kAMLane>,
|
||||
sequence<kKIter, Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 1, 1>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
if constexpr(!StaggeredK)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
else
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<kKIter, Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
|
||||
sequence<kKIter, Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 1, 1>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
tuple<sequence<Impl::kBNLane>,
|
||||
sequence<kKIter, Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 1, 0>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
bool SwizzleA = false,
|
||||
bool StaggeredK = false>
|
||||
struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
@@ -35,6 +36,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32StaggeredK; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
@@ -51,6 +53,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32StaggeredK; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
@@ -72,7 +75,8 @@ template <typename AType,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false>
|
||||
bool SwizzleA = false,
|
||||
bool StaggeredK = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
@@ -80,6 +84,7 @@ using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA>::Type;
|
||||
SwizzleA,
|
||||
StaggeredK>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user