temp save. change the memory layout based on basic unit of 16x128.

This commit is contained in:
aska-0096
2025-04-09 09:04:11 +00:00
parent 92e2c50fb8
commit e95db65de3
9 changed files with 634 additions and 105 deletions

View File

@@ -38,8 +38,7 @@ using fmha_bwd_convert_dq_0 =
using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 =
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
@@ -67,6 +66,7 @@ using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
using fmha_warp_tile2_0 = ck_tile::sequence<16, 32, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
@@ -76,16 +76,16 @@ using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_warp_tile2_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_warp_tile2_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
@@ -619,6 +619,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// for(int iM=0; iM<128; iM++){
// for(int iK=0; iK<16; iK++){
// printf("%04x ", *(reinterpret_cast<uint16_t*>(&(q_host(0, 0, iK, iM)))));
// }
// printf("\n");
// }
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());

View File

@@ -1081,8 +1081,26 @@ struct sorted_sequence_histogram<h_idx, sequence<x>, sequence<r, rs...>>
{
h.template at<h_idx>() += 1;
}
else
{
sorted_sequence_histogram<h_idx + 1, sequence<x>, sequence<rs...>>{}(h);
}
}
};
template <index_t h_idx, index_t x, index_t r>
struct sorted_sequence_histogram<h_idx, sequence<x>, sequence<r>>
{
template <typename Histogram>
constexpr auto operator()(Histogram& h)
{
if constexpr(x < r)
{
h.template at<h_idx>() += 1;
}
}
};
} // namespace detail
template <typename, index_t>

View File

@@ -46,23 +46,22 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
return rh_major_minor_to_y_;
};
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
}
return y_dim_out_to_in_;
}();
//
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
using OutDstrEncode = typename decltype(out_tensor.get_tile_distribution())::DstrEncode;
// using InDstrEncode = typename decltype(in_tensor.get_tile_distribution())::DstrEncode;
constexpr auto y_dim_out_to_in = generate_sequence_v2(
[&](auto i) constexpr {
constexpr index_t rh_major_out = OutDstrEncode::ys_to_rhs_major_[i];
constexpr index_t rh_minor_out = OutDstrEncode::ys_to_rhs_minor_[i];
return number<rh_major_minor_to_y_in[{rh_major_out, rh_minor_out}]>{};
},
number<NDimY>{});
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
@@ -128,7 +127,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
constexpr auto idx_y_out_tmp = generate_tuple(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});

View File

@@ -86,7 +86,7 @@ set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
dst_tile.set_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
}
} // namespace ck_tile

View File

