mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
changed all the scale outside except for uq
This commit is contained in:
@@ -245,13 +245,10 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename AToken_id, typename AQRes, typename DQRes, typename GQRes, typename SMQRes, typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
template <typename Ascale, typename GQscale, typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()( const AToken_id& row_ids_a_,
|
||||
const AQRes& res_aq,
|
||||
const DQRes& res_dq,
|
||||
const GQRes& res_gq,
|
||||
const SMQRes& res_smq,
|
||||
operator()( const Ascale& a_scale_,
|
||||
const GQscale& gq_scale_,
|
||||
const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
const BRes& res_b,
|
||||
@@ -263,7 +260,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
|
||||
{
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 4 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
static_assert(AToken_id::size() == Repeat_M);
|
||||
static_assert(Ascale::size() == Repeat_M);
|
||||
|
||||
auto a_sst = make_tile_window(
|
||||
@@ -372,10 +368,6 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
|
||||
register int v_z61 asm("v189") = 0;
|
||||
register int v_z62 asm("v190") = 0;
|
||||
register int v_z63 asm("v191") = 0;
|
||||
|
||||
index_t temp0 = static_cast<index_t>(row_ids_a_[number<0>{}]);
|
||||
index_t temp1 = static_cast<index_t>(row_ids_a_[number<1>{}]);
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
@@ -449,13 +441,11 @@ struct Flatmm_32x512x256_1x4x1_16x16x64_int8 : public Flatmm_32x512x256_1x4x1_16
|
||||
[c61]"+v"(v_z61),
|
||||
[c62]"+v"(v_z62),
|
||||
[c63]"+v"(v_z63),
|
||||
[v_token_id0]"+v"(temp0),
|
||||
[v_token_id1]"+v"(temp1),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_aq]"s"(res_aq),
|
||||
[s_res_dq]"s"(res_dq),
|
||||
[s_res_gq]"s"(res_gq),
|
||||
[s_res_smq]"s"(res_smq),
|
||||
: [a_scale0]"v"(a_scale_[0]),
|
||||
[a_scale1]"v"(a_scale_[1]),
|
||||
[gq_scale0]"v"(gq_scale_[0]),
|
||||
[gq_scale1]"v"(gq_scale_[1]),
|
||||
[s_res_a]"s"(res_a),
|
||||
// [s_res_a1]"s"(res_a[1]),
|
||||
// [s_res_a2]"s"(res_a[2]),
|
||||
|
||||
@@ -80,21 +80,25 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
|
||||
template <typename DQRes,
|
||||
typename BRes,
|
||||
typename DQCoords,
|
||||
typename BCoords,
|
||||
typename ORes,
|
||||
typename OCoords,
|
||||
typename OFlags>
|
||||
// typename ScaleTensor>
|
||||
typename OFlags,
|
||||
typename ScaleTensor,
|
||||
typename YScaleTensor>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const DQRes& res_dq,
|
||||
const BRes& res_b,
|
||||
const DQCoords& cached_coords_dq,
|
||||
const BCoords& cached_coords_b,
|
||||
const ORes& res_o,
|
||||
const OCoords& cached_coords_o,
|
||||
const OFlags& o_flags, // this should be in sgpr
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t n, // loop along n dim
|
||||
// const ScaleTensor& scale_,
|
||||
const ScaleTensor& scale_,
|
||||
const YScaleTensor& smq_scale_,
|
||||
index_t tile_offset_dq,
|
||||
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
|
||||
index_t tile_offset_half_b, //splited load alone K in to 2 part
|
||||
@@ -108,9 +112,9 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
|
||||
const index_t tile_stride_dq_bytes = tile_offset_dq * sizeof(DScaleDataType);
|
||||
|
||||
// static_assert(ScaleTensor::size() == 2);
|
||||
// float s0 = scale_[number<0>{}];
|
||||
// float s1 = scale_[number<1>{}];
|
||||
static_assert(ScaleTensor::size() == 2);
|
||||
float s0 = scale_[number<0>{}];
|
||||
float s1 = scale_[number<1>{}];
|
||||
|
||||
index_t loop_cnt = n ;
|
||||
|
||||
@@ -220,8 +224,10 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
// [v_sld_y_os]"v"(sld_y_os),
|
||||
// [v_sfl_sld]"v"(sfl_sld),
|
||||
// [v_sfl_sst]"v"(sfl_sst),
|
||||
[smq_scale0]"s"(smq_scale_[0]),
|
||||
[smq_scale1]"s"(smq_scale_[1]),
|
||||
[s_res_dq]"s"(res_dq),
|
||||
[s_res_o0]"s"(res_o[0]),
|
||||
[s_res_o0]"s"(res_o[0]),
|
||||
[s_res_o1]"s"(res_o[1]),
|
||||
//[s_res_o2]"s"(res_o[2]),
|
||||
//[s_res_o3]"s"(res_o[3]),
|
||||
@@ -229,6 +235,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
// [s_res_b1]"s"(res_b[1]),
|
||||
// [s_res_b2]"s"(res_b[2]),
|
||||
// [s_res_b3]"s"(res_b[3]),
|
||||
[v_os_dq]"v"(static_cast<index_t>(cached_coords_dq * sizeof(DScaleDataType))),
|
||||
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
|
||||
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
|
||||
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
|
||||
@@ -293,8 +300,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
|
||||
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
|
||||
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
|
||||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
|
||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
|
||||
"v1", "v2", "v3", "v4", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
|
||||
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
|
||||
"v56", "v57", "v64",
|
||||
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
|
||||
@@ -366,8 +372,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
// [v_sfl_sld]"v"(sfl_sld),
|
||||
// [v_sfl_sst]"v"(sfl_sst),
|
||||
[s_res_dq]"s"(res_dq),
|
||||
[s_res_o0]"s"(res_o[0]),
|
||||
[s_res_o1]"s"(res_o[1]),
|
||||
[s_res_o0]"s"(res_o[0]),
|
||||
[s_res_o1]"s"(res_o[1]),
|
||||
//[s_res_o2]"s"(res_o[2]),
|
||||
//[s_res_o3]"s"(res_o[3]),
|
||||
[s_res_d]"s"(res_b),
|
||||
@@ -390,8 +396,8 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
[s_tile_os_b_half]"s"(tile_offset_half_b_bytes),
|
||||
[s_tile_os_b]"s"(tile_stride_b_bytes),
|
||||
[s_tile_os_dq]"s"(tile_stride_dq_bytes),
|
||||
// [scale_0]"v"(s0),
|
||||
// [scale_1]"v"(s1),
|
||||
[scale_0]"v"(s0),
|
||||
[scale_1]"v"(s1),
|
||||
// [v_nan_lo]"v"(nan_lo),
|
||||
// [v_nan_hi]"v"(nan_hi),
|
||||
[s_execflag_0]"s"(o_flags[number<0>{}]),
|
||||
@@ -438,8 +444,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x
|
||||
"s55", "s56", "s57", "s58", "s59", "s60", "s61", "s62", "s63",
|
||||
"s64", "s65", "s66", "s67", "s68", "s69", "s70", "s71", "s72",
|
||||
"s73", "s74", "s75", "s76", "s77", "s78", "s79", "s80", // s86 as tmp
|
||||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
|
||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
|
||||
"v1", "v2", "v3", "v4", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
|
||||
"v20", "v21", "v22", "v23", "v24", "v25", "v50", "v51", "v52", "v53", "v54", "v55",
|
||||
"v56", "v57", "v64",
|
||||
"v65", "v66", "v67", "v68", "v69", "v70", "v71", "v72", "v73",
|
||||
|
||||
@@ -27,8 +27,12 @@
|
||||
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
|
||||
|
||||
#endif
|
||||
" v_and_b32 v0, 0x3f, v0 \n"
|
||||
" v_lshrrev_b32 v3, 6, v0 \n"
|
||||
" v_readfirstlane_b32 s7, v3 \n"
|
||||
" s_waitcnt vmcnt(24) \n"
|
||||
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], %[s_res_d], 0 offen\n"
|
||||
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
|
||||
" v_mul_f32 v54, v128, v128 \n"
|
||||
" v_mul_f32 v55, v129, v129 \n"
|
||||
" v_mul_f32 v56, v130, v130 \n"
|
||||
@@ -49,7 +53,6 @@
|
||||
" v_exp_f32 v55, v55 \n"
|
||||
" v_exp_f32 v56, v56 \n"
|
||||
" v_exp_f32 v57, v57 \n"
|
||||
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], %[s_res_d], 0 offen offset:1024\n"
|
||||
" v_add_f32 v54, v54, 1.0 \n"
|
||||
" v_add_f32 v55, v55, 1.0 \n"
|
||||
" v_add_f32 v56, v56, 1.0 \n"
|
||||
@@ -577,71 +580,71 @@
|
||||
" v_mul_f32 v189, v189, v55 \n"
|
||||
" v_mul_f32 v190, v190, v56 \n"
|
||||
" v_mul_f32 v191, v191, v57 \n"
|
||||
" v_mul_f32 v128, v18, v128 row_newbcast:0 \n"
|
||||
" v_mul_f32 v129, v18, v129 row_newbcast:1 \n"
|
||||
" v_mul_f32 v130, v18, v130 row_newbcast:2 \n"
|
||||
" v_mul_f32 v131, v18, v131 row_newbcast:3 \n"
|
||||
" v_mul_f32 v132, v18, v132 row_newbcast:0 \n"
|
||||
" v_mul_f32 v133, v18, v133 row_newbcast:1 \n"
|
||||
" v_mul_f32 v134, v18, v134 row_newbcast:2 \n"
|
||||
" v_mul_f32 v135, v18, v135 row_newbcast:3 \n"
|
||||
" v_mul_f32 v136, v18, v136 row_newbcast:4 \n"
|
||||
" v_mul_f32 v137, v18, v137 row_newbcast:5 \n"
|
||||
" v_mul_f32 v138, v18, v138 row_newbcast:6 \n"
|
||||
" v_mul_f32 v139, v18, v139 row_newbcast:7 \n"
|
||||
" v_mul_f32 v140, v18, v140 row_newbcast:4 \n"
|
||||
" v_mul_f32 v141, v18, v141 row_newbcast:5 \n"
|
||||
" v_mul_f32 v142, v18, v142 row_newbcast:6 \n"
|
||||
" v_mul_f32 v143, v18, v143 row_newbcast:7 \n"
|
||||
" v_mul_f32 v144, v18, v144 row_newbcast:8 \n"
|
||||
" v_mul_f32 v145, v18, v145 row_newbcast:9 \n"
|
||||
" v_mul_f32 v146, v18, v146 row_newbcast:10 \n"
|
||||
" v_mul_f32 v147, v18, v147 row_newbcast:11 \n"
|
||||
" v_mul_f32 v148, v18, v148 row_newbcast:8 \n"
|
||||
" v_mul_f32 v149, v18, v149 row_newbcast:9 \n"
|
||||
" v_mul_f32 v150, v18, v150 row_newbcast:10 \n"
|
||||
" v_mul_f32 v151, v18, v151 row_newbcast:11 \n"
|
||||
" v_mul_f32 v152, v18, v152 row_newbcast:12 \n"
|
||||
" v_mul_f32 v153, v18, v153 row_newbcast:13 \n"
|
||||
" v_mul_f32 v154, v18, v154 row_newbcast:14 \n"
|
||||
" v_mul_f32 v155, v18, v155 row_newbcast:15 \n"
|
||||
" v_mul_f32 v156, v18, v156 row_newbcast:12 \n"
|
||||
" v_mul_f32 v157, v18, v157 row_newbcast:13 \n"
|
||||
" v_mul_f32 v158, v18, v158 row_newbcast:14 \n"
|
||||
" v_mul_f32 v159, v18, v159 row_newbcast:15 \n"
|
||||
" v_mul_f32 v160, v19, v160 row_newbcast:0 \n"
|
||||
" v_mul_f32 v161, v19, v161 row_newbcast:1 \n"
|
||||
" v_mul_f32 v162, v19, v162 row_newbcast:2 \n"
|
||||
" v_mul_f32 v163, v19, v163 row_newbcast:3 \n"
|
||||
" v_mul_f32 v164, v19, v164 row_newbcast:0 \n"
|
||||
" v_mul_f32 v165, v19, v165 row_newbcast:1 \n"
|
||||
" v_mul_f32 v166, v19, v166 row_newbcast:2 \n"
|
||||
" v_mul_f32 v167, v19, v167 row_newbcast:3 \n"
|
||||
" v_mul_f32 v168, v19, v168 row_newbcast:4 \n"
|
||||
" v_mul_f32 v169, v19, v169 row_newbcast:5 \n"
|
||||
" v_mul_f32 v170, v19, v170 row_newbcast:6 \n"
|
||||
" v_mul_f32 v171, v19, v171 row_newbcast:7 \n"
|
||||
" v_mul_f32 v172, v19, v172 row_newbcast:4 \n"
|
||||
" v_mul_f32 v173, v19, v173 row_newbcast:5 \n"
|
||||
" v_mul_f32 v174, v19, v174 row_newbcast:6 \n"
|
||||
" v_mul_f32 v175, v19, v175 row_newbcast:7 \n"
|
||||
" v_mul_f32 v176, v19, v176 row_newbcast:8 \n"
|
||||
" v_mul_f32 v177, v19, v177 row_newbcast:9 \n"
|
||||
" v_mul_f32 v178, v19, v178 row_newbcast:10 \n"
|
||||
" v_mul_f32 v179, v19, v179 row_newbcast:11 \n"
|
||||
" v_mul_f32 v180, v19, v180 row_newbcast:8 \n"
|
||||
" v_mul_f32 v181, v19, v181 row_newbcast:9 \n"
|
||||
" v_mul_f32 v182, v19, v182 row_newbcast:10 \n"
|
||||
" v_mul_f32 v183, v19, v183 row_newbcast:11 \n"
|
||||
" v_mul_f32 v184, v19, v184 row_newbcast:12 \n"
|
||||
" v_mul_f32 v185, v19, v185 row_newbcast:13 \n"
|
||||
" v_mul_f32 v186, v19, v186 row_newbcast:14 \n"
|
||||
" v_mul_f32 v187, v19, v187 row_newbcast:15 \n"
|
||||
" v_mul_f32 v188, v19, v188 row_newbcast:12 \n"
|
||||
" v_mul_f32 v189, v19, v189 row_newbcast:13 \n"
|
||||
" v_mul_f32 v190, v19, v190 row_newbcast:14 \n"
|
||||
" v_mul_f32 v191, v19, v191 row_newbcast:15 \n"
|
||||
" buffer_load_dword v12, v5, %[s_res_dq], 0 offen \n"
|
||||
" v_mul_f32 v128, %[smq_scale0], v128 row_newbcast:0 \n"
|
||||
" v_mul_f32 v129, %[smq_scale0], v129 row_newbcast:1 \n"
|
||||
" v_mul_f32 v130, %[smq_scale0], v130 row_newbcast:2 \n"
|
||||
" v_mul_f32 v131, %[smq_scale0], v131 row_newbcast:3 \n"
|
||||
" v_mul_f32 v132, %[smq_scale0], v132 row_newbcast:0 \n"
|
||||
" v_mul_f32 v133, %[smq_scale0], v133 row_newbcast:1 \n"
|
||||
" v_mul_f32 v134, %[smq_scale0], v134 row_newbcast:2 \n"
|
||||
" v_mul_f32 v135, %[smq_scale0], v135 row_newbcast:3 \n"
|
||||
" v_mul_f32 v136, %[smq_scale0], v136 row_newbcast:4 \n"
|
||||
" v_mul_f32 v137, %[smq_scale0], v137 row_newbcast:5 \n"
|
||||
" v_mul_f32 v138, %[smq_scale0], v138 row_newbcast:6 \n"
|
||||
" v_mul_f32 v139, %[smq_scale0], v139 row_newbcast:7 \n"
|
||||
" v_mul_f32 v140, %[smq_scale0], v140 row_newbcast:4 \n"
|
||||
" v_mul_f32 v141, %[smq_scale0], v141 row_newbcast:5 \n"
|
||||
" v_mul_f32 v142, %[smq_scale0], v142 row_newbcast:6 \n"
|
||||
" v_mul_f32 v143, %[smq_scale0], v143 row_newbcast:7 \n"
|
||||
" v_mul_f32 v144, %[smq_scale0], v144 row_newbcast:8 \n"
|
||||
" v_mul_f32 v145, %[smq_scale0], v145 row_newbcast:9 \n"
|
||||
" v_mul_f32 v146, %[smq_scale0], v146 row_newbcast:10 \n"
|
||||
" v_mul_f32 v147, %[smq_scale0], v147 row_newbcast:11 \n"
|
||||
" v_mul_f32 v148, %[smq_scale0], v148 row_newbcast:8 \n"
|
||||
" v_mul_f32 v149, %[smq_scale0], v149 row_newbcast:9 \n"
|
||||
" v_mul_f32 v150, %[smq_scale0], v150 row_newbcast:10 \n"
|
||||
" v_mul_f32 v151, %[smq_scale0], v151 row_newbcast:11 \n"
|
||||
" v_mul_f32 v152, %[smq_scale0], v152 row_newbcast:12 \n"
|
||||
" v_mul_f32 v153, %[smq_scale0], v153 row_newbcast:13 \n"
|
||||
" v_mul_f32 v154, %[smq_scale0], v154 row_newbcast:14 \n"
|
||||
" v_mul_f32 v155, %[smq_scale0], v155 row_newbcast:15 \n"
|
||||
" v_mul_f32 v156, %[smq_scale0], v156 row_newbcast:12 \n"
|
||||
" v_mul_f32 v157, %[smq_scale0], v157 row_newbcast:13 \n"
|
||||
" v_mul_f32 v158, %[smq_scale0], v158 row_newbcast:14 \n"
|
||||
" v_mul_f32 v159, %[smq_scale0], v159 row_newbcast:15 \n"
|
||||
" v_mul_f32 v160, %[smq_scale1], v160 row_newbcast:0 \n"
|
||||
" v_mul_f32 v161, %[smq_scale1], v161 row_newbcast:1 \n"
|
||||
" v_mul_f32 v162, %[smq_scale1], v162 row_newbcast:2 \n"
|
||||
" v_mul_f32 v163, %[smq_scale1], v163 row_newbcast:3 \n"
|
||||
" v_mul_f32 v164, %[smq_scale1], v164 row_newbcast:0 \n"
|
||||
" v_mul_f32 v165, %[smq_scale1], v165 row_newbcast:1 \n"
|
||||
" v_mul_f32 v166, %[smq_scale1], v166 row_newbcast:2 \n"
|
||||
" v_mul_f32 v167, %[smq_scale1], v167 row_newbcast:3 \n"
|
||||
" v_mul_f32 v168, %[smq_scale1], v168 row_newbcast:4 \n"
|
||||
" v_mul_f32 v169, %[smq_scale1], v169 row_newbcast:5 \n"
|
||||
" v_mul_f32 v170, %[smq_scale1], v170 row_newbcast:6 \n"
|
||||
" v_mul_f32 v171, %[smq_scale1], v171 row_newbcast:7 \n"
|
||||
" v_mul_f32 v172, %[smq_scale1], v172 row_newbcast:4 \n"
|
||||
" v_mul_f32 v173, %[smq_scale1], v173 row_newbcast:5 \n"
|
||||
" v_mul_f32 v174, %[smq_scale1], v174 row_newbcast:6 \n"
|
||||
" v_mul_f32 v175, %[smq_scale1], v175 row_newbcast:7 \n"
|
||||
" v_mul_f32 v176, %[smq_scale1], v176 row_newbcast:8 \n"
|
||||
" v_mul_f32 v177, %[smq_scale1], v177 row_newbcast:9 \n"
|
||||
" v_mul_f32 v178, %[smq_scale1], v178 row_newbcast:10 \n"
|
||||
" v_mul_f32 v179, %[smq_scale1], v179 row_newbcast:11 \n"
|
||||
" v_mul_f32 v180, %[smq_scale1], v180 row_newbcast:8 \n"
|
||||
" v_mul_f32 v181, %[smq_scale1], v181 row_newbcast:9 \n"
|
||||
" v_mul_f32 v182, %[smq_scale1], v182 row_newbcast:10 \n"
|
||||
" v_mul_f32 v183, %[smq_scale1], v183 row_newbcast:11 \n"
|
||||
" v_mul_f32 v184, %[smq_scale1], v184 row_newbcast:12 \n"
|
||||
" v_mul_f32 v185, %[smq_scale1], v185 row_newbcast:13 \n"
|
||||
" v_mul_f32 v186, %[smq_scale1], v186 row_newbcast:14 \n"
|
||||
" v_mul_f32 v187, %[smq_scale1], v187 row_newbcast:15 \n"
|
||||
" v_mul_f32 v188, %[smq_scale1], v188 row_newbcast:12 \n"
|
||||
" v_mul_f32 v189, %[smq_scale1], v189 row_newbcast:13 \n"
|
||||
" v_mul_f32 v190, %[smq_scale1], v190 row_newbcast:14 \n"
|
||||
" v_mul_f32 v191, %[smq_scale1], v191 row_newbcast:15 \n"
|
||||
" buffer_load_dword v12, %[v_os_dq], %[s_res_dq], 0 offen \n"
|
||||
" v_mov_b32 v22, 0x358637bd \n"
|
||||
" v_mov_b32 v23, 0x358637bd \n"
|
||||
" v_max3_f32 v22, abs(v128), abs(v129), v22 \n"
|
||||
@@ -934,9 +937,42 @@
|
||||
" v_lshlrev_b32 v54, 1, v54 \n"
|
||||
" v_add_u32 v55, v54, v55 \n"
|
||||
" v_lshlrev_b32 v54, 2, v55 \n"
|
||||
" ds_read_b64 v[128:129], v54 offset:18688 \n"
|
||||
" ds_read_b64 v[130:131], v54 offset:18816 \n"
|
||||
" ds_read_b64 v[132:133], v54 offset:19712 \n"
|
||||
" ds_read_b64 v[134:135], v54 offset:19840 \n"
|
||||
" ds_read_b64 v[136:137], v54 offset:20736 \n"
|
||||
" ds_read_b64 v[138:139], v54 offset:20864 \n"
|
||||
" ds_read_b64 v[140:141], v54 offset:21760 \n"
|
||||
" ds_read_b64 v[142:143], v54 offset:21888 \n"
|
||||
" ds_read_b64 v[144:145], v54 offset:22784 \n"
|
||||
" ds_read_b64 v[146:147], v54 offset:22912 \n"
|
||||
" ds_read_b64 v[148:149], v54 offset:23808 \n"
|
||||
" ds_read_b64 v[150:151], v54 offset:23936 \n"
|
||||
" ds_read_b64 v[152:153], v54 offset:24832 \n"
|
||||
" ds_read_b64 v[154:155], v54 offset:24960 \n"
|
||||
" ds_read_b64 v[156:157], v54 offset:25856 \n"
|
||||
" ds_read_b64 v[158:159], v54 offset:25984 \n"
|
||||
" ds_read_b64 v[160:161], v54 offset:26880 \n"
|
||||
" ds_read_b64 v[162:163], v54 offset:27008 \n"
|
||||
" ds_read_b64 v[164:165], v54 offset:27904 \n"
|
||||
" ds_read_b64 v[166:167], v54 offset:28032 \n"
|
||||
" ds_read_b64 v[168:169], v54 offset:28928 \n"
|
||||
" ds_read_b64 v[170:171], v54 offset:29056 \n"
|
||||
" ds_read_b64 v[172:173], v54 offset:29952 \n"
|
||||
" ds_read_b64 v[174:175], v54 offset:30080 \n"
|
||||
" ds_read_b64 v[176:177], v54 offset:30976 \n"
|
||||
" ds_read_b64 v[178:179], v54 offset:31104 \n"
|
||||
" ds_read_b64 v[180:181], v54 offset:32000 \n"
|
||||
" ds_read_b64 v[182:183], v54 offset:32128 \n"
|
||||
" ds_read_b64 v[184:185], v54 offset:33024 \n"
|
||||
" ds_read_b64 v[186:187], v54 offset:33152 \n"
|
||||
" ds_read_b64 v[188:189], v54 offset:34048 \n"
|
||||
" ds_read_b64 v[190:191], v54 offset:34176 \n"
|
||||
|
||||
#undef _UK_MFMA_
|
||||
#undef _UK_PK_CVT_
|
||||
#undef _UK_ATOMIC_ADD_
|
||||
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,36 +5,20 @@
|
||||
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_INT8
|
||||
# define _UK_MFMA_ "v_mfma_i32_16x16x32_i8"
|
||||
#endif
|
||||
# define _DEQUAN_CVT_(a0,a1,a2,a3, b, c) \
|
||||
" v_cvt_f32_i32 a0, a0 \n" \
|
||||
" v_cvt_f32_i32 a1, a1 \n" \
|
||||
" v_cvt_f32_i32 a2, a2 \n" \
|
||||
" v_cvt_f32_i32 a3, a3 \n" \
|
||||
" v_mul_f32 a0, v15, a0 \n" \
|
||||
" v_mul_f32 a1, v15, a1 \n" \
|
||||
" v_mul_f32 a2, v15, a2 \n" \
|
||||
" v_mul_f32 a3, v15, a3 \n" \
|
||||
" v_mul_f32 a0, v17, a0 row_newbcast:12 \n" \
|
||||
" v_mul_f32 a1, v17, a1 row_newbcast:13 \n" \
|
||||
" v_mul_f32 a2, v17, a2 row_newbcast:14 \n" \
|
||||
" v_mul_f32 a3, v17, a3 row_newbcast:15 \n" \
|
||||
# define _DEQUAN_CVT_(a0,a1,a2,a3, xq, gq,brd0,brd1,brd2,brd3) \
|
||||
" v_cvt_f32_i32 " a0 ", " a0 " \n" \
|
||||
" v_cvt_f32_i32 " a1 ", " a1 " \n" \
|
||||
" v_cvt_f32_i32 " a2 ", " a2 " \n" \
|
||||
" v_cvt_f32_i32 " a3 ", " a3" \n" \
|
||||
" v_mul_f32 " a0 ", " xq ", " a0 " \n" \
|
||||
" v_mul_f32 " a1 ", " xq ", " a1 " \n" \
|
||||
" v_mul_f32 " a2 ", " xq ", " a2 " \n" \
|
||||
" v_mul_f32 " a3 ", " xq ", " a3 " \n" \
|
||||
" v_mul_f32 " a0 ", " gq ", " a0 " row_newbcast:" brd0 " \n" \
|
||||
" v_mul_f32 " a1 ", " gq ", " a1 " row_newbcast:" brd1 " \n" \
|
||||
" v_mul_f32 " a2 ", " gq ", " a2 " row_newbcast: " brd2 " \n" \
|
||||
" v_mul_f32 " a3 ", " gq ", " a3 " row_newbcast:" brd3 " \n"
|
||||
|
||||
";---------------------------------------------- \n"
|
||||
" v_lshrrev_b32 v54, 4, v0 \n"
|
||||
" v_lshlrev_b32 v55, 2, v54 \n"
|
||||
" v_and_b32 v54, 15, v0 \n"
|
||||
" v_lshrrev_b32 v56, 2, v54 \n"
|
||||
" v_lshlrev_b32 v56, 6, v56 \n"
|
||||
" v_add_u32 v55, v56, v55 \n"
|
||||
" v_and_b32 v54, 3, v0 \n"
|
||||
" v_add_u32 v55, v54, v55 \n"
|
||||
" v_lshlrev_b32 v10, 2, v55 \n"
|
||||
" v_add_u32 v11, 0x00000400, v10 \n"
|
||||
" s_mul_i32 s60, %[s_wave_id], 16 \n"
|
||||
" s_mul_i32 s60, s60, 4 \n"
|
||||
" v_add_u32 v10, s60, v10 \n"
|
||||
" v_add_u32 v11, s60, v11 \n"
|
||||
" v_mov_b32 v5, v10 \n"
|
||||
";---------------------------------------------- \n"
|
||||
" s_mov_b32 s57, 0x00000100 \n"
|
||||
" s_mov_b32 s58, 0x00001000 \n"
|
||||
@@ -53,27 +37,22 @@
|
||||
" v_mov_b32 v52, 0x7fff0000 \n"
|
||||
" v_mov_b32 v53, 0x00007fff \n"
|
||||
" s_waitcnt 0x0000 \n"
|
||||
";---------------------------------------------- \n"
|
||||
" v_lshrrev_b32 v54, 24, %[v_token_id0] \n"
|
||||
" v_mul_i32_i24 v54, s66, v54 \n"
|
||||
" v_and_b32 v55, 0x00ffffff, %[v_token_id0] \n"
|
||||
" v_add_u32 %[v_token_id0], v54, v55 \n"
|
||||
" v_lshrrev_b32 v54, 24, %[v_token_id1] \n"
|
||||
" v_mul_i32_i24 v54, s66, v54 \n"
|
||||
" v_and_b32 v55, 0x00ffffff, %[v_token_id1] \n"
|
||||
" v_add_u32 %[v_token_id1], v54, v55 \n"
|
||||
" v_lshlrev_b32 %[v_token_id0], 2, %[v_token_id0] \n"
|
||||
" v_lshlrev_b32 %[v_token_id1], 2, %[v_token_id1] \n"
|
||||
" buffer_load_dword v14, %[v_token_id0], %[s_res_aq], 0 offen \n"
|
||||
" buffer_load_dword v15, %[v_token_id1], %[s_res_aq], 0 offen \n"
|
||||
" buffer_load_dword v16, v10, %[s_res_gq], 0 offen \n"
|
||||
" buffer_load_dword v17, v11, %[s_res_gq], 0 offen \n"
|
||||
" buffer_load_dword v18, v10, %[s_res_smq], 0 offen \n"
|
||||
" buffer_load_dword v19, v11, %[s_res_smq], 0 offen \n"
|
||||
" buffer_load_dword v20, v8, s[40:43], 0 offen \n"
|
||||
" buffer_load_dword v21, v9, s[40:43], 0 offen \n"
|
||||
|
||||
" s_mov_b32 s80, 0 \n"
|
||||
" v_lshrrev_b32 v54, 4, v0 \n"
|
||||
" v_mul_i32_i24 v3, 34, v54 \n"
|
||||
" v_and_b32 v54, 15, v0 \n"
|
||||
" v_mul_i32_i24 v55, 2, v54 \n"
|
||||
" v_add_u32 v3, v55, v3 \n"
|
||||
" s_mul_i32 s60, s7, 0x00000088 \n"
|
||||
" v_add_u32 v3, s60, v3 \n"
|
||||
" v_lshlrev_b32 v3, 2, v3 \n"
|
||||
" v_lshrrev_b32 v54, 1, v0 \n"
|
||||
" v_mul_i32_i24 v4, 34, v54 \n"
|
||||
" v_and_b32 v55, 1, v0 \n"
|
||||
" v_add_u32 v4, v55, v4 \n"
|
||||
" s_mul_i32 s60, s7, 2 \n"
|
||||
" v_add_u32 v4, s60, v4 \n"
|
||||
" v_lshlrev_b32 v4, 2, v4 \n"
|
||||
";---------------------------------------------- \n"
|
||||
"; -- prefetch A0\n"
|
||||
"s_add_u32 m0, 0, %[s_m0_init] \n"
|
||||
@@ -570,198 +549,23 @@
|
||||
" s_branch label_start \n"
|
||||
" label_end : \n"
|
||||
";---------------------------------------------- \n"
|
||||
" v_cvt_f32_i32 v128, v128 \n"
|
||||
" v_cvt_f32_i32 v129, v129 \n"
|
||||
" v_cvt_f32_i32 v130, v130 \n"
|
||||
" v_cvt_f32_i32 v131, v131 \n"
|
||||
" v_mul_f32 v128, v14, v128 \n"
|
||||
" v_mul_f32 v129, v14, v129 \n"
|
||||
" v_mul_f32 v130, v14, v130 \n"
|
||||
" v_mul_f32 v131, v14, v131 \n"
|
||||
" v_mul_f32 v128, v16, v128 row_newbcast:0 \n"
|
||||
" v_mul_f32 v129, v16, v129 row_newbcast:1 \n"
|
||||
" v_mul_f32 v130, v16, v130 row_newbcast:2 \n"
|
||||
" v_mul_f32 v131, v16, v131 row_newbcast:3 \n"
|
||||
" v_cvt_f32_i32 v132, v132 \n"
|
||||
" v_cvt_f32_i32 v133, v133 \n"
|
||||
" v_cvt_f32_i32 v134, v134 \n"
|
||||
" v_cvt_f32_i32 v135, v135 \n"
|
||||
" v_mul_f32 v132, v15, v132 \n"
|
||||
" v_mul_f32 v133, v15, v133 \n"
|
||||
" v_mul_f32 v134, v15, v134 \n"
|
||||
" v_mul_f32 v135, v15, v135 \n"
|
||||
" v_mul_f32 v132, v16, v132 row_newbcast:0 \n"
|
||||
" v_mul_f32 v133, v16, v133 row_newbcast:1 \n"
|
||||
" v_mul_f32 v134, v16, v134 row_newbcast:2 \n"
|
||||
" v_mul_f32 v135, v16, v135 row_newbcast:3 \n"
|
||||
" v_cvt_f32_i32 v136, v136 \n"
|
||||
" v_cvt_f32_i32 v137, v137 \n"
|
||||
" v_cvt_f32_i32 v138, v138 \n"
|
||||
" v_cvt_f32_i32 v139, v139 \n"
|
||||
" v_mul_f32 v136, v14, v136 \n"
|
||||
" v_mul_f32 v137, v14, v137 \n"
|
||||
" v_mul_f32 v138, v14, v138 \n"
|
||||
" v_mul_f32 v139, v14, v139 \n"
|
||||
" v_mul_f32 v136, v16, v136 row_newbcast:4 \n"
|
||||
" v_mul_f32 v137, v16, v137 row_newbcast:5 \n"
|
||||
" v_mul_f32 v138, v16, v138 row_newbcast:6 \n"
|
||||
" v_mul_f32 v139, v16, v139 row_newbcast:7 \n"
|
||||
" v_cvt_f32_i32 v140, v140 \n"
|
||||
" v_cvt_f32_i32 v141, v141 \n"
|
||||
" v_cvt_f32_i32 v142, v142 \n"
|
||||
" v_cvt_f32_i32 v143, v143 \n"
|
||||
" v_mul_f32 v140, v15, v140 \n"
|
||||
" v_mul_f32 v141, v15, v141 \n"
|
||||
" v_mul_f32 v142, v15, v142 \n"
|
||||
" v_mul_f32 v143, v15, v143 \n"
|
||||
" v_mul_f32 v140, v16, v140 row_newbcast:4 \n"
|
||||
" v_mul_f32 v141, v16, v141 row_newbcast:5 \n"
|
||||
" v_mul_f32 v142, v16, v142 row_newbcast:6 \n"
|
||||
" v_mul_f32 v143, v16, v143 row_newbcast:7 \n"
|
||||
" v_cvt_f32_i32 v144, v144 \n"
|
||||
" v_cvt_f32_i32 v145, v145 \n"
|
||||
" v_cvt_f32_i32 v146, v146 \n"
|
||||
" v_cvt_f32_i32 v147, v147 \n"
|
||||
" v_mul_f32 v144, v14, v144 \n"
|
||||
" v_mul_f32 v145, v14, v145 \n"
|
||||
" v_mul_f32 v146, v14, v146 \n"
|
||||
" v_mul_f32 v147, v14, v147 \n"
|
||||
" v_mul_f32 v144, v16, v144 row_newbcast:8 \n"
|
||||
" v_mul_f32 v145, v16, v145 row_newbcast:9 \n"
|
||||
" v_mul_f32 v146, v16, v146 row_newbcast:10 \n"
|
||||
" v_mul_f32 v147, v16, v147 row_newbcast:11 \n"
|
||||
" v_cvt_f32_i32 v148, v148 \n"
|
||||
" v_cvt_f32_i32 v149, v149 \n"
|
||||
" v_cvt_f32_i32 v150, v150 \n"
|
||||
" v_cvt_f32_i32 v151, v151 \n"
|
||||
" v_mul_f32 v148, v15, v148 \n"
|
||||
" v_mul_f32 v149, v15, v149 \n"
|
||||
" v_mul_f32 v150, v15, v150 \n"
|
||||
" v_mul_f32 v151, v15, v151 \n"
|
||||
" v_mul_f32 v148, v16, v148 row_newbcast:8 \n"
|
||||
" v_mul_f32 v149, v16, v149 row_newbcast:9 \n"
|
||||
" v_mul_f32 v150, v16, v150 row_newbcast:10 \n"
|
||||
" v_mul_f32 v151, v16, v151 row_newbcast:11 \n"
|
||||
" v_cvt_f32_i32 v152, v152 \n"
|
||||
" v_cvt_f32_i32 v153, v153 \n"
|
||||
" v_cvt_f32_i32 v154, v154 \n"
|
||||
" v_cvt_f32_i32 v155, v155 \n"
|
||||
" v_mul_f32 v152, v14, v152 \n"
|
||||
" v_mul_f32 v153, v14, v153 \n"
|
||||
" v_mul_f32 v154, v14, v154 \n"
|
||||
" v_mul_f32 v155, v14, v155 \n"
|
||||
" v_mul_f32 v152, v16, v152 row_newbcast:12 \n"
|
||||
" v_mul_f32 v153, v16, v153 row_newbcast:13 \n"
|
||||
" v_mul_f32 v154, v16, v154 row_newbcast:14 \n"
|
||||
" v_mul_f32 v155, v16, v155 row_newbcast:15 \n"
|
||||
" v_cvt_f32_i32 v156, v156 \n"
|
||||
" v_cvt_f32_i32 v157, v157 \n"
|
||||
" v_cvt_f32_i32 v158, v158 \n"
|
||||
" v_cvt_f32_i32 v159, v159 \n"
|
||||
" v_mul_f32 v156, v15, v156 \n"
|
||||
" v_mul_f32 v157, v15, v157 \n"
|
||||
" v_mul_f32 v158, v15, v158 \n"
|
||||
" v_mul_f32 v159, v15, v159 \n"
|
||||
" v_mul_f32 v156, v16, v156 row_newbcast:12 \n"
|
||||
" v_mul_f32 v157, v16, v157 row_newbcast:13 \n"
|
||||
" v_mul_f32 v158, v16, v158 row_newbcast:14 \n"
|
||||
" v_mul_f32 v159, v16, v159 row_newbcast:15 \n"
|
||||
" v_cvt_f32_i32 v160, v160 \n"
|
||||
" v_cvt_f32_i32 v161, v161 \n"
|
||||
" v_cvt_f32_i32 v162, v162 \n"
|
||||
" v_cvt_f32_i32 v163, v163 \n"
|
||||
" v_mul_f32 v160, v14, v160 \n"
|
||||
" v_mul_f32 v161, v14, v161 \n"
|
||||
" v_mul_f32 v162, v14, v162 \n"
|
||||
" v_mul_f32 v163, v14, v163 \n"
|
||||
" v_mul_f32 v160, v17, v160 row_newbcast:0 \n"
|
||||
" v_mul_f32 v161, v17, v161 row_newbcast:1 \n"
|
||||
" v_mul_f32 v162, v17, v162 row_newbcast:2 \n"
|
||||
" v_mul_f32 v163, v17, v163 row_newbcast:3 \n"
|
||||
" v_cvt_f32_i32 v164, v164 \n"
|
||||
" v_cvt_f32_i32 v165, v165 \n"
|
||||
" v_cvt_f32_i32 v166, v166 \n"
|
||||
" v_cvt_f32_i32 v167, v167 \n"
|
||||
" v_mul_f32 v164, v15, v164 \n"
|
||||
" v_mul_f32 v165, v15, v165 \n"
|
||||
" v_mul_f32 v166, v15, v166 \n"
|
||||
" v_mul_f32 v167, v15, v167 \n"
|
||||
" v_mul_f32 v164, v17, v164 row_newbcast:0 \n"
|
||||
" v_mul_f32 v165, v17, v165 row_newbcast:1 \n"
|
||||
" v_mul_f32 v166, v17, v166 row_newbcast:2 \n"
|
||||
" v_mul_f32 v167, v17, v167 row_newbcast:3 \n"
|
||||
" v_cvt_f32_i32 v168, v168 \n"
|
||||
" v_cvt_f32_i32 v169, v169 \n"
|
||||
" v_cvt_f32_i32 v170, v170 \n"
|
||||
" v_cvt_f32_i32 v171, v171 \n"
|
||||
" v_mul_f32 v168, v14, v168 \n"
|
||||
" v_mul_f32 v169, v14, v169 \n"
|
||||
" v_mul_f32 v170, v14, v170 \n"
|
||||
" v_mul_f32 v171, v14, v171 \n"
|
||||
" v_mul_f32 v168, v17, v168 row_newbcast:4 \n"
|
||||
" v_mul_f32 v169, v17, v169 row_newbcast:5 \n"
|
||||
" v_mul_f32 v170, v17, v170 row_newbcast:6 \n"
|
||||
" v_mul_f32 v171, v17, v171 row_newbcast:7 \n"
|
||||
" v_cvt_f32_i32 v172, v172 \n"
|
||||
" v_cvt_f32_i32 v173, v173 \n"
|
||||
" v_cvt_f32_i32 v174, v174 \n"
|
||||
" v_cvt_f32_i32 v175, v175 \n"
|
||||
" v_mul_f32 v172, v15, v172 \n"
|
||||
" v_mul_f32 v173, v15, v173 \n"
|
||||
" v_mul_f32 v174, v15, v174 \n"
|
||||
" v_mul_f32 v175, v15, v175 \n"
|
||||
" v_mul_f32 v172, v17, v172 row_newbcast:4 \n"
|
||||
" v_mul_f32 v173, v17, v173 row_newbcast:5 \n"
|
||||
" v_mul_f32 v174, v17, v174 row_newbcast:6 \n"
|
||||
" v_mul_f32 v175, v17, v175 row_newbcast:7 \n"
|
||||
" v_cvt_f32_i32 v176, v176 \n"
|
||||
" v_cvt_f32_i32 v177, v177 \n"
|
||||
" v_cvt_f32_i32 v178, v178 \n"
|
||||
" v_cvt_f32_i32 v179, v179 \n"
|
||||
" v_mul_f32 v176, v14, v176 \n"
|
||||
" v_mul_f32 v177, v14, v177 \n"
|
||||
" v_mul_f32 v178, v14, v178 \n"
|
||||
" v_mul_f32 v179, v14, v179 \n"
|
||||
" v_mul_f32 v176, v17, v176 row_newbcast:8 \n"
|
||||
" v_mul_f32 v177, v17, v177 row_newbcast:9 \n"
|
||||
" v_mul_f32 v178, v17, v178 row_newbcast:10 \n"
|
||||
" v_mul_f32 v179, v17, v179 row_newbcast:11 \n"
|
||||
" v_cvt_f32_i32 v180, v180 \n"
|
||||
" v_cvt_f32_i32 v181, v181 \n"
|
||||
" v_cvt_f32_i32 v182, v182 \n"
|
||||
" v_cvt_f32_i32 v183, v183 \n"
|
||||
" v_mul_f32 v180, v15, v180 \n"
|
||||
" v_mul_f32 v181, v15, v181 \n"
|
||||
" v_mul_f32 v182, v15, v182 \n"
|
||||
" v_mul_f32 v183, v15, v183 \n"
|
||||
" v_mul_f32 v180, v17, v180 row_newbcast:8 \n"
|
||||
" v_mul_f32 v181, v17, v181 row_newbcast:9 \n"
|
||||
" v_mul_f32 v182, v17, v182 row_newbcast:10 \n"
|
||||
" v_mul_f32 v183, v17, v183 row_newbcast:11 \n"
|
||||
" v_cvt_f32_i32 v184, v184 \n"
|
||||
" v_cvt_f32_i32 v185, v185 \n"
|
||||
" v_cvt_f32_i32 v186, v186 \n"
|
||||
" v_cvt_f32_i32 v187, v187 \n"
|
||||
" v_mul_f32 v184, v14, v184 \n"
|
||||
" v_mul_f32 v185, v14, v185 \n"
|
||||
" v_mul_f32 v186, v14, v186 \n"
|
||||
" v_mul_f32 v187, v14, v187 \n"
|
||||
" v_mul_f32 v184, v17, v184 row_newbcast:12 \n"
|
||||
" v_mul_f32 v185, v17, v185 row_newbcast:13 \n"
|
||||
" v_mul_f32 v186, v17, v186 row_newbcast:14 \n"
|
||||
" v_mul_f32 v187, v17, v187 row_newbcast:15 \n"
|
||||
" v_cvt_f32_i32 v188, v188 \n"
|
||||
" v_cvt_f32_i32 v189, v189 \n"
|
||||
" v_cvt_f32_i32 v190, v190 \n"
|
||||
" v_cvt_f32_i32 v191, v191 \n"
|
||||
" v_mul_f32 v188, v15, v188 \n"
|
||||
" v_mul_f32 v189, v15, v189 \n"
|
||||
" v_mul_f32 v190, v15, v190 \n"
|
||||
" v_mul_f32 v191, v15, v191 \n"
|
||||
" v_mul_f32 v188, v17, v188 row_newbcast:12 \n"
|
||||
" v_mul_f32 v189, v17, v189 row_newbcast:13 \n"
|
||||
" v_mul_f32 v190, v17, v190 row_newbcast:14 \n"
|
||||
" v_mul_f32 v191, v17, v191 row_newbcast:15 \n"
|
||||
_DEQUAN_CVT_("%[c0]","%[c1]","%[c2]","%[c3]","%[a_scale0]"," %[gq_scale0]","0","1","2","3")
|
||||
_DEQUAN_CVT_("%[c4]","%[c5]","%[c6]","%[c7]","%[a_scale1]"," %[gq_scale0]","0","1","2","3")
|
||||
_DEQUAN_CVT_("%[c8]","%[c9]","%[c10]","%[c11]","%[a_scale0]"," %[gq_scale0]","4","5","6","7")
|
||||
_DEQUAN_CVT_("%[c12]","%[c13]","%[c14]","%[c15]","%[a_scale1]"," %[gq_scale0]","4","5","6","7")
|
||||
_DEQUAN_CVT_("%[c16]","%[c17]","%[c18]","%[c19]","%[a_scale0]"," %[gq_scale0]","8","9","10","11")
|
||||
_DEQUAN_CVT_("%[c20]","%[c21]","%[c22]","%[c23]","%[a_scale1]"," %[gq_scale0]","8","9","10","11")
|
||||
_DEQUAN_CVT_("%[c24]","%[c25]","%[c26]","%[c27]","%[a_scale0]"," %[gq_scale0]","12","13","14","15")
|
||||
_DEQUAN_CVT_("%[c28]","%[c29]","%[c30]","%[c31]","%[a_scale1]"," %[gq_scale0]","12","13","14","15")
|
||||
_DEQUAN_CVT_("%[c32]","%[c33]","%[c34]","%[c35]","%[a_scale0]"," %[gq_scale1]","0","1","2","3")
|
||||
_DEQUAN_CVT_("%[c36]","%[c37]","%[c38]","%[c39]","%[a_scale1]"," %[gq_scale1]","0","1","2","3")
|
||||
_DEQUAN_CVT_("%[c40]","%[c41]","%[c42]","%[c43]","%[a_scale0]"," %[gq_scale1]","4","5","6","7")
|
||||
_DEQUAN_CVT_("%[c44]","%[c45]","%[c46]","%[c47]","%[a_scale1]"," %[gq_scale1]","4","5","6","7")
|
||||
_DEQUAN_CVT_("%[c48]","%[c49]","%[c50]","%[c51]","%[a_scale0]"," %[gq_scale1]","8","9","10","11")
|
||||
_DEQUAN_CVT_("%[c52]","%[c53]","%[c54]","%[c55]","%[a_scale1]"," %[gq_scale1]","8","9","10","11")
|
||||
_DEQUAN_CVT_("%[c56]","%[c57]","%[c58]","%[c59]","%[a_scale0]"," %[gq_scale1]","12","13","14","15")
|
||||
_DEQUAN_CVT_("%[c60]","%[c61]","%[c62]","%[c63]","%[a_scale1]"," %[gq_scale1]","12","13","14","15")
|
||||
|
||||
#undef _UK_MFMA_
|
||||
#undef _DEQUAN_CVT_
|
||||
|
||||
|
||||
@@ -186,6 +186,50 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
|
||||
return coords;
|
||||
}
|
||||
// TODO: this row id is before shuffle atomic, need use acc distribution
|
||||
//this calculation shared by G and SMQ
|
||||
CK_TILE_DEVICE auto GetColCoords_GQSMQ(index_t base_offset)
|
||||
{
|
||||
constexpr index_t MLanes = BlockShape::Warp_M1;
|
||||
constexpr index_t Repeat_N = 2;//different,this load is partitioned along N
|
||||
|
||||
// auto h_id = threadIdx.x / MLanes ;
|
||||
// auto r_id = threadIdx.x & 0xffff;
|
||||
// auto p_id = r_id/4;
|
||||
// auto q_is = threadIdx.x & 0x3;
|
||||
|
||||
array<index_t, Repeat_N> coords;
|
||||
static_for<0, Repeat_N, 1>{}([&](auto i) { coords.at(i) = base_coord + (threadIdx.x / MLanes) * 4 +
|
||||
(threadIdx.x & 0xffff)/4 * 64 +
|
||||
q_id +
|
||||
i * 256 ; });
|
||||
return coords;
|
||||
}
|
||||
//this calculation shared by G and SMQ
|
||||
CK_TILE_DEVICE auto GetGQScale(const COL_IDS coords,
|
||||
const GScaleDataType* g_scale_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
|
||||
array<GScaleDataType, n_size> g_scale_value;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
g_scale_value.at(i) = g_scale_ptr[coords[i]];
|
||||
});
|
||||
|
||||
return g_scale_value;
|
||||
}
|
||||
CK_TILE_DEVICE auto GetSMQScale(const COL_IDS coords,
|
||||
const YSmoothScaleDataType * y_scale_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
array<YSmoothScaleDataType, n_size> y_scale_value;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
y_scale_value.at(i) = y_scale_ptr[coords[i]];
|
||||
});
|
||||
|
||||
return y_scale_value;
|
||||
}
|
||||
|
||||
|
||||
template <typename Karg>
|
||||
CK_TILE_DEVICE auto operator()(const Karg& kargs,
|
||||
@@ -230,12 +274,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
return (row_ids_a[i]) &0xffffff;
|
||||
},
|
||||
number<row_ids_a.size()>{});
|
||||
// auto token_id_mma = generate_tuple(
|
||||
// [&](auto i) {
|
||||
// return (row_ids_a_mma[i]) &0xffffff;
|
||||
// },
|
||||
// number<row_ids_a_mma.size()>{});
|
||||
//addr in fact
|
||||
auto a_coords = generate_tuple(
|
||||
[&](auto i) {
|
||||
return ((row_ids_a[i])&0xffffff) * kargs.stride_token +
|
||||
@@ -306,7 +344,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
auto smq_win = [&]() {
|
||||
const YSmoothScaleDataType* smq_ptr = reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * smq_scale_expert_stride_0 +
|
||||
intermediate_tile_id * BlockShape::Block_N0;
|
||||
intermediate_tile_id * BlockShape::Block_K1;
|
||||
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr);//remember to add expert id for inline
|
||||
auto smq_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
smq_ptr,
|
||||
@@ -346,15 +384,15 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
//////gq
|
||||
auto dq_win = [&]() {
|
||||
const DScaleDataType* g_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1;
|
||||
const DScaleDataType* dq_ptr = reinterpret_cast<const DScaleDataType*>(kargs.d_scale_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * d_scale_expert_stride_1;
|
||||
// const GDataType* g_ptr = reinterpret_cast<const GScaleDataType*>(kargs.d_scale_ptr)//remember to add expert_id as expert_idx
|
||||
auto g_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
g_ptr,
|
||||
auto dq_view_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
dq_ptr,
|
||||
make_tuple(kargs.hidden_size),
|
||||
number<1>{});
|
||||
|
||||
return g_view_;
|
||||
return dq_view_;
|
||||
}();
|
||||
|
||||
auto dq_res = dq_win.get_buffer_view().cached_buf_res_;
|
||||
@@ -400,15 +438,6 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
generate_tuple([&](auto i) { return cmp_lt_to_exec(token_id[i], kargs.num_tokens); },
|
||||
number<row_ids_a.size()>{});
|
||||
|
||||
// auto bridge_sst_win = [&]() {
|
||||
// constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
|
||||
// constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
|
||||
// return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
|
||||
// reinterpret_cast<YDataType*>(smem), desc_),
|
||||
// desc_.get_lengths(),
|
||||
// {0, 0},
|
||||
// dist_);
|
||||
// }();
|
||||
auto o_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
|
||||
@@ -417,16 +446,17 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
auto w_scale = GetWeightScale(
|
||||
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
|
||||
auto a_scale = GetAScale(
|
||||
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.a_scale_ptr));
|
||||
|
||||
row_ids_a_mma, reinterpret_cast<const AScaleDataType*>(kargs.a_scale_ptr));
|
||||
auto gqsmq_coords = GetColCoords_GQSMQ(intermediated_tile_id * BlockShape::Block_K1);
|
||||
auto dq_coords = gqsmq_coords[0];//only one for this tiling
|
||||
auto gq_scale = GetGQScale(
|
||||
gqsmq_coords, reinterpret_cast<const GScaleDataType*>(kargs.g_scale_ptr + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
|
||||
auto smq_scale = GetSMQScale(
|
||||
gqsmq_coords, reinterpret_cast<const YSmoothScaleDataType*>(kargs.y_smooth_scale_ptr + static_cast<long_index_t>(expert_id) * shared_intermediate_size_0));
|
||||
auto uk_0 = Policy::template GetUK_0<Problem>();
|
||||
// auto acc_0= uk_0(
|
||||
uk_0(
|
||||
row_ids_a_mma,//fake token id, 2D index for X scale
|
||||
a_scale,
|
||||
dq_res,
|
||||
gq_res,
|
||||
smq_res,
|
||||
uk_0( a_scale,
|
||||
gq_scale,
|
||||
a_res,
|
||||
a_coords,
|
||||
g_res,
|
||||
@@ -457,6 +487,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
auto uk_1 = Policy::template GetUK_1<Problem>();
|
||||
uk_1(dq_res,
|
||||
d_res,
|
||||
dq_coords,
|
||||
d_coords,
|
||||
o_res,
|
||||
o_coords,
|
||||
@@ -464,6 +495,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8
|
||||
smem,
|
||||
kargs.hidden_size, // total n number
|
||||
w_scale,
|
||||
smq_scale,
|
||||
BlockShape::Block_N1,
|
||||
shared_intermediate_size_1 * BlockShape::Block_N1 - kr_1 * BlockShape::Block_W1, // along N
|
||||
kr_1 * BlockShape::Block_W1,
|
||||
|
||||
Reference in New Issue
Block a user