mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Merge remote-tracking branch 'origin/develop' into users/yiding12/fmha-bwd-workspace
This commit is contained in:
@@ -22,6 +22,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added FP8 block scale quantization for FMHA forward kernel.
|
||||
* Added gfx11 support for FMHA.
|
||||
* Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only).
|
||||
* Added FP8 per-tensor quantization support for FMHA forward V3 pipeline on gfx950.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -206,22 +206,14 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw
|
||||
"""
|
||||
FMHA_FWD_API_FOOTER_TEMPLATE = """
|
||||
float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{
|
||||
const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
const bool is_swa = (traits.mask_type != mask_enum::no_mask) and
|
||||
((0 < args.window_size_left) or (0 < args.window_size_right));
|
||||
const bool can_dispatch_v3 =
|
||||
(device_name.compare(0, 6, "gfx950") == 0) and
|
||||
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
|
||||
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
|
||||
(not traits.has_lse) and (not traits.has_dropout) and
|
||||
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
|
||||
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
|
||||
if ({F_is_v3_enabled} and can_dispatch_v3) {{
|
||||
return fmha_fwd_v3(traits, args, config);
|
||||
}} else {{
|
||||
return fmha_fwd_v2(traits, args, config);
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunreachable-code"
|
||||
if ({F_is_v3_enabled}) {{
|
||||
float r = fmha_fwd_v3(traits, args, config);
|
||||
if (r >= 0) return r;
|
||||
}}
|
||||
#pragma clang diagnostic pop
|
||||
return fmha_fwd_v2(traits, args, config);
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -1059,10 +1051,11 @@ class KernelComponentFactoryGfx950(
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
# add tile for qr_async_trload_v3
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].append(
|
||||
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
|
||||
# # add tile for qr_async_trload_v3 (bf16/fp16 V3 not ready)
|
||||
# if (128, 128) in result.keys():
|
||||
# result[(128, 128)].append(
|
||||
# FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
|
||||
pass
|
||||
elif dtype in cls._DT_MXFP8:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
@@ -1075,6 +1068,10 @@ class KernelComponentFactoryGfx950(
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in cls._DT_FP8BF16:
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].append(
|
||||
FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -1105,12 +1102,19 @@ class KernelComponentFactoryGfx950(
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
|
||||
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
|
||||
if (hdim, hdim_v) == (128, 128):
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
# # qr_async_trload_v3 bf16/fp16 not ready
|
||||
# if (hdim, hdim_v) == (128, 128):
|
||||
# for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
# F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
elif dtype in cls._DT_FP8BF16:
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for logits, qscale, mask in itertools.product(
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
["no", "causal"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
|
||||
# no need dropout kernels
|
||||
@@ -1494,8 +1498,8 @@ def write_fwd_api(
|
||||
FMHA_FWD_API_FOOTER_TEMPLATE.format(
|
||||
F_is_v3_enabled=BOOL_MAP[
|
||||
# NOTE: enable v3 pipelines when ready
|
||||
# 0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
|
||||
False
|
||||
0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
|
||||
# False
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
@@ -844,6 +844,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
@@ -877,6 +880,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
|
||||
@@ -1209,7 +1209,8 @@ enum LLVMSchedGroupMask : int32_t
|
||||
DS = 1 << 7,
|
||||
DS_READ = 1 << 8,
|
||||
DS_WRITE = 1 << 9,
|
||||
ALL = (DS_WRITE << 1) - 1,
|
||||
TRANS = 1 << 10,
|
||||
ALL = (TRANS << 1) - 1,
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_max_mem_vec_inst_width()
|
||||
|
||||
@@ -27,6 +27,8 @@ inline constexpr bool
|
||||
kattr_no_packed_fp32_ops_v<T, std::void_t<decltype(T::kattr_no_packed_fp32_ops)>> =
|
||||
T::kattr_no_packed_fp32_ops;
|
||||
|
||||
// TODO: rename to something more specific (e.g. kernel_attr_no_packed_fp32) since
|
||||
// kernel_attr<bool> only controls the no-packed-fp32-ops flag, not a general attribute bag.
|
||||
template <bool no_packed_fp32_ops>
|
||||
struct kernel_attr
|
||||
{
|
||||
@@ -35,6 +37,32 @@ struct kernel_attr
|
||||
static constexpr bool kattr_no_packed_fp32_ops = no_packed_fp32_ops;
|
||||
};
|
||||
|
||||
// Compose an architecture tag with kernel attributes.
|
||||
// Inherits ArchTag for symbol mangling and adds attribute flags.
|
||||
// kernel_attr_for<gfx950_t> -> gfx950_t (identity)
|
||||
// kernel_attr_for<gfx950_t, kernel_attr<true>> -> unique type with attribute
|
||||
namespace detail {
|
||||
template <typename ArchTag, typename... Attrs>
|
||||
struct kernel_attr_for_impl : ArchTag, Attrs...
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ArchTag, typename... Attrs>
|
||||
struct kernel_attr_for_helper
|
||||
{
|
||||
using type = kernel_attr_for_impl<ArchTag, Attrs...>;
|
||||
};
|
||||
|
||||
template <typename ArchTag>
|
||||
struct kernel_attr_for_helper<ArchTag>
|
||||
{
|
||||
using type = ArchTag;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename ArchTag, typename... Attrs>
|
||||
using kernel_attr_for = typename detail::kernel_attr_for_helper<ArchTag, Attrs...>::type;
|
||||
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
#define KENTRY_LAUNCH_BOUNDS __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#else
|
||||
|
||||
@@ -187,11 +187,11 @@ struct EpilogueGraph
|
||||
Context& context) const
|
||||
{
|
||||
// For each iteration, process all epilogues in order
|
||||
static_for<0, Steps, 1>{}([&](auto iAccess) {
|
||||
static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) {
|
||||
epilogues.template get<I.value>()(
|
||||
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
|
||||
});
|
||||
static_ford<sequence<Steps, sizeof...(EpilogueTypes)>>{}([&](auto iI) {
|
||||
constexpr auto iAccess = number<iI[number<0>{}]>{};
|
||||
constexpr auto I = number<iI[number<1>{}]>{};
|
||||
epilogues.template get<I.value>()(
|
||||
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -92,29 +92,29 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1105,15 +1105,14 @@ struct MoeFlatmmKernel
|
||||
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
|
||||
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
const auto row_idx =
|
||||
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
|
||||
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
|
||||
row_to_token_idx(row_idx);
|
||||
});
|
||||
});
|
||||
static_ford<sequence<MRepeat, kM0, kM2>>{}([&](auto mmm) {
|
||||
constexpr auto mIter = number<mmm[number<0>{}]>{};
|
||||
constexpr auto m0 = number<mmm[number<1>{}]>{};
|
||||
constexpr auto m2 = number<mmm[number<2>{}]>{};
|
||||
const auto row_idx =
|
||||
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
|
||||
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
|
||||
row_to_token_idx(row_idx);
|
||||
});
|
||||
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
@@ -1426,19 +1425,19 @@ struct MoeFlatmmKernel
|
||||
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
|
||||
c_scatter_valids;
|
||||
auto c_coord = dram_tile_distribution.calculate_index();
|
||||
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
|
||||
static_for<0, MPerThread, 1>{}([&](auto m0) {
|
||||
auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
|
||||
auto fused_token =
|
||||
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
|
||||
static_ford<sequence<NumMEpiTile, MPerThread>>{}([&](auto mm) {
|
||||
constexpr auto mIter = number<mm[number<0>{}]>{};
|
||||
constexpr auto m0 = number<mm[number<1>{}]>{};
|
||||
auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
|
||||
auto fused_token =
|
||||
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
|
||||
|
||||
index_t scatter_token_id = fused_token & token_id_mask;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
if constexpr(IsInputGemm)
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
});
|
||||
index_t scatter_token_id = fused_token & token_id_mask;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
if constexpr(IsInputGemm)
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -606,16 +606,16 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
@@ -656,15 +656,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -701,15 +701,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
@@ -722,44 +722,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -776,15 +776,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
@@ -797,43 +797,43 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -854,15 +854,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
@@ -870,44 +870,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
@@ -920,86 +920,86 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -529,22 +529,22 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
@@ -592,26 +592,26 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
2;
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -648,28 +648,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
@@ -682,44 +681,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -736,28 +735,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
@@ -770,43 +768,43 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -827,28 +825,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
@@ -856,44 +853,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
@@ -906,86 +903,86 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
@@ -486,13 +486,13 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
|
||||
constexpr auto mIter = number<mn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mn[number<1>{}]>{};
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
return c_block_tile;
|
||||
}
|
||||
@@ -643,24 +643,23 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
});
|
||||
|
||||
// prefetch Scale A
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
static_ford<sequence<MPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto impack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) =
|
||||
load_tile_with_offset(scale_a_dram_window,
|
||||
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
// move Scale A window to next K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// prefetch Scale B
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<NPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto inpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window, inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
// move Scale B window to next K
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -698,34 +697,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
// MAIN LOOP
|
||||
auto main_body_implx2 = [&]() mutable {
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto impack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, NPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto inpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
|
||||
// GEMM 2i
|
||||
@@ -788,34 +787,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
////////////////////////////// Next K //////////////////////////////
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+2)
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto impack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, NPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto inpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
|
||||
// GEMM 2i+1
|
||||
@@ -888,28 +887,28 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<MPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto impack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<NPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto inpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
|
||||
// GEMM loopK-1
|
||||
|
||||
@@ -484,20 +484,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
// Check that the maximum offset won't overflow.
|
||||
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
|
||||
{
|
||||
if(num_total_pages > 1)
|
||||
{
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache K offset overflow: exceed int32 max");
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache V offset overflow: exceed int32 max");
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -651,20 +637,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
// Check that the maximum offset won't overflow.
|
||||
if constexpr(kPageBlockSize < FmhaPipeline::kN0)
|
||||
{
|
||||
if(num_total_pages > 1)
|
||||
{
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_k <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache K offset overflow: exceed int32 max");
|
||||
assert(static_cast<int64_t>(num_total_pages - 1) * batch_stride_v <=
|
||||
static_cast<int64_t>(std::numeric_limits<index_t>::max()) &&
|
||||
"KV cache V offset overflow: exceed int32 max");
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
@@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel
|
||||
float logits_soft_cap_rcp;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
const void* q_descale_ptr = nullptr;
|
||||
const void* k_descale_ptr = nullptr;
|
||||
const void* v_descale_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
@@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for logits_soft_cap
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
@@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* q_descale_ptr,
|
||||
const void* k_descale_ptr,
|
||||
const void* v_descale_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
@@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel
|
||||
nhead_stride_o}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for qscale
|
||||
{}, // placeholder for logits_soft_cap
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
@@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
@@ -437,8 +472,19 @@ struct FmhaFwdV3Kernel
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
// Notice: When using double buffering, make sure both buffers are in the same array.
|
||||
// This prevents the compiler from using separate VGPRs to store the base address
|
||||
// and enables the use of immediate offsets in load/store instructions.
|
||||
constexpr auto smem_size_kv =
|
||||
FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>();
|
||||
__shared__ char smem_k[2][smem_size_kv];
|
||||
__shared__ char smem_v[2][smem_size_kv];
|
||||
|
||||
auto* smem_k0 = reinterpret_cast<KDataType*>(smem_k[0]);
|
||||
auto* smem_k1 = reinterpret_cast<KDataType*>(smem_k[1]);
|
||||
auto* smem_v0 = reinterpret_cast<VDataType*>(smem_v[0]);
|
||||
auto* smem_v1 = reinterpret_cast<VDataType*>(smem_v[1]);
|
||||
;
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
@@ -640,32 +686,88 @@ struct FmhaFwdV3Kernel
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
const float scale_s = [&] {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
|
||||
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
|
||||
return kargs.scale_s * q_descale * k_descale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.scale_s;
|
||||
}
|
||||
}();
|
||||
|
||||
AttentionVariant variant;
|
||||
const auto variant_params = [&] {
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
|
||||
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
|
||||
return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr);
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
|
||||
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
|
||||
float scale_o = v_descale / scale_p;
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return make_composes(
|
||||
ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
|
||||
else
|
||||
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func,
|
||||
mask,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_k0,
|
||||
smem_k1,
|
||||
smem_v0,
|
||||
smem_v1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_k0,
|
||||
smem_k1,
|
||||
smem_v0,
|
||||
smem_v1);
|
||||
}
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
|
||||
@@ -1706,22 +1706,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
p_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
#if defined(__gfx11__)
|
||||
PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor);
|
||||
PermuteWarpGemmCToA(pt_warp_tensor, p_warp_tensor);
|
||||
#else
|
||||
pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer();
|
||||
pt_warp_tensor.get_thread_buffer() = p_warp_tensor.get_thread_buffer();
|
||||
#endif
|
||||
pt_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
pt_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
pt_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
pt_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -1763,22 +1763,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
ds_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
#if defined(__gfx11__)
|
||||
PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor);
|
||||
PermuteWarpGemmCToA(dst_warp_tensor, ds_warp_tensor);
|
||||
#else
|
||||
dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer();
|
||||
dst_warp_tensor.get_thread_buffer() = ds_warp_tensor.get_thread_buffer();
|
||||
#endif
|
||||
dst_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
dst_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
dst_out.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
|
||||
dst_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
}
|
||||
else
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
constexpr auto warp_gemm = [] {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
// Use SwizzleB variant to get 8 contiguous K positions per lane,
|
||||
// matching the V tile distribution for PV GEMM
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
|
||||
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
|
||||
@@ -310,9 +318,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
|
||||
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -323,7 +330,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
@@ -339,31 +345,28 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
// CRITICAL: Must match Load descriptor merge pattern (NumIssues, LaneGroups, NumWarps)
|
||||
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
make_tuple(make_merge_transform(make_tuple(
|
||||
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
@@ -458,9 +461,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -471,7 +473,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
@@ -487,31 +488,27 @@ struct BlockFmhaV3PipelineDefaultPolicy
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
constexpr auto v_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
make_tuple(make_merge_transform(make_tuple(
|
||||
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
@@ -213,38 +213,38 @@ struct BlockGemmARegBRegCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::
|
||||
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx =
|
||||
std::conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -323,73 +323,69 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
// hot loop with MX scaling and pre-packed int32_t scales:
|
||||
// Outer loops iterate over pack groups (scale tile indices)
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
|
||||
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
|
||||
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto impack = number<ii[number<1>{}]>{};
|
||||
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
|
||||
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
|
||||
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
// Get pre-packed int32_t B scale
|
||||
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
// Get pre-packed int32_t B scale
|
||||
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
|
||||
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
|
||||
|
||||
// Inner loops: issue MFMAs within the pack group using OpSel
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
|
||||
constexpr auto mIter = impack * MXdlPack + imxdl;
|
||||
// Inner loops: issue MFMAs within the pack group using OpSel
|
||||
static_ford<sequence<KXdlPack, MXdlPack>>{}([&](auto jj) {
|
||||
constexpr auto ikxdl = number<jj[number<0>{}]>{};
|
||||
constexpr auto imxdl = number<jj[number<1>{}]>{};
|
||||
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
|
||||
constexpr auto mIter = impack * MXdlPack + imxdl;
|
||||
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// OpSel for A: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
|
||||
// OpSel for A: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
|
||||
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto nIter = inpack * NXdlPack + inxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto nIter = inpack * NXdlPack + inxdl;
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// OpSel for B: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
|
||||
// OpSel for B: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::conditional_t<TransposeC,
|
||||
sequence<nIter, mIter>,
|
||||
sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::conditional_t<TransposeC,
|
||||
sequence<nIter, mIter>,
|
||||
sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM with MX scaling using pre-packed scale and OpSel
|
||||
WarpGemm{}.template operator()<kOpSelA, kOpSelB>(c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
a_scale_packed,
|
||||
b_scale_packed);
|
||||
// warp GEMM with MX scaling using pre-packed scale and OpSel
|
||||
WarpGemm{}.template operator()<kOpSelA, kOpSelB>(c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
a_scale_packed,
|
||||
b_scale_packed);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -250,74 +250,74 @@ struct BlockGemmARegBRegCRegV2
|
||||
// hot loop:
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp, KIterPerWarp>>{}([&](auto mnk) {
|
||||
constexpr auto mIter = number<mnk[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mnk[number<1>{}]>{};
|
||||
constexpr auto kIter = number<mnk[number<2>{}]>{};
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,13 +109,13 @@ struct BlockGemmARegBSmemCRegOneWarpV1
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
#endif
|
||||
|
||||
@@ -141,35 +141,35 @@ struct BlockGemmARegBSmemCRegOneWarpV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -116,13 +116,13 @@ struct BlockGemmARegBSmemCRegV1
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
#endif
|
||||
|
||||
@@ -148,35 +148,35 @@ struct BlockGemmARegBSmemCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -103,13 +103,13 @@ struct BlockGemmARegBSmemCRegV2
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
#endif
|
||||
|
||||
@@ -135,36 +135,36 @@ struct BlockGemmARegBSmemCRegV2
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -90,13 +90,13 @@ struct BlockGemmARegBSmemCRegV2R1
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
// check C-block-distribution
|
||||
@@ -126,43 +126,43 @@ struct BlockGemmARegBSmemCRegV2R1
|
||||
NIterPerWarp>
|
||||
b_warp_tensors;
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = b_warp_tensors(nIter)(kIter);
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = b_warp_tensors(nIter)(kIter);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -116,13 +116,13 @@ struct BlockGemmASmemBRegCRegV1
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
#endif
|
||||
|
||||
@@ -148,34 +148,34 @@ struct BlockGemmASmemBRegCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A Block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -85,13 +85,13 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
#endif
|
||||
|
||||
@@ -120,13 +120,13 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
#endif
|
||||
|
||||
@@ -138,31 +138,31 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -165,61 +165,60 @@ struct BlockGemmMxARegBSmemCRegV1
|
||||
uniform_sequence_gen_t<BScaleWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto b_warp_window = b_warp_window_tmp;
|
||||
move_tile_window(
|
||||
b_warp_window,
|
||||
{nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)});
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_window);
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
auto b_warp_window = b_warp_window_tmp;
|
||||
move_tile_window(
|
||||
b_warp_window,
|
||||
{nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)});
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_window);
|
||||
|
||||
BScaleWarpTensor b_scale_warp_tensor;
|
||||
BScaleWarpTensor b_scale_warp_tensor;
|
||||
|
||||
b_scale_warp_tensor.get_thread_buffer() =
|
||||
b_scale_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter / NIterPack, nIter % NIterPack, kIter>{},
|
||||
b_scale_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths));
|
||||
b_scale_warp_tensor.get_thread_buffer() = b_scale_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter / NIterPack, nIter % NIterPack, kIter>{},
|
||||
b_scale_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths));
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
AScaleWarpTensor a_scale_warp_tensor;
|
||||
AScaleWarpTensor a_scale_warp_tensor;
|
||||
|
||||
a_scale_warp_tensor.get_thread_buffer() =
|
||||
a_scale_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_scale_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths));
|
||||
a_scale_warp_tensor.get_thread_buffer() =
|
||||
a_scale_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_scale_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}.template operator()<0, 0>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
int32_t(a_scale_warp_tensor.get_thread_buffer()[0]),
|
||||
int32_t(b_scale_warp_tensor.get_thread_buffer()[0]));
|
||||
// warp GEMM
|
||||
WarpGemm{}.template operator()<0, 0>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
int32_t(a_scale_warp_tensor.get_thread_buffer()[0]),
|
||||
int32_t(b_scale_warp_tensor.get_thread_buffer()[0]));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -239,39 +239,39 @@ struct BlockUniversalGemmAsBsCr
|
||||
"C block tensor data type!");
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -392,63 +392,59 @@ struct BlockUniversalGemmAsBsCr
|
||||
0); // Prevents instruction reordering across this boundary
|
||||
}
|
||||
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
static_ford<sequence<KInnerLoopIter, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kInnerIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kInnerIter>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kInnerIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard because barrier from
|
||||
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
|
||||
// by applying small delays to different wavefronts It is
|
||||
// performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(kIter.value == KRepeat - 1 &&
|
||||
kInnerIter.value == KInnerLoopIter - 1 &&
|
||||
mIter.value == MIterPerWarp - 1 &&
|
||||
nIter.value == NIterPerWarp - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard because barrier from
|
||||
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
|
||||
// by applying small delays to different wavefronts It is
|
||||
// performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(kIter.value == KRepeat - 1 &&
|
||||
kInnerIter.value == KInnerLoopIter - 1 &&
|
||||
mIter.value == MIterPerWarp - 1 &&
|
||||
nIter.value == NIterPerWarp - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
|
||||
nIter.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
if constexpr(kInnerIter.value == 0 && mIter.value == 0 && nIter.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -156,55 +156,54 @@ struct BlockWeightPreshuffleASmemBRegCReg
|
||||
uniform_sequence_gen_t<BFlatDistribution::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{},
|
||||
typename sequence_split<decltype(b_block_y_index_zeros),
|
||||
2>::right_type{}),
|
||||
merge_sequences(
|
||||
sequence<1, 1>{},
|
||||
typename sequence_split<decltype(b_block_y_lengths), 2>::right_type{}));
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<nIter, kIter>{},
|
||||
typename sequence_split<decltype(b_block_y_index_zeros), 2>::right_type{}),
|
||||
merge_sequences(
|
||||
sequence<1, 1>{},
|
||||
typename sequence_split<decltype(b_block_y_lengths), 2>::right_type{}));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(
|
||||
c_warp_tensor, preloaded_a_warp_tensor(number<AwarpIter>{}), b_warp_tensor);
|
||||
// warp GEMM
|
||||
WarpGemm{}(
|
||||
c_warp_tensor, preloaded_a_warp_tensor(number<AwarpIter>{}), b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
|
||||
load_tile(preloaded_a_warp_tensor(number<AwarpIter>{}),
|
||||
a_load_windows[number<AkIter>{}][number<AmIter>{}]);
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
|
||||
load_tile(preloaded_a_warp_tensor(number<AwarpIter>{}),
|
||||
a_load_windows[number<AkIter>{}][number<AmIter>{}]);
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -88,28 +88,28 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -369,6 +369,13 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
@@ -170,6 +170,8 @@ template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32,
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<EDouble>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, true, false, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, true, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<EDouble>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<EDouble>; };
|
||||
|
||||
|
||||
@@ -210,45 +210,45 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
|
||||
c_acc;
|
||||
|
||||
auto zero_accumulators = [&] {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) {
|
||||
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
|
||||
}); // make sure WG::CWarpTensor exposes a clear/zero
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp, (WG::kM * WG::kN) / warp_size>>{}(
|
||||
[&](auto mni) {
|
||||
constexpr auto mIter = number<mni[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mni[number<1>{}]>{};
|
||||
constexpr auto i = number<mni[number<2>{}]>{};
|
||||
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
|
||||
});
|
||||
});
|
||||
};
|
||||
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
zero_accumulators();
|
||||
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// warp GEMM
|
||||
WG{}(c_acc(mIter)(nIter),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor(nIter)(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
|
||||
load_and_convert_tile<UnaryOpSize>(
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// Could be deleted
|
||||
if constexpr((mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
static_ford<sequence<KIterPerQScale, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIterInQScale = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// warp GEMM
|
||||
WG{}(c_acc(mIter)(nIter),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor(nIter)(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
|
||||
load_and_convert_tile<UnaryOpSize>(
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// Could be deleted
|
||||
if constexpr((mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);
|
||||
|
||||
@@ -127,105 +127,103 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
c_acc;
|
||||
|
||||
auto zero_accumulators = [&] {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) {
|
||||
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
|
||||
}); // make sure WG::CWarpTensor exposes a clear/zero
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp, (WG::kM * WG::kN) / warp_size>>{}(
|
||||
[&](auto mni) {
|
||||
constexpr auto mIter = number<mni[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mni[number<1>{}]>{};
|
||||
constexpr auto i = number<mni[number<2>{}]>{};
|
||||
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
|
||||
});
|
||||
});
|
||||
};
|
||||
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
zero_accumulators();
|
||||
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// warp GEMM
|
||||
WG{}(c_acc(mIter)(nIter),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor(nIter)(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// Could be deleted
|
||||
if constexpr((mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_ford<sequence<KIterPerQScale, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIterInQScale = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
// warp GEMM
|
||||
WG{}(c_acc(mIter)(nIter),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor(nIter)(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// Could be deleted
|
||||
if constexpr((mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
|
||||
constexpr auto mIter = number<mn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mn[number<1>{}]>{};
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
if constexpr(BPreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// cross lane ops to get the value of scale_reg.
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
});
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
|
||||
KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
});
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
});
|
||||
|
||||
// cross lane ops to get the value of scale_reg.
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f = cvt_scale_to_fp32(scale_reg);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * scale_reg_f;
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -290,121 +290,115 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
|
||||
constexpr auto mIter = number<mn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mn[number<1>{}]>{};
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
// a_scale
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
if constexpr(BPreshuffleQuant)
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
return kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// cross lane ops to get the value of scale_reg.
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float b_scale_reg_f =
|
||||
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
|
||||
gathered_scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
float a_scale_reg_f = aq_picker.template pick<c_row>();
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
|
||||
b_scale_reg_f);
|
||||
});
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >=
|
||||
(NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::BQuantGroupSize::kN *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_reg_f =
|
||||
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
float a_scale_reg_f = aq_picker.template pick<c_row>();
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
|
||||
b_scale_reg_f);
|
||||
});
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
// a_scale
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
return kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// cross lane ops to get the value of scale_reg.
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float b_scale_reg_f = Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
|
||||
gathered_scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
float a_scale_reg_f = aq_picker.template pick<c_row>();
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
|
||||
b_scale_reg_f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_reg_f =
|
||||
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
float a_scale_reg_f = aq_picker.template pick<c_row>();
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * a_scale_reg_f *
|
||||
b_scale_reg_f);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -268,54 +268,51 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
|
||||
constexpr auto mIter = number<mn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mn[number<1>{}]>{};
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// for every column in AQ
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
// for every warp corresponding to a quantization scale
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
// for every column in AQ
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
// for every warp corresponding to a quantization scale
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
float scale_reg_f = aq_picker.template pick<c_row>();
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
float scale_reg_f = aq_picker.template pick<c_row>();
|
||||
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -290,57 +290,55 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using SrcVectorRawType = ext_vector_t<BDataTypeRaw, UnaryOpSize_ / BPackedSize>;
|
||||
using DstVectorType = ext_vector_t<ComputeDataType, UnaryOpSize_>;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
// B scale register offset
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return ((nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::BQuantGroupSize::kN) *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
static_ford<sequence<NIterPerWarp, Traits::QScalesPerBlockRow>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kQScale = number<nk[number<1>{}]>{};
|
||||
// B scale register offset
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return ((nIter * NWarp * WarpGemm::kN) / GemmTraits::BQuantGroupSize::kN) *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
// Get B scale from thread buffer
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_f = float(scale_reg);
|
||||
// Get B scale from thread buffer
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_f = float(scale_reg);
|
||||
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
// Thread buffers
|
||||
using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
|
||||
using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
// Thread buffers
|
||||
using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
|
||||
using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
|
||||
|
||||
BWarpThreadBuffer b_warp_thread_buffer;
|
||||
BLDSThreadBuffer b_lds_thread_buffer;
|
||||
BWarpThreadBuffer b_warp_thread_buffer;
|
||||
BLDSThreadBuffer b_lds_thread_buffer;
|
||||
|
||||
// Load thread buffer from tile (LDS type)
|
||||
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// Load thread buffer from tile (LDS type)
|
||||
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// Apply scale to B thread buffer and cast
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(
|
||||
b_warp_thread_buffer.template get_as<DstVectorType>()(i),
|
||||
b_lds_thread_buffer.template get_as<SrcVectorRawType>()[i],
|
||||
b_scale_f);
|
||||
});
|
||||
|
||||
// Store B thread buffer to tile (MMA type)
|
||||
b_warp_tile_.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),
|
||||
b_warp_thread_buffer);
|
||||
// Apply scale to B thread buffer and cast
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(b_warp_thread_buffer.template get_as<DstVectorType>()(i),
|
||||
b_lds_thread_buffer.template get_as<SrcVectorRawType>()[i],
|
||||
b_scale_f);
|
||||
});
|
||||
|
||||
// Store B thread buffer to tile (MMA type)
|
||||
b_warp_tile_.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),
|
||||
b_warp_thread_buffer);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -361,113 +359,107 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// hot loop:
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
|
||||
constexpr auto mIter = number<mn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mn[number<1>{}]>{};
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
if constexpr(BPreshuffleQuant)
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
return kQScale; // prefill: one quant group per block
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter; // decode or multiple groups per warp
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// cross lane ops to get the value of scale_reg.
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f =
|
||||
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
|
||||
gathered_scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >=
|
||||
(NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::BQuantGroupSize::kN *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f =
|
||||
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN > (NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
return kQScale; // prefill: one quant group per block
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter; // decode or multiple groups per warp
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
// cross lane ops to get the value of scale_reg.
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32<typename Traits::BQDataType>(
|
||||
gathered_scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float scale_reg_f =
|
||||
Base::cvt_scale_to_fp32<typename Traits::BQDataType>(scale_reg);
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -288,22 +288,22 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
@@ -366,16 +366,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -448,15 +448,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
aq_block_tile_2 = load_tile(aq_copy_dram_window);
|
||||
@@ -473,15 +473,15 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
aq_block_tile = load_tile(aq_copy_dram_window);
|
||||
@@ -520,16 +520,16 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
aq_block_tile_2 = load_tile(aq_copy_dram_window);
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
|
||||
@@ -275,22 +275,22 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
@@ -337,16 +337,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -424,15 +424,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
@@ -461,15 +461,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
@@ -518,16 +518,16 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
|
||||
|
||||
@@ -303,11 +303,11 @@ struct BlockNormReduceCrossWarpSync
|
||||
index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
});
|
||||
static_ford<sequence<thread_buf_size, num_reduce_warps>>{}([&](auto ii) {
|
||||
constexpr auto i_0 = number<ii[number<0>{}]>{};
|
||||
constexpr auto i_1 = number<ii[number<1>{}]>{};
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
});
|
||||
block_sync_lds(); // TODO: we don't need sync here
|
||||
|
||||
|
||||
@@ -631,17 +631,17 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
IndexDataType> all_indices;
|
||||
|
||||
// Load data from shared memory
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
static_ford<sequence<thread_buf_size, num_reduce_warps>>{}([&](auto ii) {
|
||||
constexpr auto i_0 = number<ii[number<0>{}]>{};
|
||||
constexpr auto i_1 = number<ii[number<1>{}]>{};
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
all_indices[i_0 * num_reduce_warps + i_1] =
|
||||
smem_indices[i_0 * num_warps + local_smem_os + i_1];
|
||||
}
|
||||
});
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
all_indices[i_0 * num_reduce_warps + i_1] =
|
||||
smem_indices[i_0 * num_warps + local_smem_os + i_1];
|
||||
}
|
||||
});
|
||||
block_sync_lds(); // TODO: we don't need sync here
|
||||
|
||||
|
||||
Reference in New Issue
Block a user