@@ -258,6 +258,51 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds();
auto v_reg_tensor = load_tile(v_lds_read_window);
#if 0
constexpr auto kSeq0 = 64;
// Looped data loading
static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) {
auto k_block_tile = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kSeq0, 0});
store_tile(k_lds_write_window, k_block_tile);
shuffle_distributed_tensor(kt_block_tile, k_block_tile);
store_tile(kt_lds_write_window, kt_block_tile);
block_sync_lds();
auto k_reg_tensor_slice = load_tile(k_lds_read_window);
set_slice_tile(k_reg_tensor,
k_reg_tensor_slice,
Sequence<i_n0*kSeq0, 0>{},
Sequence<(i_n0+1)*kSeq0, kQKHeaddim>{});
auto kt_reg_tensor_slice = load_tile(kt_lds_read_window);
set_slice_tile(kt_reg_tensor,
kt_reg_tensor_slice,
Sequence<0, i_n0*kSeq0>{},
Sequence<kQKHeaddim, (i_n0+1)*kSeq0>{});
block_sync_lds();
});
static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) {
auto v_block_tile = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kSeq0, 0});
store_tile(v_lds_write_window, v_block_tile);
block_sync_lds();
auto v_reg_tensor_slice = load_tile(v_lds_read_window);
set_slice_tile(v_reg_tensor,
v_reg_tensor_slice,
Sequence<i_n0*kSeq0, 0>{},
Sequence<(i_n0+1)*kSeq0, kVHeaddim>{});
block_sync_lds();
});
#endif
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
@@ -738,12 +783,61 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// STAGE 6, SGrad^T@Q^T Gemm3
const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
// Policy::template SGradTFromGemm2CToGemm3A<Problem,
// decltype(dst_reg_tensor),
// decltype(ds_gemm)>(dst_reg_tensor,
// ds_gemm);
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
// if(get_block_1d_id()==0 && get_thread_local_1d_id()<64 &&i_total_loops==0){
// printf("Tid: %02d, Qt: %04x %04x %04x %04x %04x %04x %04x %04x DsT: %04x %04x
// %04x %04x %04x %04x %04x %04x dk_acc: %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf
// %.4lf\n",
// get_thread_local_1d_id(),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<0>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<1>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<2>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<3>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<4>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<5>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<6>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(qt_reg_tensor.get_thread_buffer()[number<7>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<0>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<1>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<2>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<3>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<4>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<5>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<6>{}]))),
// *(reinterpret_cast<const
// uint16_t*>(&(dst_reg_tensor.get_thread_buffer()[number<7>{}]))),
// dk_acc.get_thread_buffer()[number<0>{}],
// dk_acc.get_thread_buffer()[number<1>{}],
// dk_acc.get_thread_buffer()[number<2>{}],
// dk_acc.get_thread_buffer()[number<3>{}],
// dk_acc.get_thread_buffer()[number<4>{}],
// dk_acc.get_thread_buffer()[number<5>{}],
// dk_acc.get_thread_buffer()[number<6>{}],
// dk_acc.get_thread_buffer()[number<7>{}]);
// }
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();
@@ -977,9 +1071,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// STAGE 6, SGrad^T@Q^T Gemm3
const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
// Policy::template SGradTFromGemm2CToGemm3A<Problem,
// decltype(dst_reg_tensor),
// decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);

View File

