mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Fix mock token id, support g1u1/g1u0 through same inline code block (#1808)
* fix mock token id * prepare host for g1u1 * reformat inline-asm * restructure uk_0 * restructure gate_up * done * change default to init=1 * update readme * fix a bug in interleave pipeline * rcp for silu
This commit is contained in:
@@ -234,10 +234,153 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 32 * (128 + 8) * sizeof(bf16_t);
|
||||
// return 32 * (128 + 8) * sizeof(bf16_t);
|
||||
return MakeLdsLoadDesc_A().get_element_space_size() * sizeof(bf16_t) * 2; // 2 lds buffers
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
|
||||
[s_loop_cnt]"+s"(loop_cnt), \
|
||||
[v_acc_0]"+v"(v_acc[0]), \
|
||||
[v_acc_1]"+v"(v_acc[1]), \
|
||||
[v_acc_2]"+v"(v_acc[2]), \
|
||||
[v_acc_3]"+v"(v_acc[3]), \
|
||||
[v_acc_4]"+v"(v_acc[4]), \
|
||||
[v_acc_5]"+v"(v_acc[5]), \
|
||||
[v_acc_6]"+v"(v_acc[6]), \
|
||||
[v_acc_7]"+v"(v_acc[7]), \
|
||||
[v_acc_8]"+v"(v_acc[8]), \
|
||||
[v_acc_9]"+v"(v_acc[9]), \
|
||||
[v_acc_10]"+v"(v_acc[10]), \
|
||||
[v_acc_11]"+v"(v_acc[11]), \
|
||||
[v_acc_12]"+v"(v_acc[12]), \
|
||||
[v_acc_13]"+v"(v_acc[13]), \
|
||||
[v_acc_14]"+v"(v_acc[14]), \
|
||||
[v_acc_15]"+v"(v_acc[15]), \
|
||||
[s_mem_]"+r"(smem)
|
||||
|
||||
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
|
||||
[s_loop_cnt]"+s"(loop_cnt), \
|
||||
[v_acc_0]"+v"(v_acc[0]), \
|
||||
[v_acc_1]"+v"(v_acc[1]), \
|
||||
[v_acc_2]"+v"(v_acc[2]), \
|
||||
[v_acc_3]"+v"(v_acc[3]), \
|
||||
[v_acc_4]"+v"(v_acc[4]), \
|
||||
[v_acc_5]"+v"(v_acc[5]), \
|
||||
[v_acc_6]"+v"(v_acc[6]), \
|
||||
[v_acc_7]"+v"(v_acc[7]), \
|
||||
[v_acc_8]"+v"(v_acc[8]), \
|
||||
[v_acc_9]"+v"(v_acc[9]), \
|
||||
[v_acc_10]"+v"(v_acc[10]), \
|
||||
[v_acc_11]"+v"(v_acc[11]), \
|
||||
[v_acc_12]"+v"(v_acc[12]), \
|
||||
[v_acc_13]"+v"(v_acc[13]), \
|
||||
[v_acc_14]"+v"(v_acc[14]), \
|
||||
[v_acc_15]"+v"(v_acc[15]), \
|
||||
[v_acc_16]"+v"(v_acc[16]), \
|
||||
[v_acc_17]"+v"(v_acc[17]), \
|
||||
[v_acc_18]"+v"(v_acc[18]), \
|
||||
[v_acc_19]"+v"(v_acc[19]), \
|
||||
[v_acc_20]"+v"(v_acc[20]), \
|
||||
[v_acc_21]"+v"(v_acc[21]), \
|
||||
[v_acc_22]"+v"(v_acc[22]), \
|
||||
[v_acc_23]"+v"(v_acc[23]), \
|
||||
[v_acc_24]"+v"(v_acc[24]), \
|
||||
[v_acc_25]"+v"(v_acc[25]), \
|
||||
[v_acc_26]"+v"(v_acc[26]), \
|
||||
[v_acc_27]"+v"(v_acc[27]), \
|
||||
[v_acc_28]"+v"(v_acc[28]), \
|
||||
[v_acc_29]"+v"(v_acc[29]), \
|
||||
[v_acc_30]"+v"(v_acc[30]), \
|
||||
[v_acc_31]"+v"(v_acc[31]), \
|
||||
[s_mem_]"+r"(smem)
|
||||
|
||||
#define _EXPAND_ASM_ARGS_IN \
|
||||
[s_res_a0]"s"(res_a[0]), \
|
||||
[s_res_a1]"s"(res_a[1]), \
|
||||
[s_res_a2]"s"(res_a[2]), \
|
||||
[s_res_a3]"s"(res_a[3]), \
|
||||
[s_res_b0]"s"(res_b[0]), \
|
||||
[s_res_b1]"s"(res_b[1]), \
|
||||
[s_res_b2]"s"(res_b[2]), \
|
||||
[s_res_b3]"s"(res_b[3]), \
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
|
||||
\
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
|
||||
\
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
|
||||
[s_m0_init]"s"(m0_init_value), \
|
||||
[s_size_per_issue]"s"(size_per_issue), \
|
||||
[smem_sz]"n"(smem_buf_size), \
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value), \
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value), \
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value), \
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value), \
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value), \
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value), \
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value), \
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value), \
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes), \
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
|
||||
#define _EXPAND_ASM_ARGS_CLOBBER \
|
||||
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
|
||||
"a252", "a253", "a254", "a255", \
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
|
||||
"s86", \
|
||||
"v64", "v65", "v66", "v67", "v68", "v69", \
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
|
||||
"v124", "v125", "v126", "v127"
|
||||
// clang-format on
|
||||
|
||||
struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
@@ -245,7 +388,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
// Is2B: originally for B matrix we have 2 prefetch buffers. If set this to true
|
||||
// we can support A matric serve 2 B matrix, B0/B1, each B0/B1 still have same tile size
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
@@ -254,7 +399,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b) // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b,
|
||||
bool_constant<Is2B> = {}) // for each tile, the offset to move for each unroll
|
||||
{
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
@@ -299,129 +445,78 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
if constexpr(Is2B)
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[32]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#define CK_TILE_FLATMM_UK_2B 1
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0]"+v"(v_acc[0]),
|
||||
[v_acc_1]"+v"(v_acc[1]),
|
||||
[v_acc_2]"+v"(v_acc[2]),
|
||||
[v_acc_3]"+v"(v_acc[3]),
|
||||
[v_acc_4]"+v"(v_acc[4]),
|
||||
[v_acc_5]"+v"(v_acc[5]),
|
||||
[v_acc_6]"+v"(v_acc[6]),
|
||||
[v_acc_7]"+v"(v_acc[7]),
|
||||
[v_acc_8]"+v"(v_acc[8]),
|
||||
[v_acc_9]"+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
: _EXPAND_ASM_ARGS_OUT_TWO_ACC
|
||||
: _EXPAND_ASM_ARGS_IN,
|
||||
[s_res_b4]"s"(res_b[4]),
|
||||
[s_res_b5]"s"(res_b[5]),
|
||||
[s_res_b6]"s"(res_b[6]),
|
||||
[s_res_b7]"s"(res_b[7])
|
||||
: _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
// return local scratch
|
||||
auto c = make_tuple(MakeCBlockTile(), MakeCBlockTile());
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
else
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
: _EXPAND_ASM_ARGS_OUT_ONE_ACC
|
||||
: _EXPAND_ASM_ARGS_IN
|
||||
: _EXPAND_ASM_ARGS_CLOBBER
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -432,7 +527,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
@@ -441,7 +536,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b) // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b, // for each tile, the offset to move for each unroll
|
||||
bool_constant<Is2B> = {})
|
||||
{
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
@@ -486,130 +582,82 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
if constexpr(Is2B)
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[32]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#define CK_TILE_FLATMM_UK_2B 1
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0]"+v"(v_acc[0]),
|
||||
[v_acc_1]"+v"(v_acc[1]),
|
||||
[v_acc_2]"+v"(v_acc[2]),
|
||||
[v_acc_3]"+v"(v_acc[3]),
|
||||
[v_acc_4]"+v"(v_acc[4]),
|
||||
[v_acc_5]"+v"(v_acc[5]),
|
||||
[v_acc_6]"+v"(v_acc[6]),
|
||||
[v_acc_7]"+v"(v_acc[7]),
|
||||
[v_acc_8]"+v"(v_acc[8]),
|
||||
[v_acc_9]"+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
: _EXPAND_ASM_ARGS_OUT_TWO_ACC
|
||||
: _EXPAND_ASM_ARGS_IN,
|
||||
[s_res_b4]"s"(res_b[4]),
|
||||
[s_res_b5]"s"(res_b[5]),
|
||||
[s_res_b6]"s"(res_b[6]),
|
||||
[s_res_b7]"s"(res_b[7])
|
||||
: _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
// return local scratch
|
||||
auto c = make_tuple(MakeCBlockTile(), MakeCBlockTile());
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
|
||||
c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
else
|
||||
{
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
: _EXPAND_ASM_ARGS_OUT_ONE_ACC
|
||||
: _EXPAND_ASM_ARGS_IN
|
||||
: _EXPAND_ASM_ARGS_CLOBBER
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
#undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
|
||||
#undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
|
||||
#undef _EXPAND_ASM_ARGS_IN
|
||||
#undef _EXPAND_ASM_ARGS_CLOBBER
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -65,7 +65,8 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
|
||||
// in LDS we need store as
|
||||
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
|
||||
// y y wave-id lid/16 lid%16 v
|
||||
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
|
||||
constexpr index_t nbufs = 2;
|
||||
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t) * nbufs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -173,7 +174,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
:[smem_]"+r"(smem),
|
||||
[s_loop_cnt]"+s"(loop_cnt),
|
||||
[c0]"+v" (v_c0),
|
||||
@@ -418,7 +418,6 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
:[smem_]"+r"(smem),
|
||||
[s_loop_cnt]"+s"(loop_cnt),
|
||||
[c0]"+v" (v_c0),
|
||||
|
||||
@@ -477,7 +477,7 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
|
||||
"s36", "s37","s59","s80",
|
||||
"s36", "s37", "s56", "s59", "s60", "s80",
|
||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
|
||||
"v50", "v54", "v55",
|
||||
"v64","v65","v66","v67","v68","v69","v70","v71",
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// clang-format off
|
||||
|
||||
// define the CK_TILE_** macro before include this file to change kernel variation
|
||||
// we will undef everything defined in this file
|
||||
|
||||
#ifndef CK_TILE_FLATMM_UK_MFMA
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#endif
|
||||
@@ -816,3 +823,5 @@
|
||||
#undef _UK_MFMA_
|
||||
#undef _UK_PK_CVT_
|
||||
#undef _UK_ATOMIC_ADD_
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
// clang-format on
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user