mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
temp save. change the memory layout based on basic unit of 16x128.
This commit is contained in:
@@ -38,8 +38,7 @@ using fmha_bwd_convert_dq_0 =
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 =
|
||||
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
|
||||
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
@@ -67,6 +66,7 @@ using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
using fmha_warp_tile2_0 = ck_tile::sequence<16, 32, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
@@ -76,16 +76,16 @@ using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_warp_tile2_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_warp_tile2_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
@@ -619,6 +619,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
// for(int iM=0; iM<128; iM++){
|
||||
// for(int iK=0; iK<16; iK++){
|
||||
// printf("%04x ", *(reinterpret_cast<uint16_t*>(&(q_host(0, 0, iK, iM)))));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -1081,8 +1081,26 @@ struct sorted_sequence_histogram<h_idx, sequence<x>, sequence<r, rs...>>
|
||||
{
|
||||
h.template at<h_idx>() += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_sequence_histogram<h_idx + 1, sequence<x>, sequence<rs...>>{}(h);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t h_idx, index_t x, index_t r>
|
||||
struct sorted_sequence_histogram<h_idx, sequence<x>, sequence<r>>
|
||||
{
|
||||
template <typename Histogram>
|
||||
constexpr auto operator()(Histogram& h)
|
||||
{
|
||||
if constexpr(x < r)
|
||||
{
|
||||
h.template at<h_idx>() += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename, index_t>
|
||||
|
||||
@@ -46,23 +46,22 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
return rh_major_minor_to_y_;
|
||||
};
|
||||
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
|
||||
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
|
||||
|
||||
constexpr auto y_dim_out_to_in = [&] {
|
||||
map<index_t, index_t> y_dim_out_to_in_;
|
||||
|
||||
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
|
||||
{
|
||||
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
|
||||
}
|
||||
|
||||
return y_dim_out_to_in_;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
|
||||
|
||||
using OutDstrEncode = typename decltype(out_tensor.get_tile_distribution())::DstrEncode;
|
||||
// using InDstrEncode = typename decltype(in_tensor.get_tile_distribution())::DstrEncode;
|
||||
|
||||
constexpr auto y_dim_out_to_in = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
constexpr index_t rh_major_out = OutDstrEncode::ys_to_rhs_major_[i];
|
||||
constexpr index_t rh_minor_out = OutDstrEncode::ys_to_rhs_minor_[i];
|
||||
|
||||
return number<rh_major_minor_to_y_in[{rh_major_out, rh_minor_out}]>{};
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
|
||||
// input and output vector dim in the order of input Y dims
|
||||
@@ -128,7 +127,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
constexpr auto idx_y_out_tmp = generate_tuple(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
number<NDimY>{});
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
|
||||
|
||||
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
|
||||
|
||||
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
dst_tile.set_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -258,6 +258,51 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
block_sync_lds();
|
||||
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
|
||||
#if 0
|
||||
constexpr auto kSeq0 = 64;
|
||||
// Looped data loading
|
||||
static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) {
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {kSeq0, 0});
|
||||
|
||||
store_tile(k_lds_write_window, k_block_tile);
|
||||
|
||||
shuffle_distributed_tensor(kt_block_tile, k_block_tile);
|
||||
store_tile(kt_lds_write_window, kt_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto k_reg_tensor_slice = load_tile(k_lds_read_window);
|
||||
set_slice_tile(k_reg_tensor,
|
||||
k_reg_tensor_slice,
|
||||
Sequence<i_n0*kSeq0, 0>{},
|
||||
Sequence<(i_n0+1)*kSeq0, kQKHeaddim>{});
|
||||
|
||||
auto kt_reg_tensor_slice = load_tile(kt_lds_read_window);
|
||||
set_slice_tile(kt_reg_tensor,
|
||||
kt_reg_tensor_slice,
|
||||
Sequence<0, i_n0*kSeq0>{},
|
||||
Sequence<kQKHeaddim, (i_n0+1)*kSeq0>{});
|
||||
block_sync_lds();
|
||||
});
|
||||
|
||||
static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) {
|
||||
auto v_block_tile = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {kSeq0, 0});
|
||||
|
||||
store_tile(v_lds_write_window, v_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
auto v_reg_tensor_slice = load_tile(v_lds_read_window);
|
||||
set_slice_tile(v_reg_tensor,
|
||||
v_reg_tensor_slice,
|
||||
Sequence<i_n0*kSeq0, 0>{},
|
||||
Sequence<(i_n0+1)*kSeq0, kVHeaddim>{});
|
||||
block_sync_lds();
|
||||
});
|
||||
#endif
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
auto q_dram_window =
|
||||
@@ -738,12 +783,61 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
decltype(dst_reg_tensor),
|
||||
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
|
||||
// Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
// decltype(dst_reg_tensor),
|
||||
// decltype(ds_gemm)>(dst_reg_tensor,
|
||||
// ds_gemm);
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
// if(get_block_1d_id()==0 && get_thread_local_1d_id()<64 &&i_total_loops==0){
|
||||
// printf("Tid: %02d, Qt: %04x %04x %04x %04x %04x %04x %04x %04x DsT: %04x %04x
|
||||
// %04x %04x %04x %04x %04x %04x dk_acc: %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf
|
||||
// %.4lf\n",
|
||||
// get_thread_local_1d_id(),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<0>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<1>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<2>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<3>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<4>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<5>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<6>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<7>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<0>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<1>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<2>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<3>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<4>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<5>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<6>{}]))),
|
||||
// *(reinterpret_cast<const
|
||||
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<7>{}]))),
|
||||
// dk_acc.get_thread_buffer()[number<0>{}],
|
||||
// dk_acc.get_thread_buffer()[number<1>{}],
|
||||
// dk_acc.get_thread_buffer()[number<2>{}],
|
||||
// dk_acc.get_thread_buffer()[number<3>{}],
|
||||
// dk_acc.get_thread_buffer()[number<4>{}],
|
||||
// dk_acc.get_thread_buffer()[number<5>{}],
|
||||
// dk_acc.get_thread_buffer()[number<6>{}],
|
||||
// dk_acc.get_thread_buffer()[number<7>{}]);
|
||||
// }
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -977,9 +1071,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
decltype(dst_reg_tensor),
|
||||
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
|
||||
// Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
// decltype(dst_reg_tensor),
|
||||
// decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
|
||||
@@ -371,6 +371,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
@@ -389,11 +390,34 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
#elif 1
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
constexpr index_t kKRow = 2;
|
||||
constexpr index_t kMRow = 2;
|
||||
constexpr index_t kRowsize = 16;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
constexpr index_t kMGroup = kNPerBlock/16;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kMGroup, kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 3, 2>>,
|
||||
sequence<1, 1, 1, 2>,
|
||||
sequence<0, 2, 4, 3>>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
@@ -412,52 +436,72 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#elif 1
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
constexpr index_t kKRow = 2;
|
||||
constexpr index_t kMRow = 2;
|
||||
constexpr index_t kRowsize = 16;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
constexpr index_t kMGroup = kNPerBlock/16;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kMGroup, kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 3, 2>>,
|
||||
sequence<1, 1, 1, 2>,
|
||||
sequence<0, 2, 4, 3>>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M1 = get_warp_size() / K0;
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
constexpr index_t kKRow = 2;
|
||||
constexpr index_t kMRow = 2;
|
||||
constexpr index_t kRowsize = 16;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
|
||||
sequence<1, 1, 2>,
|
||||
sequence<1, 3, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M1 = get_warp_size() / K0;
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = kMPerBlock / (M1 * M0);
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
constexpr index_t kKRow = 2;
|
||||
constexpr index_t kMRow = 2;
|
||||
constexpr index_t kRowsize = 16;
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
|
||||
sequence<1, 1, 2>,
|
||||
sequence<1, 3, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
@@ -666,6 +710,145 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return 16 / sizeof(GemmDataType);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make16x128LdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MWarp = 2;
|
||||
constexpr index_t KWarp = 2;
|
||||
constexpr index_t KRow = 2;
|
||||
constexpr index_t KBit0 = 2;
|
||||
constexpr index_t KBit1 = 2;
|
||||
constexpr index_t KBit2 = 2;
|
||||
constexpr index_t KBit3 = 2;
|
||||
constexpr index_t KBit4 = 2;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t MPair = 2;
|
||||
constexpr index_t MRepeat = 2;
|
||||
|
||||
// K:HeadDim, M:Seq, 11 Dimensions Total
|
||||
// W T I V
|
||||
// Total: 4*64*4*2 = 2^11
|
||||
|
||||
// W I T W I T T T T T V
|
||||
// 2 2 2 2 2 2 2 2 2 2 2
|
||||
// KWarp, MPair, KRow, MWarp, MRepeat, KBit<1, 2, 4, 3, 0>, K1
|
||||
// M = 2^4
|
||||
// K = 2^7
|
||||
|
||||
constexpr auto lds_16x128_block_desc_raw = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KWarp>{},
|
||||
number<MPair>{},
|
||||
number<KRow>{},
|
||||
number<MWarp>{},
|
||||
number<MRepeat>{},
|
||||
number<KBit1>{},
|
||||
number<KBit2>{},
|
||||
number<KBit4>{},
|
||||
number<KBit3>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{}),
|
||||
make_tuple(
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 *
|
||||
KBit1*(MRepeat * MWarp * KRow * MPair + 1)>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1 * MRepeat * MWarp>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1 * MRepeat>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4 * KBit2>{},
|
||||
number<K1 * KBit0 * KBit3 * KBit4>{},
|
||||
number<K1 * KBit0 * KBit3>{},
|
||||
number<K1 * KBit0>{},
|
||||
number<K1>{},
|
||||
number<1>{}),
|
||||
number<K1>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_16x128_block_desc = transform_tensor_descriptor(
|
||||
lds_16x128_block_desc_raw,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MWarp>{}, number<MRepeat>{}, number<MPair>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(number<KWarp>{},
|
||||
number<KRow>{},
|
||||
number<KBit4>{},
|
||||
number<KBit3>{},
|
||||
number<KBit2>{},
|
||||
number<KBit1>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{}))),
|
||||
make_tuple(sequence<3, 4, 1>{}, sequence<0, 2, 7, 8, 6, 5, 9, 10>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_16x128_block_desc;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make16x128TransLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MWarp = 2;
|
||||
constexpr index_t KWarp = 2;
|
||||
constexpr index_t KRow = 2;
|
||||
constexpr index_t MRow = 2;
|
||||
constexpr index_t KGroup = 2;
|
||||
constexpr index_t KBit0 = 2;
|
||||
constexpr index_t KBit1 = 2;
|
||||
constexpr index_t KBit2 = 2;
|
||||
constexpr index_t K1 = 2;
|
||||
constexpr index_t MPair = 2;
|
||||
constexpr index_t MRepeat = 2;
|
||||
|
||||
// K:HeadDim, M:Seq, 11 Dimensions Total
|
||||
// W T I V
|
||||
// Total: 4*64*4*2 = 2^11
|
||||
// W I T W I T T T T T V
|
||||
// 2 2 2 2 2 2 2 2 2 2 2
|
||||
// Kwarp, K1, KRow, MWarp, MRepeat, <KBit1, KBit2, KBit0>, KGroup, MRow, MPair
|
||||
// M = 2^4
|
||||
// K = 2^7
|
||||
|
||||
constexpr auto lds_16x128_trans_block_desc_raw = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KWarp>{},
|
||||
number<K1>{},
|
||||
number<KRow>{},
|
||||
number<MWarp>{},
|
||||
number<MRepeat>{},
|
||||
number<KBit1>{},
|
||||
number<KBit2>{},
|
||||
number<KBit0>{},
|
||||
number<KGroup>{},
|
||||
number<MRow>{},
|
||||
number<MPair>{}),
|
||||
make_tuple(number<MPair * MRow * KGroup * KBit0 * KBit2 *
|
||||
KBit1*(MRepeat * MWarp * KRow * K1 + 1)>{},
|
||||
// Padding
|
||||
number<MPair * MRow * KGroup *
|
||||
KBit0*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1 * MRepeat * MWarp>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1 * MRepeat>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1>{},
|
||||
number<MPair * MRow * KGroup * KBit0 * KBit2>{},
|
||||
number<MPair * MRow * KGroup * KBit0>{},
|
||||
number<MPair * MRow * KGroup>{},
|
||||
number<MPair * MRow>{},
|
||||
number<MPair>{},
|
||||
number<1>{}),
|
||||
number<MPair>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_16x128_trans_block_desc = transform_tensor_descriptor(
|
||||
lds_16x128_trans_block_desc_raw,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(number<KWarp>{},
|
||||
number<KRow>{},
|
||||
number<KGroup>{},
|
||||
number<KBit2>{},
|
||||
number<KBit1>{},
|
||||
number<KBit0>{},
|
||||
number<K1>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<MWarp>{}, number<MRepeat>{}, number<MRow>{}, number<MPair>{}))),
|
||||
make_tuple(sequence<0, 2, 8, 6, 5, 7, 1>{}, sequence<3, 4, 9, 10>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return lds_16x128_trans_block_desc;
|
||||
}
|
||||
|
||||
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, bool XorLdsLayout = true>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
|
||||
{
|
||||
@@ -1008,12 +1191,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
|
||||
#elif 1
|
||||
return Make16x128LdsBlockDescriptor();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1051,36 +1238,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
|
||||
constexpr index_t N1 = get_warp_size() / K0;
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
constexpr index_t kKRow = 2;
|
||||
constexpr index_t kMRow = 2;
|
||||
constexpr index_t kRowsize = 16;
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
|
||||
sequence<1, 2, 1>,
|
||||
sequence<1, 3, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
|
||||
{
|
||||
// Hold full block data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t kKPack = GetAlignmentQ<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
|
||||
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
return Make16x128TransLdsBlockDescriptor();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1215,12 +1396,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor()
|
||||
{
|
||||
// Hold full block data
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
|
||||
#elif 1
|
||||
return Make16x128LdsBlockDescriptor();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1258,36 +1443,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
|
||||
constexpr index_t N1 = get_warp_size() / K0;
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t kMWarps = 2;
|
||||
constexpr index_t kKWarps = 2;
|
||||
constexpr index_t kKRow = 2;
|
||||
constexpr index_t kMRow = 2;
|
||||
constexpr index_t kRowsize = 16;
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t kMPair = 2;
|
||||
constexpr index_t kMRepeat = 2;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
|
||||
sequence<kKWarps, kKRow, kRowsize, K1>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
|
||||
sequence<1, 2, 1>,
|
||||
sequence<1, 3, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
|
||||
{
|
||||
// Hold all data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t kKPack = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
|
||||
|
||||
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
|
||||
return Make16x128TransLdsBlockDescriptor();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1916,18 +2094,26 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
static constexpr index_t D_VMEM_READ = 1;
|
||||
|
||||
// LDS Read
|
||||
// 16 * 128 / 64 / 4 = 8
|
||||
static constexpr index_t OGradT_LDS_READ =
|
||||
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
|
||||
// 16 * 128 / 64 / 4 = 8
|
||||
// kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
kM0 * kVHeaddim / get_warp_size() / 8;
|
||||
// 16 * 128 / 64 / 4 = 8
|
||||
static constexpr index_t QT_LDS_READ =
|
||||
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
// 16 * 128 / 64 / 4 = 8
|
||||
// kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
kM0 * kVHeaddim / get_warp_size() / 8;
|
||||
// 16 * 32 / 64 / 8 = 1
|
||||
static constexpr index_t SGradT_LDS_READ_P1 =
|
||||
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
static constexpr index_t Q_LDS_READ =
|
||||
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ<Problem>();
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
// kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ<Problem>();
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / 8;
|
||||
// 1
|
||||
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
// 16 * 96 / 64 / 8 = 3
|
||||
@@ -1935,7 +2121,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
static constexpr index_t OGrad_LDS_READ =
|
||||
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad<Problem>();
|
||||
// kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad<Problem>();
|
||||
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / 8;
|
||||
// 1
|
||||
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
|
||||
|
||||
@@ -64,6 +64,10 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK
|
||||
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateNAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
// bf16
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
|
||||
@@ -120,6 +124,10 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterat
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
|
||||
4>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateNAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
// fp8
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
|
||||
|
||||
@@ -642,6 +642,219 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kNIter>
|
||||
struct WarpGemmAtrributeMfmaIterateNAndTransposedCDistribution
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
|
||||
|
||||
// swap A and B
|
||||
using ADataType = typename Impl::BDataType;
|
||||
using BDataType = typename Impl::ADataType;
|
||||
using CDataType = typename Impl::CDataType;
|
||||
|
||||
using AVecType = typename Impl::AVecType;
|
||||
|
||||
using BVecType =
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kNIter>;
|
||||
using CVecType =
|
||||
ext_vector_t<CDataType, vector_traits<typename Impl::CVecType>::vector_size * kNIter>;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM * kNIter;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kNIter; }
|
||||
|
||||
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
|
||||
"Multi-block on both M & N directions is not supported");
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// each N blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kAMBlock>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kNIter, Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
// each M blocks share the same data
|
||||
return tile_distribution_encoding<
|
||||
sequence<Impl::kBNBlock>,
|
||||
tuple<sequence<kNIter, Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<0, 2, 1>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
// single block to multi-block thread mapping
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kNIter, Impl::kAMBlock, Impl::kAMLane>,
|
||||
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
tuple<sequence<1, 2, 1>>,
|
||||
tuple<sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
|
||||
{
|
||||
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<kNIter, Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<2, 0>>,
|
||||
sequence<2, 2, 2>,
|
||||
sequence<0, 1, 3>>{};
|
||||
}
|
||||
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNBlock * Impl::kCNLane>,
|
||||
sequence<kNIter, Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<2, 0>>,
|
||||
sequence<2, 2, 2>,
|
||||
sequence<0, 1, 3>>{};
|
||||
}
|
||||
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<Impl::kCNLane>,
|
||||
sequence<kNIter,
|
||||
Impl::kCM0PerLane,
|
||||
Impl::kAMBlock * Impl::kCMLane,
|
||||
Impl::kCM1PerLane>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<2, 0>>,
|
||||
sequence<2, 2, 2>,
|
||||
sequence<0, 1, 3>>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
|
||||
|
||||
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
|
||||
|
||||
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
|
||||
|
||||
template <bool post_nop_ = false>
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_c = thread_buffer<typename Impl::CVecType, kNIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kNIter>;
|
||||
// swap A and B, value and type
|
||||
// Bug: result not write back.
|
||||
static_for<0, kNIter, 1>{}([&](auto iNIter) {
|
||||
Impl{}(
|
||||
reinterpret_cast<buf_c&>(c_vec).template get_as<typename Impl::CVecType>()(iNIter),
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iNIter],
|
||||
a_vec,
|
||||
bool_constant<post_nop_>{});
|
||||
});
|
||||
|
||||
// if(get_thread_global_1d_id()==0){
|
||||
// printf("Enter here, ")
|
||||
// }
|
||||
}
|
||||
|
||||
template <index_t iNIter, bool post_nop_ = false>
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iNIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_c = thread_buffer<typename Impl::CVecType, kNIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kNIter>;
|
||||
|
||||
static_assert(iNIter < kNIter);
|
||||
// swap A and B, value and type
|
||||
// static_for<0, kNIter, 1>{}([&](auto iNIter) {
|
||||
Impl{}(reinterpret_cast<buf_c&>(c_vec).template get_as<typename Impl::CVecType>()(iNIter),
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iNIter],
|
||||
a_vec,
|
||||
bool_constant<post_nop_>{});
|
||||
//});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
using buf_c = thread_buffer<typename Impl::CVecType, kNIter>;
|
||||
buf_c c_buf;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kNIter>;
|
||||
|
||||
static_for<0, kNIter, 1>{}([&](auto iNIter) {
|
||||
auto c_vec_tmp = Impl{}(
|
||||
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()(
|
||||
iNIter),
|
||||
a_vec);
|
||||
|
||||
c_buf.template set_as<typename Impl::CVecType>(iNIter, c_vec_tmp);
|
||||
});
|
||||
|
||||
return c_buf.template get_as<CVecType>()[I0];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
|
||||
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
{
|
||||
|
||||
@@ -34,6 +34,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
@@ -49,6 +50,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
|
||||
Reference in New Issue
Block a user