From 5d00b37e6be1ab233acc4a33bbb635aefe6d993a Mon Sep 17 00:00:00 2001 From: shengnxu Date: Tue, 7 Jan 2025 07:42:05 +0000 Subject: [PATCH] fix loop cnt and half d buffer size --- ...latmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp | 2 +- ...k_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc | 15 ++++++--------- .../fused_moegemm_pipeline_flatmm_uk_int8.hpp | 18 ++++++++++-------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp index a22d568a08..001d7fb59b 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x256x512_1x4x1_16x16x64_int8.hpp @@ -72,7 +72,7 @@ struct FlatmmSn_32x256x512_1x4x1_16x16x64_Base struct FlatmmSn_32x256x512_1x4x1_16x16x64_int8 : public FlatmmSn_32x256x512_1x4x1_16x16x64_Base { using BDataType = int8_t; - using ODataType = int8_t; + using ODataType = bf16_t; using DScaleDataType = float_t; // TODO: need paired with tile_window_linear! diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc index 057a545ffc..d040c65cfa 100644 --- a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x256x512_1x4x1_16x16x32_int8_2.inc @@ -205,7 +205,7 @@ " v_mfma_i32_16x16x32_i8 v[220:223], acc[124:125], v[188:189], v[220:223] \n" " v_mfma_i32_16x16x32_i8 v[220:223], acc[126:127], v[190:191], v[220:223] \n" " s_add_u32 s60, 0x00000200, s80 \n" -" s_cmp_lt_u32 s60, s81 \n" +" s_cmp_lt_u32 s60, %[s_loop_cnt] \n" " s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n" " s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n" " s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n" @@ -528,10 +528,10 @@ " s_mov_b64 exec, %[s_execflag_7] \n" " global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n" " s_mov_b64 exec, s[36:37] \n" -" s_add_u32 %[s_res_o0], s59, %[s_res_o0] \n" +" s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n" " s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n" " s_addk_i32 s80, 0x0100 \n" -" s_cmp_lt_i32 s80, s81 \n" +" s_cmp_lt_i32 s80, %[s_loop_cnt] \n" " s_cbranch_scc0 label_end_gemm2 \n" " s_waitcnt vmcnt(41) \n" " s_barrier \n" @@ -702,7 +702,7 @@ " v_mfma_i32_16x16x32_i8 v[252:255], acc[252:253], v[188:189], v[252:255] \n" " v_mfma_i32_16x16x32_i8 v[252:255], acc[254:255], v[190:191], v[252:255] \n" " s_add_u32 s60, 0x00000200, s80 \n" -" s_cmp_lt_u32 s60, s81 \n" +" s_cmp_lt_u32 s60, %[s_loop_cnt] \n" " s_cselect_b32 %[s_tile_os_b], %[s_tile_os_b], 0 \n" " s_cselect_b32 %[s_tile_os_b_half], %[s_tile_os_b_half], 0 \n" " s_cselect_b32 %[s_tile_os_dq], %[s_tile_os_dq], 0 \n" @@ -1025,10 +1025,10 @@ " s_mov_b64 exec, %[s_execflag_7] \n" " global_atomic_pk_add_bf16 %[v_os_o7], v79, [%[s_res_o0],%[s_res_o1]] inst_offset:256 \n" " s_mov_b64 exec, s[36:37] \n" -" s_add_u32 %[s_res_o0], s59, %[s_res_o0] \n" +" s_add_u32 %[s_res_o0], %[s_tile_os_o], %[s_res_o0] \n" " s_addc_u32 %[s_res_o1], 0, %[s_res_o1] \n" " s_addk_i32 s80, 0x0100 \n" -" s_cmp_lt_i32 s80, s81 \n" +" s_cmp_lt_i32 s80, %[s_loop_cnt] \n" " s_cbranch_scc0 label_end_gemm2 \n" " s_branch label_startgemm2 \n" " label_end_gemm2: \n" @@ -1037,6 +1037,3 @@ #undef _UK_MFMA_ #undef _UK_PK_CVT_ #undef _UK_ATOMIC_ADD_ - - - diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp index 7926760e1c..bd16c9d365 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk_int8.hpp @@ -372,8 +372,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 Nl_; // Kr0_ * Kr1_ * W_; return generate_tuple( [&](auto i) { - constexpr auto i_nr_ = number{}; - return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + + // constexpr auto i_nr_ = number{}; + return i * shared_intermediate_size_1 * Nw_ * Nl_ + base_os_; }, number{}); @@ -382,7 +382,7 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 auto o_coords = generate_tuple( [&](auto i) { return token_id[i] * kargs.stride_token + - threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO; + threadIdx.x % (BlockShape::Block_N1/2 / kAlignmentO) * kAlignmentO; }, number{}); @@ -420,11 +420,13 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 BlockShape::Block_K0, // tile offset for B matrix each unroll BlockShape::Block_Kr0 * BlockShape::Block_W0); // tile offset for B matrix each unroll - if(hipBlockIdx_x == 0 && hipBlockIdx_y == 0 && hipBlockIdx_z == 0 && - hipThreadIdx_x == 5) + if(hipBlockIdx_x == 1 && hipBlockIdx_y == 1 && hipBlockIdx_z == 0 && + hipThreadIdx_x == 64) { printf("\ngemm0 done\n"); - + // printf("\n wg 1 1, wave 1, row_coords_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_coords_a[number<0>{}],row_coords_a[number<1>{}],row_coords_a[number<2>{}],row_coords_a[number<3>{}], row_coords_a[number<4>{}],row_coords_a[number<5>{}],row_coords_a[number<6>{}],row_coords_a[number<7>{}]); + // printf("\n -------------- -row_ids_a 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", row_ids_a[number<0>{}],row_ids_a[number<1>{}],row_ids_a[number<2>{}],row_ids_a[number<3>{}], row_ids_a[number<4>{}],row_ids_a[number<5>{}],row_ids_a[number<6>{}],row_ids_a[number<7>{}]); + printf("\n -------------- - token_id 0: %x 1: %x, 2: %x, 3:%x, 5: %x 6: %x, 7: %x, 8:%x,, \n", token_id[number<0>{}],token_id[number<1>{}],token_id[number<2>{}],token_id[number<3>{}], token_id[number<4>{}],token_id[number<5>{}],token_id[number<6>{}],token_id[number<7>{}]); } // sweep_tile( // acc_0, @@ -457,8 +459,8 @@ struct FusedMoeGemmPipeline_FlatmmUk_int8 w_scale, smq_scale, BlockShape::Block_N1, - shared_intermediate_size_1 * BlockShape::Block_N1 - kr_1 * BlockShape::Block_W1, // along N - kr_1 * BlockShape::Block_W1, + shared_intermediate_size_1 * BlockShape::Block_N1 - 256 * 16, // along N + 256 * 16, BlockShape::Block_N1); // along N } };