diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 5d9746b114..230a59df40 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -315,18 +315,32 @@ int main(int argc, char* argv[]) d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 2: - // a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{1.0, 1.0}); - // b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0, 1.0}); - ck::utils::FillConstant{ck::type_convert(ck::float2_t(1.0f))}( - a0_t_k_k); - ck::utils::FillConstant{ck::type_convert(ck::float2_t(1.0f))}( - b0_e_n_k); + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove + d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; default: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); @@ -372,13 +386,16 @@ int main(int argc, char* argv[]) { for(int k = 0; k < K; ++k) { + auto f4x2 = a0_t_k_k(t, tk, k).data; if(k % 2 == 0) { - printf("%f ", ck::type_convert(a0_t_k_k(t, tk, k).data >> 4 & 0xf)); + ck::f4_t f4 = (f4x2 >> 4) & 0xf; + printf("%f ", ck::type_convert(f4)); } else { - printf("%f ", ck::type_convert(a0_t_k_k(t, tk, k).data & 0xf)); + ck::f4_t f4 = (f4x2 >> 0) & 0xf; + printf("%f ", ck::type_convert(f4)); } } printf("\n"); @@ -407,13 +424,16 @@ int main(int argc, char* argv[]) { for(int k = 0; k < K; ++k) { + auto f4x2 = b0_e_n_k(e, k, n).data; if(k % 2 == 0) { - printf("%f ", ck::type_convert(b0_e_n_k(e, k, n).data >> 4 & 0xf)); + ck::f4_t f4 = f4x2 >> 4 & 0xf; + printf("%f ", ck::type_convert(f4)); } else { - printf("%f ", ck::type_convert(b0_e_n_k(e, k, n).data & 0xf)); + ck::f4_t f4 = f4x2 >> 0 & 0xf; + printf("%f ", ck::type_convert(f4)); } } printf("\n"); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp index 713ba1049b..a2b8318512 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp @@ -279,34 +279,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - // auto a_scale_thread_buf_copy = - // make_static_buffer( - // a_scale_thread_desc_copy.GetElementSpaceSize()); - // a_scale_thread_copy.Run(a_scale_grid_desc, - // a_scale_grid_buf, - // a_scale_thread_desc_copy, - // make_tuple(I0, I0), - // a_scale_thread_buf_copy); - - a_scale_thread_bufs(I0)(Number{}) = - type_convert(1.0f); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(I0)); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + make_multi_index(0, ScalesPerKBlockSize, 0)); // Prefetch b_scales to buf 0 static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -314,17 +295,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp{}([&](auto s) { constexpr auto b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - // auto b_scale_thread_buf_copy = - // make_static_buffer( - // b_scale_thread_desc_copy.GetElementSpaceSize()); - // b_scale_thread_copy.Run(b_scale_grid_desc, - // b_scale_grid_buf, - // b_scale_thread_desc_copy, - // make_tuple(I0, I0), - // b_scale_thread_buf_copy); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); b_scale_thread_bufs(I0)(Number{}) = - type_convert(1.0f); + b_scale_thread_buf_copy[Number<0>{}]; b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); @@ -337,7 +318,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - // auto a_scale_thread_buf_copy = - // make_static_buffer( - // a_scale_thread_desc_copy.GetElementSpaceSize()); - // a_scale_thread_copy.Run(a_scale_grid_desc, - // a_scale_grid_buf, - // a_scale_thread_desc_copy, - // make_tuple(I0, I0), - // a_scale_thread_buf_copy); - - a_scale_thread_bufs(I1)(Number{}) = - type_convert(1.0f); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(I1)); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + make_multi_index(0, ScalesPerKBlockSize, 0)); // Prefetch b_scales to buf 1 static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -384,17 +346,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp{}([&](auto s) { constexpr auto b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - // auto b_scale_thread_buf_copy = - // make_static_buffer( - // b_scale_thread_desc_copy.GetElementSpaceSize()); - // b_scale_thread_copy.Run(b_scale_grid_desc, - // b_scale_grid_buf, - // b_scale_thread_desc_copy, - // make_tuple(I0, I0), - // b_scale_thread_buf_copy); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); b_scale_thread_bufs(I1)(Number{}) = - type_convert(1.0f); + b_scale_thread_buf_copy[Number<0>{}]; b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); @@ -538,35 +500,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - // auto a_scale_thread_buf_copy = - // make_static_buffer( - // a_scale_thread_desc_copy.GetElementSpaceSize()); - // a_scale_thread_copy.Run(a_scale_grid_desc, - // a_scale_grid_buf, - // a_scale_thread_desc_copy, - // make_tuple(I0, I0), - // a_scale_thread_buf_copy); - - a_scale_thread_bufs(mfma_reg_buf)(Number{}) = - type_convert(1.0f); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0, I0), + a_scale_thread_bufs(mfma_reg_buf)); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0)); // Prefetch b_scales static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -574,17 +516,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp{}([&](auto s) { constexpr auto b_scale_offset = b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - // auto b_scale_thread_buf_copy = - // make_static_buffer( - // b_scale_thread_desc_copy.GetElementSpaceSize()); - // b_scale_thread_copy.Run(b_scale_grid_desc, - // b_scale_grid_buf, - // b_scale_thread_desc_copy, - // make_tuple(I0, I0), - // b_scale_thread_buf_copy); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); b_scale_thread_bufs(mfma_reg_buf)(Number{}) = - type_convert(1.0f); + b_scale_thread_buf_copy[Number<0>{}]; b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index ad818d7a2b..0a88eb6d73 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -211,6 +211,7 @@ struct GridwiseMoeGemmMX static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave; // static constexpr index_t NumTokens = 1; static constexpr index_t SortedTileSize = MPerBlock; @@ -512,9 +513,7 @@ struct GridwiseMoeGemmMX __host__ __device__ static constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) { - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - - return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); } template @@ -942,8 +941,6 @@ struct GridwiseMoeGemmMX __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = make_naive_tensor_descriptor_packed( make_tuple(I1, @@ -1249,8 +1246,9 @@ struct GridwiseMoeGemmMX const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, - math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1)); + math::integer_divide_ceil(problem.K, ScaleBlockSize), + 1), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1, 1)); const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( make_tuple(problem.K, math::integer_divide_ceil(problem.K, ScaleBlockSize)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1)); @@ -1431,20 +1429,40 @@ struct GridwiseMoeGemmMX auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; - auto a_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true>( - a_scale_grid_desc_am_ak, - make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k)); + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray scale_gather_offsets; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { + const index_t fused_token = + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scale_gather_offsets(m0) = + token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize); + }); + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, 1>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + 1, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true, + MXdlPerWave, + KRepeat>( + a_scale_grid_desc_am_ak, make_multi_index(0, 0, thread_offset_k), scale_gather_offsets); + + // B scale load auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; auto b_scale_thread_copy = @@ -1537,8 +1555,6 @@ struct GridwiseMoeGemmMX NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -2255,8 +2271,6 @@ struct GridwiseMoeGemmMX NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index a234c581e0..41665e7d45 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -424,6 +424,248 @@ struct ThreadwiseTensorSliceTransfer_v2 SrcCoord src_coord_; }; // namespace ck +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v2_gather +{ + static_assert((InvalidElementAsNaN && !ck::is_integral::value) || + (!InvalidElementAsNaN), + "Filling invalid element as NaN is only for floating point types"); + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t> || + is_same_v, f4x2_pk_t>) + return 2; + else + return 1; + }(); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather( + const SrcDesc& src_desc, + const Index& src_slice_origin_idx, + const StaticallyIndexedArray& scale_gather_offsets) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)), + scale_gather_offsets_(scale_gather_offsets) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + + if constexpr(is_same_v, pk_i4_t> || + is_same_v, f4x2_pk_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + auto adjusted_origin_idx = [&]() { + Index idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number{}]; }); + + return idx; + }(); + + src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over tensor and copy + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { // MRepeate + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto current_dst_origin = + to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, k0, 0); + MoveSrcSliceWindow(src_desc, make_multi_index(0, 0, 0)); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type + src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, + src_coord_); + + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset() / PackedSize + + scale_gather_offsets_(gather_idx), + is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + + src_data_idx + i * src_scalar_step_in_vector); + constexpr auto full_dst_offset = + dst_desc.CalculateOffset(current_dst_origin) + dst_offset; + + if constexpr(InvalidElementAsNaN) + { + dst_buf(full_dst_offset) = + is_src_valid ? type_convert( + src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } + }); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate(src_desc, + src_coord_, + make_tensor_coordinate_step(src_desc, forward_step)); + } + }); + }); + MoveSrcSliceWindow(src_desc, make_multi_index(0, -KRepeat, 0)); + }); + + // printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n", + // blockIdx.y, + // threadIdx.x, + // dst_buf(Number<0>{})); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + StaticallyIndexedArray scale_gather_offsets_; +}; // namespace ck + // Assume: // 1. src_desc and dst_desc are not known at compile-time // 2. SrcBuffer and DstBuffer are DynamicBuffer diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 4d75c12052..5f5d6e7ffc 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -762,7 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> using arg_type = int32x8_t; -#if 1 +#if 0 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},