@@ -371,6 +371,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
@@ -389,11 +390,34 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
#elif 1
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMWarps = 2;
constexpr index_t kKWarps = 2;
constexpr index_t kKRow = 2;
constexpr index_t kMRow = 2;
constexpr index_t kRowsize = 16;
constexpr index_t K1 = 2;
constexpr index_t kMPair = 2;
constexpr index_t kMRepeat = 2;
constexpr index_t kMGroup = kNPerBlock/16;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<kMGroup, kMWarps, kMRepeat, kMRow, kMPair>,
sequence<kKWarps, kKRow, kRowsize, K1>>,
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
tuple<sequence<0, 1>, sequence<1, 3, 2>>,
sequence<1, 1, 1, 2>,
sequence<0, 2, 4, 3>>{});
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution()
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
@@ -412,52 +436,72 @@ struct BlockFmhaBwdPipelineDefaultPolicy
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
#elif 1
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMWarps = 2;
constexpr index_t kKWarps = 2;
constexpr index_t kKRow = 2;
constexpr index_t kMRow = 2;
constexpr index_t kRowsize = 16;
constexpr index_t K1 = 2;
constexpr index_t kMPair = 2;
constexpr index_t kMRepeat = 2;
constexpr index_t kMGroup = kNPerBlock/16;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<kMGroup, kMWarps, kMRepeat, kMRow, kMPair>,
sequence<kKWarps, kKRow, kRowsize, K1>>,
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
tuple<sequence<0, 1>, sequence<1, 3, 2>>,
sequence<1, 1, 1, 2>,
sequence<0, 2, 4, 3>>{});
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
constexpr index_t kMWarps = 2;
constexpr index_t kKWarps = 2;
constexpr index_t kKRow = 2;
constexpr index_t kMRow = 2;
constexpr index_t kRowsize = 16;
constexpr index_t K1 = 2;
constexpr index_t kMPair = 2;
constexpr index_t kMRepeat = 2;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
sequence<kKWarps, kKRow, kRowsize, K1>>,
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
sequence<1, 1, 2>,
sequence<1, 3, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
constexpr index_t kMWarps = 2;
constexpr index_t kKWarps = 2;
constexpr index_t kKRow = 2;
constexpr index_t kMRow = 2;
constexpr index_t kRowsize = 16;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t kMPair = 2;
constexpr index_t kMRepeat = 2;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
sequence<kKWarps, kKRow, kRowsize, K1>>,
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
sequence<1, 1, 2>,
sequence<1, 3, 3>>{});
}
template <typename Problem, typename BlockGemm>
@@ -666,6 +710,145 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return 16 / sizeof(GemmDataType);
}
CK_TILE_HOST_DEVICE static constexpr auto Make16x128LdsBlockDescriptor()
{
constexpr index_t MWarp = 2;
constexpr index_t KWarp = 2;
constexpr index_t KRow = 2;
constexpr index_t KBit0 = 2;
constexpr index_t KBit1 = 2;
constexpr index_t KBit2 = 2;
constexpr index_t KBit3 = 2;
constexpr index_t KBit4 = 2;
constexpr index_t K1 = 2;
constexpr index_t MPair = 2;
constexpr index_t MRepeat = 2;
// K:HeadDim, M:Seq, 11 Dimensions Total
// W T I V
// Total: 4*64*4*2 = 2^11
// W I T W I T T T T T V
// 2 2 2 2 2 2 2 2 2 2 2
// KWarp, MPair, KRow, MWarp, MRepeat, KBit<1, 2, 4, 3, 0>, K1
// M = 2^4
// K = 2^7
constexpr auto lds_16x128_block_desc_raw = make_naive_tensor_descriptor(
make_tuple(number<KWarp>{},
number<MPair>{},
number<KRow>{},
number<MWarp>{},
number<MRepeat>{},
number<KBit1>{},
number<KBit2>{},
number<KBit4>{},
number<KBit3>{},
number<KBit0>{},
number<K1>{}),
make_tuple(
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 *
KBit1*(MRepeat * MWarp * KRow * MPair + 1)>{},
number<K1 * KBit0 * KBit3 * KBit4*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1 * MRepeat * MWarp>{},
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1 * MRepeat>{},
number<K1 * KBit0 * KBit3 * KBit4 * KBit2 * KBit1>{},
number<K1 * KBit0 * KBit3 * KBit4 * KBit2>{},
number<K1 * KBit0 * KBit3 * KBit4>{},
number<K1 * KBit0 * KBit3>{},
number<K1 * KBit0>{},
number<K1>{},
number<1>{}),
number<K1>{},
number<1>{});
constexpr auto lds_16x128_block_desc = transform_tensor_descriptor(
lds_16x128_block_desc_raw,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MWarp>{}, number<MRepeat>{}, number<MPair>{})),
make_merge_transform_v3_division_mod(make_tuple(number<KWarp>{},
number<KRow>{},
number<KBit4>{},
number<KBit3>{},
number<KBit2>{},
number<KBit1>{},
number<KBit0>{},
number<K1>{}))),
make_tuple(sequence<3, 4, 1>{}, sequence<0, 2, 7, 8, 6, 5, 9, 10>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_16x128_block_desc;
}
CK_TILE_HOST_DEVICE static constexpr auto Make16x128TransLdsBlockDescriptor()
{
constexpr index_t MWarp = 2;
constexpr index_t KWarp = 2;
constexpr index_t KRow = 2;
constexpr index_t MRow = 2;
constexpr index_t KGroup = 2;
constexpr index_t KBit0 = 2;
constexpr index_t KBit1 = 2;
constexpr index_t KBit2 = 2;
constexpr index_t K1 = 2;
constexpr index_t MPair = 2;
constexpr index_t MRepeat = 2;
// K:HeadDim, M:Seq, 11 Dimensions Total
// W T I V
// Total: 4*64*4*2 = 2^11
// W I T W I T T T T T V
// 2 2 2 2 2 2 2 2 2 2 2
// Kwarp, K1, KRow, MWarp, MRepeat, <KBit1, KBit2, KBit0>, KGroup, MRow, MPair
// M = 2^4
// K = 2^7
constexpr auto lds_16x128_trans_block_desc_raw = make_naive_tensor_descriptor(
make_tuple(number<KWarp>{},
number<K1>{},
number<KRow>{},
number<MWarp>{},
number<MRepeat>{},
number<KBit1>{},
number<KBit2>{},
number<KBit0>{},
number<KGroup>{},
number<MRow>{},
number<MPair>{}),
make_tuple(number<MPair * MRow * KGroup * KBit0 * KBit2 *
KBit1*(MRepeat * MWarp * KRow * K1 + 1)>{},
// Padding
number<MPair * MRow * KGroup *
KBit0*(KBit2 * KBit1 * MRepeat * MWarp * KRow + 1)>{},
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1 * MRepeat * MWarp>{},
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1 * MRepeat>{},
number<MPair * MRow * KGroup * KBit0 * KBit2 * KBit1>{},
number<MPair * MRow * KGroup * KBit0 * KBit2>{},
number<MPair * MRow * KGroup * KBit0>{},
number<MPair * MRow * KGroup>{},
number<MPair * MRow>{},
number<MPair>{},
number<1>{}),
number<MPair>{},
number<1>{});
constexpr auto lds_16x128_trans_block_desc = transform_tensor_descriptor(
lds_16x128_trans_block_desc_raw,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(number<KWarp>{},
number<KRow>{},
number<KGroup>{},
number<KBit2>{},
number<KBit1>{},
number<KBit0>{},
number<K1>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<MWarp>{}, number<MRepeat>{}, number<MRow>{}, number<MPair>{}))),
make_tuple(sequence<0, 2, 8, 6, 5, 7, 1>{}, sequence<3, 4, 9, 10>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return lds_16x128_trans_block_desc;
}
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, bool XorLdsLayout = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{
@@ -1008,12 +1191,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
#if 0
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
#elif 1
return Make16x128LdsBlockDescriptor();
#endif
}
template <typename Problem>
@@ -1051,36 +1238,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t kMWarps = 2;
constexpr index_t kKWarps = 2;
constexpr index_t kKRow = 2;
constexpr index_t kMRow = 2;
constexpr index_t kRowsize = 16;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t kMPair = 2;
constexpr index_t kMRepeat = 2;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
sequence<kKWarps, kKRow, kRowsize, K1>>,
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
sequence<1, 2, 1>,
sequence<1, 3, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
{
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetAlignmentQ<Problem>();
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
return Make16x128TransLdsBlockDescriptor();
}
template <typename Problem>
@@ -1215,12 +1396,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor()
{
// Hold full block data
#if 0
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
#elif 1
return Make16x128LdsBlockDescriptor();
#endif
}
template <typename Problem>
@@ -1258,36 +1443,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t kMWarps = 2;
constexpr index_t kKWarps = 2;
constexpr index_t kKRow = 2;
constexpr index_t kMRow = 2;
constexpr index_t kRowsize = 16;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t kMPair = 2;
constexpr index_t kMRepeat = 2;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
tuple<sequence<kMWarps, kMRepeat, kMRow, kMPair>,
sequence<kKWarps, kKRow, kRowsize, K1>>,
tuple<sequence<2, 1>, sequence<2, 1, 2>>,
tuple<sequence<0, 0>, sequence<1, 2, 2>>,
sequence<1, 2, 1>,
sequence<1, 3, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetAlignmentOGrad<Problem>();
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
return Make16x128TransLdsBlockDescriptor();
}
template <typename Problem>
@@ -1916,18 +2094,26 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t D_VMEM_READ = 1;
// LDS Read
// 16 * 128 / 64 / 4 = 8
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
// 16 * 128 / 64 / 4 = 8
// kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
// 16 * 128 / 64 / 8 = 4
kM0 * kVHeaddim / get_warp_size() / 8;
// 16 * 128 / 64 / 4 = 8
static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
// 16 * 128 / 64 / 4 = 8
// kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
// 16 * 128 / 64 / 8 = 4
kM0 * kVHeaddim / get_warp_size() / 8;
// 16 * 32 / 64 / 8 = 1
static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
// 16 * 128 / 64 / 8 = 4
static constexpr index_t Q_LDS_READ =
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ<Problem>();
// 16 * 128 / 64 / 8 = 4
// kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ<Problem>();
// 16 * 128 / 64 / 8 = 4
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / 8;
// 1
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// 16 * 96 / 64 / 8 = 3
@@ -1935,7 +2121,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
// 16 * 128 / 64 / 8 = 4
static constexpr index_t OGrad_LDS_READ =
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad<Problem>();
// kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad<Problem>();
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / 8;
// 1
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);

