diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index 50d993dfd2..6e29db76d9 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -40,7 +40,7 @@ using B1DataType = F32; // using EDataType = F16; using EDataType = BF16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = EDataType; using D2DataType = F32; using DsDataType = ck::Tuple; @@ -126,7 +126,7 @@ static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul -static constexpr bool MulRoutedWeight = false; +static constexpr bool MulRoutedWeight = true; #if 0 static constexpr ck::index_t MPerBlock = 32; @@ -466,7 +466,7 @@ int main(int argc, char* argv[]) Tensor b_e_n_k({experts, K, N * 2}); e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); // handle scale before ref. for(int t = 0; t < tokens; ++t) @@ -491,7 +491,7 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm1BlockScale; @@ -64,23 +64,25 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()(EDataType& e, const EDataType& c, const float& d2) const { - // (void) d2; - e = ck::type_convert(c * d2); + // for real kernel use + (void)d2; + e = ck::type_convert(c); } template <> __host__ __device__ constexpr void operator()(EDataType& e, const float& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const { // for reference cpu e = ck::type_convert(c * d2); } - // template <> - // __host__ __device__ constexpr void - // operator()(float& e, const float& c, const float& d2) const - // { - // // for reference cpu - // e = ck::type_convert(c * d2); - // } }; void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) @@ -213,7 +215,7 @@ int main(int argc, char* argv[]) { // use default case } - else if(argc == 3) + else if(argc == 4) { // use default case do_verification = std::stoi(argv[1]); @@ -317,9 +319,9 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; @@ -467,7 +469,7 @@ int main(int argc, char* argv[]) Tensor a_t_k_k({tokens, topk, K}); Tensor b_e_n_k({experts, K, N}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); for(int t = 0; t < tokens; ++t) { @@ -496,7 +498,7 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2BlockScale - __device__ void RunRead(const SrcDescs& src_descs, - const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, - Number thread_scratch_id = Number{}) - { - if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or - ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id); - } - } - template using is_tuple = decltype(std::declval().IsTuple()); @@ -188,18 +175,6 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter RunWrite(dst_descs, dst_bufs, scatter_offsets); } - template - __device__ void Run(const SrcDescs& src_descs, - const SrcBuffers& src_bufs, - const DstDescs& dst_descs, - DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) - { - RunRead(src_descs, src_bufs, scatter_weights); - RunWrite(dst_descs, dst_bufs, scatter_offsets); - } - template __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index eb7031d2b1..74a27578d8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -89,21 +89,21 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1206,7 +1206,7 @@ struct GridwiseMoeGemmBlockScale math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N , ScaleBlockN), + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); @@ -1265,10 +1265,11 @@ struct GridwiseMoeGemmBlockScale } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = - __builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * - math::integer_divide_ceil(problem.K, ScaleBlockK)); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockK)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1461,22 +1462,24 @@ struct GridwiseMoeGemmBlockScale get_warp_local_1d_id() % NWave, 0, KPack / KGroup * (get_thread_local_1d_id() % warpSize))); - const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; + const BScaleType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2, - Sequence<0, 1>, - 1, - ScaleSliceSizeK, - 1, - false>( - b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + BScaleType, + decltype(b_scale_grid_desc_bn_ak), + decltype(b_scale_thread_desc), + Sequence, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, @@ -1577,39 +1580,36 @@ struct GridwiseMoeGemmBlockScale constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); - if constexpr(IsInputGemm) // gu fusion, elementwise - { - static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); - static_assert(N4 == 4); - const index_t n1 = get_warp_local_1d_id() / MWave; - const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl; + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; - vector_type topk_weights; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk - const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 + - n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4; - if constexpr(MulRoutedWeight) + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion, elementwise { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I0] + n_pos); - } - // if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast(m0), n_pos);} - static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, n2 * N4 + n4)); - constexpr auto cidx = Number{}; - if constexpr(ActivationOperation == Activation::silu_and_mul) { float gate = c_thread_buf[cidx]; float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[n4]; - up = up * topk_weights.AsType()[n4]; + gate = gate * topk_weight; + up = up * topk_weight; } if constexpr(is_same_v, pk_i4_t>) { @@ -1625,8 +1625,8 @@ struct GridwiseMoeGemmBlockScale float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[n4]; - up = up * topk_weights.AsType()[n4]; + gate = gate * topk_weight; + up = up * topk_weight; } if constexpr(is_same_v, pk_i4_t>) { @@ -1636,11 +1636,18 @@ struct GridwiseMoeGemmBlockScale tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf(cidx) = gate * up; } - }); + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } }); }); }); - } + }); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1853,7 +1860,6 @@ struct GridwiseMoeGemmBlockScale static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS StaticallyIndexedArray scatter_offsets; - StaticallyIndexedArray scatter_weights; //= for topk auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -1861,18 +1867,11 @@ struct GridwiseMoeGemmBlockScale static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens ? 1 : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else - { - const float* p_sorted_weights_2 = p_ds_grid[I0]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; }); block_sync_lds(); @@ -1893,8 +1892,7 @@ struct GridwiseMoeGemmBlockScale c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { @@ -2019,10 +2017,11 @@ struct GridwiseMoeGemmBlockScale } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = - __builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N , ScaleBlockN) * (IsInputGemm ? 2 : 1) * - math::integer_divide_ceil(problem.K, ScaleBlockK)); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockK)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -2071,12 +2070,12 @@ struct GridwiseMoeGemmBlockScale IndexType, 1, BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -2123,8 +2122,7 @@ struct GridwiseMoeGemmBlockScale (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - - //scale + // scale constexpr index_t ScaleSliceSizeM = MXdlPerWave; constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); @@ -2160,7 +2158,7 @@ struct GridwiseMoeGemmBlockScale { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - scale_gather_offsets(m0) = static_cast(token_offset) * + scale_gather_offsets(m0) = static_cast(token_offset) * math::integer_divide_ceil(problem.K, ScaleBlockK); }); @@ -2222,22 +2220,24 @@ struct GridwiseMoeGemmBlockScale get_warp_local_1d_id() % NWave, 0, KPack / KGroup * (get_thread_local_1d_id() % warpSize))); - const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; + const BScaleType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2, - Sequence<0, 1>, - 1, - ScaleSliceSizeK, - 1, - false>( - b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + BScaleType, + decltype(b_scale_grid_desc_bn_ak), + decltype(b_scale_thread_desc), + Sequence, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, @@ -2316,7 +2316,7 @@ struct GridwiseMoeGemmBlockScale blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); // TODO: hacky, fix it! - //only used to get lengths + // only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); @@ -2329,39 +2329,36 @@ struct GridwiseMoeGemmBlockScale constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); - if constexpr(IsInputGemm) // gu fusion, elementwise - { - static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); - static_assert(N4 == 4); - const index_t n1 = get_warp_local_1d_id() / MWave; - const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl; + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; - vector_type topk_weights; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk - const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 + - n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4; - if constexpr(MulRoutedWeight) + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion, elementwise { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I0] + n_pos); - } - // if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast(m0), n_pos);} - static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, n2 * N4 + n4)); - constexpr auto cidx = Number{}; - if constexpr(ActivationOperation == Activation::silu_and_mul) { float gate = c_thread_buf[cidx]; float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[n4]; - up = up * topk_weights.AsType()[n4]; + gate = gate * topk_weight; + up = up * topk_weight; } if constexpr(is_same_v, pk_i4_t>) { @@ -2377,8 +2374,8 @@ struct GridwiseMoeGemmBlockScale float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[n4]; - up = up * topk_weights.AsType()[n4]; + gate = gate * topk_weight; + up = up * topk_weight; } if constexpr(is_same_v, pk_i4_t>) { @@ -2388,11 +2385,19 @@ struct GridwiseMoeGemmBlockScale tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf(cidx) = gate * up; } - }); + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } + }); }); }); - } + }); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -2491,11 +2496,8 @@ struct GridwiseMoeGemmBlockScale const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -2605,7 +2607,6 @@ struct GridwiseMoeGemmBlockScale // make sure it's safe to write to LDS StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -2613,18 +2614,11 @@ struct GridwiseMoeGemmBlockScale static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens ? 1 : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else - { - const float* p_sorted_weights_2 = p_ds_grid[I0]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; }); block_sync_lds(); @@ -2645,8 +2639,7 @@ struct GridwiseMoeGemmBlockScale c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { @@ -2665,45 +2658,6 @@ struct GridwiseMoeGemmBlockScale I0, cde_lds_and_global_step); } - - // // print C - // printf("tid: %d, blkid: %d, " - // "c_thread_buf = <%1.f, %1.f, %1.f>\n " - // // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f," - // // "%1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n" - // , get_thread_local_1d_id(), block_m_id, - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<0>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<1>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<2>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<3>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<4>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<5>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<6>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<7>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<8>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<9>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<10>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<11>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<12>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<13>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<14>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<3>{}]); }); } } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 4cbc9c16f6..9b1ff3dbf8 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -262,143 +262,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }); } - template = false> - __device__ void RunRead(const SrcDescs& src_descs, - const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, - Number thread_scratch_id = Number{}) - { - // loop over space-filling curve - static_for<0, src_num_access, 1>{}([&](auto iAccess) { - auto src_vectors = generate_vectors(); - auto elm_vectors = generate_vectors(); - - bool oob_val = true; - - // copy data from src_bufs into src_vectors - static_for<0, nSrc, 1>{}([&](auto i) { - using src_vector_t = typename remove_cvref_t::type; - - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], - src_coords_[i]); - - oob_val = oob_val & is_src_valid; - if(i.value == ScatterWeightIdx) - { - static_assert(SrcScalarPerVectors{}[Number{}] == 1, - "scatter weight dim, should only one vec"); - constexpr auto iScatter = - SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); - static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { - src_vectors(i).template AsType()(j) = - scatter_weights(Number{}); - }); - } - else if constexpr(SrcScalarPerVectors{}[i] == 1) - { - auto data_types = SrcDatas{}; - using DataType = remove_cvref_t; - const auto tmp = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); - } - else - { - src_vectors(i).template AsType()(I0) = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - } - }); - - constexpr auto get_elem_op_vec_len = []() { - if constexpr(is_detected::value) - { - if constexpr(decltype(element_op_)::is_pack8_invocable) - return math::min(8, SrcScalarPerVector); - } - if constexpr(is_detected::value) - { - if constexpr(decltype(element_op_)::is_pack4_invocable) - return math::min(4, SrcScalarPerVector); - } - if constexpr(is_detected::value) - { - if constexpr(decltype(element_op_)::is_pack2_invocable) - return math::min(2, SrcScalarPerVector); - } - return 1; - }; - - constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); - - // apply pointwise function - static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { - // get reference to src data - const auto src_data_refs = generate_tie( - // return type should be lvalue - [&](auto iSrc) -> const auto& { - using SrcData = remove_cvref_t>; - - using elem_op_vec_t = typename vector_type::type; - - return src_vectors[iSrc].template AsType()[i]; - }, - Number{}); - - // get reference to dst data - auto dst_data_refs = generate_tie( - // return type should be lvalue - [&](auto iDst) -> auto& { - using DstData = remove_cvref_t>; - - using elem_op_vec_t = typename vector_type::type; - - return elm_vectors(iDst).template AsType()(i); - }, - Number{}); - - // apply pointwise function - // pointwise function signature: - // element_op_(dst_data_refs[I0], - // dst_data_refs[I1], - // ..., - // src_data_refs[I0], - // src_data_refs[I1], - // ...) - unpack2(element_op_, dst_data_refs, src_data_refs); - }); - - elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; - oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; - - // move coordinate - if constexpr(iAccess.value != src_num_access - 1) - { - constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); - - static_for<0, nSrc, 1>{}([&](auto i) { - move_tensor_coordinate(src_descs[i], - src_coords_(i), - make_tensor_coordinate_step(src_descs[i], forward_step)); - }); - } - }); - - // move coordinate back to slice origin (or not) - static_for<0, nSrc, 1>{}([&](auto i) { - if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); - } - }); - } - #if 1 template __device__ void OOBCheck(Number thread_scratch_id = Number{}) @@ -608,22 +471,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter RunWrite(dst_descs, dst_bufs, scatter_offsets); } - template = false> - __device__ void Run(const SrcDescs& src_descs, - const SrcBuffers& src_bufs, - const DstDescs& dst_descs, - DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) - { - RunRead(src_descs, src_bufs, scatter_weights); - RunWrite(dst_descs, dst_bufs, scatter_offsets); - } - __device__ static constexpr auto GetSrcCoordinateResetStep() { if constexpr(src_num_access == 0)