View File

@@ -64,6 +64,10 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK
WarpGemmAttributeMfmaImplF16F16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateNAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
// bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
@@ -120,6 +124,10 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterat
WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4<WGAttrCtlEnum::Default_>,
4>>;
using WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateNAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
// fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<

View File

@@ -642,6 +642,219 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kNIter>
struct WarpGemmAtrributeMfmaIterateNAndTransposedCDistribution
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
// swap A and B
using ADataType = typename Impl::BDataType;
using BDataType = typename Impl::ADataType;
using CDataType = typename Impl::CDataType;
using AVecType = typename Impl::AVecType;
using BVecType =
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kNIter>;
using CVecType =
ext_vector_t<CDataType, vector_traits<typename Impl::CVecType>::vector_size * kNIter>;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM * kNIter;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kNIter; }
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<sequence<>,
tuple<sequence<Impl::kBNBlock, Impl::kBNLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// each N blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kAMBlock>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 0>>,
sequence<2>,
sequence<1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kNIter, Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 1>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
// each M blocks share the same data
return tile_distribution_encoding<
sequence<Impl::kBNBlock>,
tuple<sequence<kNIter, Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<0, 2, 1>>,
tuple<sequence<0, 0, 1>>,
sequence<1, 2>,
sequence<0, 1>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
// single block to multi-block thread mapping
return tile_distribution_encoding<sequence<>,
tuple<sequence<kNIter, Impl::kAMBlock, Impl::kAMLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<1, 2, 1>>,
tuple<sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{};
}
}
CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
{
if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<kNIter, Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<2, 0>>,
sequence<2, 2, 2>,
sequence<0, 1, 3>>{};
}
else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNBlock * Impl::kCNLane>,
sequence<kNIter, Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<2, 0>>,
sequence<2, 2, 2>,
sequence<0, 1, 3>>{};
}
else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<kNIter,
Impl::kCM0PerLane,
Impl::kAMBlock * Impl::kCMLane,
Impl::kCM1PerLane>>,
tuple<sequence<2, 1>>,
tuple<sequence<2, 0>>,
sequence<2, 2, 2>,
sequence<0, 1, 3>>{};
}
}
using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding());
using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_c = thread_buffer<typename Impl::CVecType, kNIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kNIter>;
// swap A and B, value and type
// Bug: result not write back.
static_for<0, kNIter, 1>{}([&](auto iNIter) {
Impl{}(
reinterpret_cast<buf_c&>(c_vec).template get_as<typename Impl::CVecType>()(iNIter),
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iNIter],
a_vec,
bool_constant<post_nop_>{});
});
// if(get_thread_global_1d_id()==0){
// printf("Enter here, ")
// }
}
template <index_t iNIter, bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iNIter>,
bool_constant<post_nop_> = {}) const
{
using buf_c = thread_buffer<typename Impl::CVecType, kNIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kNIter>;
static_assert(iNIter < kNIter);
// swap A and B, value and type
// static_for<0, kNIter, 1>{}([&](auto iNIter) {
Impl{}(reinterpret_cast<buf_c&>(c_vec).template get_as<typename Impl::CVecType>()(iNIter),
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iNIter],
a_vec,
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
constexpr auto I0 = number<0>{};
using buf_c = thread_buffer<typename Impl::CVecType, kNIter>;
buf_c c_buf;
using buf_b = thread_buffer<typename Impl::BVecType, kNIter>;
static_for<0, kNIter, 1>{}([&](auto iNIter) {
auto c_vec_tmp = Impl{}(
reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()(
iNIter),
a_vec);
c_buf.template set_as<typename Impl::CVecType>(iNIter, c_vec_tmp);
});
return c_buf.template get_as<CVecType>()[I0];
}
};
template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
{

View File

@@ -34,6 +34,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution; };
// bf16
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
@@ -49,6 +50,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution; };
// fp8
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };