From dd24786f78f978d6365fb03d7cef91efea33b31a Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 2 Jun 2025 12:23:01 +0000 Subject: [PATCH] Enable splitk for mxfp4; clang format; --- .../67_gemm_microscaling/gemm_mx_common.hpp | 57 +- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 13 +- ...peline_xdlops_b_preshuffle_gufusion_v1.hpp | 13 +- ...oe_blockscale_b_preshuffle_gufusion_v3.hpp | 82 +- ...s_moe_blockscale_b_preshuffle_selector.hpp | 92 +- ..._xdlops_moe_blockscale_b_preshuffle_v3.hpp | 29 +- ...gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp | 79 +- .../impl/device_moe_gemm_blockscale.hpp | 2 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 58 +- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 58 +- .../threadwise_tensor_slice_transfer copy.hpp | 2101 ----------------- .../threadwise_tensor_slice_transfer.hpp | 93 - 12 files changed, 191 insertions(+), 2486 deletions(-) delete mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer copy.hpp diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 0205bf2668..ea3b731d19 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -108,7 +108,6 @@ bool parse_cmd_args(int argc, return true; } -#if 1 template void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) { @@ -146,8 +145,9 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, + // 2-k))); + if constexpr(KLast) dst[outputIndex] = src[n * K + k]; else @@ -186,7 +186,6 @@ void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K } } } -#endif template >(a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), Scale_Padded_M, @@ -358,48 +356,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c int NPerXdl = 16; // Fixed 16 preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl); } -#endif - // printf("a_scale:\n"); - // for(ck::index_t i = 0; i < M; i++) - // { - // for(ck::index_t j = 0; j < K / ScaleBlockSize; j++) - // { - // // a_m_k_scale(i, j) = - // // ck::type_convert(static_cast(powf(2.0f, (j / 4) % 4))); - // // a_m_k_scale(i, j) =ck::type_convert(static_cast(1.0f)); - // // a_shuffled_scale(i, j) =ck::type_convert(static_cast(1.0f)); - // printf("%02x ", *reinterpret_cast(&a_m_k_scale(i, j))); - // } - // printf("\n"); - // } - // printf("b_scale:\n"); - // for(ck::index_t i = 0; i < N; i++) - // { - // for(ck::index_t j = 0; j < K / ScaleBlockSize; j++) - // { - // // // b_k_n_scale(j, i) = - // // // ck::type_convert(static_cast(powf(2.0f, (j / 4) % 4))); - // // b_k_n_scale(j, i) =ck::type_convert(static_cast(1.0f)); - // // b_shuffled_scale(j, i) =ck::type_convert(static_cast(1.0f)); - // printf("%02x ", *reinterpret_cast(&b_k_n_scale(j, i))); - // } - // printf("\n"); - // } - - // printf("a_shuffled_scale:\n"); - // for(ck::index_t i = 0; i < M * K / ScaleBlockSize; i++) - // { - // printf("%02x ", *reinterpret_cast(&(a_shuffled_scale.mData.data()[i]))); - // if(i % 64 == 63) - // printf("\n"); - // } - // printf("b_shuffled_scale:\n"); - // for(ck::index_t i = 0; i < N * K / ScaleBlockSize; i++) - // { - // printf("%02x ", *reinterpret_cast(&(b_shuffled_scale.mData.data()[i]))); - // if(i % 64 == 63) - // printf("\n"); - // } if(config.verbosity > 0) std::cout << "Device memory allocation..." << std::endl; @@ -524,9 +480,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c // << std::endl; // } - res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, - c_m_n_host_result, - "Error: Incorrect results!"); + res_verified = + res_verified && + ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1); if(config.verbosity > 0 && res_verified) std::cout << "Verification Successful!" << std::endl; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index c042350f3c..4f7b8e768c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -484,13 +484,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp index d24b9af006..fe89e700c4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -465,12 +465,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp index bf41ccbc19..62f0431691 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp @@ -254,16 +254,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< constexpr auto buffer_load_issue_point_b = 0; constexpr auto buffer_load_issue_point_interval_more = - num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1; + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; constexpr auto buffer_load_issue_point_interval_less = - num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1; + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; constexpr auto ds_write_issue_point = 0; constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; // B global read static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { // Scale load, 1B - if constexpr (i.value==0){ + if constexpr(i.value == 0) + { __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } // Scale load, 1A @@ -330,7 +335,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA // Scale load, 1A - if constexpr(imfma == 0){ + if constexpr(imfma == 0) + { __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } @@ -426,7 +432,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< StaticallyIndexedArray{}> b_thread_bufs_up; constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); - auto a_scale_thread_buf = make_static_buffer( + auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); @@ -483,12 +489,12 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< b_scale_thread_bufs(I0)); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); - + b_scale_thread_copy_up.Run(b_scale_grid_desc, - b_scale_grid_buf_up, - b_scale_thread_desc, - make_tuple(I0, I0), - b_scale_thread_bufs_up(I0)); + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(I0)); b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); @@ -496,7 +502,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; }); static_for<0, MRepeat, 1>{}([&](auto m0) { - c_scale_thread_buf_up(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; + c_scale_thread_buf_up(m0) = + a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; }); // Local prefill A1 @@ -532,10 +539,10 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy_up.Run(b_scale_grid_desc, - b_scale_grid_buf_up, - b_scale_thread_desc, - make_tuple(I0, I0), - b_scale_thread_bufs_up(I0)); + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(I0)); b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); @@ -574,7 +581,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< }); }); - __builtin_amdgcn_sched_barrier(0); // main body @@ -609,13 +615,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< if constexpr(NumKBlockPerScale == 1) { - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<1>{})); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); } else { - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<0>{})); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -627,13 +633,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< b_scale_thread_copy_step); b_scale_thread_copy_up.Run(b_scale_grid_desc, - b_scale_grid_buf_up, - b_scale_thread_desc, - make_tuple(I0, I0), - b_scale_thread_bufs_up(local_read_buf)); + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(local_read_buf)); b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step); + b_scale_thread_copy_step); static_for<0, MRepeat, 1>{}([&](auto m0) { vector_type c_scale_thread_vec; @@ -676,8 +682,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< b_thread_vec_up.template AsType()(ik) = b_thread_bufs_up[mfma_reg_buf] - [Number{}]; + [Number{}]; }); using mfma_input_type = @@ -711,7 +717,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< .template AsType()(t) = __builtin_elementwise_fma( c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) .template AsType()[t], - c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_scale_thread_vec_up + .template AsType()[Number<0>{}], c_thread_buf_up.GetVectorTypeReference(Number{}) .template AsType()[t]); }); @@ -800,8 +807,10 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< }); static_for<0, MRepeat, 1>{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0]; - c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs_up[mfma_reg_buf][I0]; + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; + c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs_up[mfma_reg_buf][I0]; }); HotLoopScheduler(); @@ -824,13 +833,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< b_block_origin_idx, b_thread_bufs(I1)); b_blockwise_copy_up.Run(b_grid_desc, - b_grid_buf_up, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs_up(I1)); + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { vector_type c_scale_thread_vec; c_scale_thread_vec.template AsType()(Number<0>{}) = c_scale_thread_buf[m0]; @@ -970,7 +979,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< static_for<0, MRepeat, 1>{}([&](auto m0) { c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; - c_scale_thread_buf_up(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; + c_scale_thread_buf_up(m0) = + a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; }); static_for<0, MRepeat, 1>{}([&](auto m0) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp index a04563f458..e07d4db9e6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp @@ -43,56 +43,56 @@ constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector() if constexpr(GUFusion) { return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< - BlkGemmPipeSche, - BlockSize, - ADataType, - BDataType, - ComputeDataType, - AccDataType, - ATileDesc, - BTileDesc, - AMmaTileDesc, - BMmaTileDesc, - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MScaleBlock, - NScaleBlock, - KScaleBlock, - MPerXDL, - NPerXDL, - MRepeat, - NRepeat, - KPack>{}; + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; } else { return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< - BlkGemmPipeSche, - BlockSize, - ADataType, - BDataType, - ComputeDataType, - AccDataType, - ATileDesc, - BTileDesc, - AMmaTileDesc, - BMmaTileDesc, - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MScaleBlock, - NScaleBlock, - KScaleBlock, - MPerXDL, - NPerXDL, - MRepeat, - NRepeat, - KPack>{}; + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; } } #if 0 diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp index 406e8737e7..3a130660f7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -254,16 +254,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< constexpr auto buffer_load_issue_point_b = 0; constexpr auto buffer_load_issue_point_interval_more = - num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1; + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; constexpr auto buffer_load_issue_point_interval_less = - num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1; + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; constexpr auto ds_write_issue_point = 0; constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; // B global read static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { // Scale load, 1B - if constexpr (i.value==0){ + if constexpr(i.value == 0) + { __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } // Scale load, 1A @@ -330,7 +335,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA // Scale load, 1A - if constexpr(imfma == 0){ + if constexpr(imfma == 0) + { __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } @@ -420,7 +426,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< StaticallyIndexedArray{}> b_thread_bufs; constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); - auto a_scale_thread_buf = make_static_buffer( + auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); @@ -586,13 +592,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< if constexpr(NumKBlockPerScale == 1) { - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<1>{})); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); } else { - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<0>{})); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); } b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_grid_buf, @@ -744,7 +750,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< }); static_for<0, MRepeat, 1>{}([&](auto m0) { - c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0]; + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; }); HotLoopScheduler(); @@ -768,7 +775,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< b_thread_bufs(I1)); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { vector_type c_scale_thread_vec; c_scale_thread_vec.template AsType()(Number<0>{}) = c_scale_thread_buf[m0]; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp index 874928600f..7e11304e2f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -449,84 +449,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle(&(b_thread_bufs(I0)(Number<0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<7>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<8 + 7>{}))), - get_thread_local_1d_id(), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 7>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<16 + 8 + 7>{}))), - get_thread_local_1d_id(), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 7>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 8 + 7>{}))), - get_thread_local_1d_id(), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 7>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 0>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 1>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 2>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 3>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 4>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 5>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 6>{}))), - *reinterpret_cast(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 7>{})))); -#endif + // Initialize C c_thread_buf.Clear(); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 1f25875193..44c2b638e6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -201,7 +201,7 @@ struct DeviceMoeGemmBlockScale index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const auto RunKernel = [&](const auto& kernel) { + const auto RunKernel = [&](const auto& kernel) { if(stream_config.flush_cache) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index b68a9c12ca..de22d65937 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -789,26 +789,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 } // Calculate A scale offset - if constexpr(is_same_v) - { - a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); - } - else if constexpr(is_same_v) - { - a_scale_k_split_offset = - k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; - } + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * MPerXdl; // Calculate B scale offset - if constexpr(is_same_v) - { - b_scale_k_split_offset = - k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; - } - else if constexpr(is_same_v) - { - b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); - } + b_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * NPerXdl; if(k_id < (karg.KBatch - 1)) { @@ -1850,17 +1836,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack const auto Padded_Scale_M = math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / (KXdlPack * 64 / MPerXdl), - 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( make_tuple(problem.N / (NXdlPack * NPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / (KXdlPack * 64 / NPerXdl), - 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); Run KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack const auto Padded_Scale_M = math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / (KXdlPack * 64 / MPerXdl), - 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( make_tuple(problem.N / (NXdlPack * NPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / (KXdlPack * 64 / NPerXdl), - 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); Run_2Lds; + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -806,7 +814,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle { if constexpr(!PermuteB) { - b_k_split_offset = k_id * karg.KRead; + b_k_split_offset = k_id * karg.KRead * NPerXdl; } else { @@ -816,26 +824,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle } // Calculate A scale offset - if constexpr(is_same_v) - { - a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); - } - else if constexpr(is_same_v) - { - a_scale_k_split_offset = - k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; - } + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * + MPerXdl / scale_pack_size_a; // Calculate B scale offset - if constexpr(is_same_v) - { - b_scale_k_split_offset = - k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; - } - else if constexpr(is_same_v) - { - b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); - } + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * + NPerXdl / scale_pack_size_b; if(k_id < (karg.KBatch - 1)) { @@ -1289,14 +1283,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - using mx_scale_t = e8m0_bexp_t; - static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); - static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); - static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, - "A scale pack data type too large!"); - static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, - "B scale pack data type too large!"); - template ::type = false> -struct ThreadwiseTensorSliceTransfer_v1r3 -{ - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, - const Index& dst_slice_origin_idx, - const ElementwiseOperation& element_op) - : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), - element_op_{element_op} - { - static_assert(SrcDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc need to known at compile-time"); - static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, - "wrong! Not divisible"); - } - - __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) - { - dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); - } - - template - __device__ void Run(const SrcDesc&, - const SrcSliceOriginIdx&, - const SrcBuffer& src_buf, - const DstDesc& dst_desc, - DstBuffer& dst_buf) - { - static_assert(SrcDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc need to known at compile-time"); - - static_assert(is_known_at_compile_time>::value, - "wrong! SrcSliceOrigin need to known at compile-time"); - - static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); - - // SrcDesc and src_slice_origin_idx are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? - static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, - "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); - typename vector_type_maker::type dst_vector; - using dst_vector_t = typename vector_type_maker::type::type; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - - static_for<0, num_access, 1>{}([&](auto idx_1d) { - constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - - // copy data from src_buf into dst_vector - // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve? - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - DstData v; - - // apply element-wise operation - element_op_(v, src_buf[Number{}]); - - dst_vector.template AsType()(i) = v; - }); - - const bool is_dst_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - - // copy data from dst_vector into dst_buf - dst_buf.template Update( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); - - if constexpr(idx_1d.value != num_access - 1) - { - constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - - move_tensor_coordinate( - dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); - } - }); - - // move dst coordinate back to slice origin (or not) - if constexpr(DstResetCoordinateAfterRun) - { - const auto dst_reset_step = - make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - - move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); - } - } - - __device__ static constexpr auto GetDstCoordinateResetStep() - { - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } - } - - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, - const Index& dst_slice_origin_step_idx) - { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); - } - - private: - DstCoord dst_coord_; - const ElementwiseOperation element_op_; -}; // namespace ThreadwiseTensorSliceTransfer_v1r3 - -/** - * @brief Helper structure that facilitates transfer of source (grid) data to destination threads. - * - * @details The following assumptions are made: - * - For Source (Grid) Data: - * 1. The source tensor descriptor SrcDesc is not known at compile-time. - * 2. The source buffer is a dynamic buffer. - * 3. The source slice origin index src_slice_origin_idx is not known at compile-time. - * - For Destination (Thread) Data: - * 1. The destination tensor descriptor DstDesc is known at compile-time. - * 2. The destination buffer dst_buf is a static buffer. - * 3. The destination slice origin index dst_slice_origin_idx is known at compile-time. - * - * @tparam SrcData The data type of the source tensor. - * @tparam DstData The data type of the destination tensor. - * @tparam SrcDesc The descriptor type of the source tensor. - * @tparam DstDesc The descriptor type of the destination tensor. - * @tparam SliceLengths The lengths of the slice to be transferred. - * @tparam DimAccessOrder The order of dimension access for the space-filling curve. - * @tparam SrcVectorDim The dimension along which vectorized access is performed in the source - * tensor. - * @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor. - * @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source - * tensor. - * @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run - * or rolled back one step in MoveSrcSliceWindow - * @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for - * floating-point types). - * - */ -template ::type = false> -struct ThreadwiseTensorSliceTransfer_v2 -{ - static_assert((InvalidElementAsNaN && !ck::is_integral::value) || - (!InvalidElementAsNaN), - "Filling invalid element as NaN is only for floating point types"); - - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - - static constexpr index_t PackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, - const Index& src_slice_origin_idx) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) - { - static_assert(DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc need to known at compile-time"); - static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, - "wrong! Not divisible"); - - if constexpr(is_same_v, pk_i4_t>) - { - static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); - } - } - - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) - { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); - } - - template - __device__ void Run(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) - { - static_assert(DstDesc::IsKnownAtCompileTime(), - "wrong! DstDesc need to known at compile-time"); - - static_assert(is_known_at_compile_time>::value, - "wrong! DstSliceOrigin need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - // DstDesc and dst_slice_origin_idx are known at compile-time - constexpr auto dst_desc = remove_cvref_t{}; - constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - // loop over tensor and copy - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - - static_for<0, num_access, 1>{}([&](auto idx_1d) { - typename vector_type_maker::type src_vector; - - using src_vector_t = - typename vector_type_maker::type::type; - constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); - - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - - // copy data from src_buf into src_vector - src_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_coord_.GetOffset() / PackedSize, - is_src_valid); - - // copy data from src_vector into dst_buf - static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { - constexpr index_t dst_offset = - dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + - i * src_scalar_step_in_vector); - - if constexpr(InvalidElementAsNaN) - { - dst_buf(Number{}) = - is_src_valid - ? type_convert(src_vector.template AsType()[i]) - : NumericLimits::QuietNaN(); - } - else - { - dst_buf(Number{}) = - type_convert(src_vector.template AsType()[i]); - } - }); - - if constexpr(idx_1d.value != num_access - 1) - { - constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - - move_tensor_coordinate( - src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); - } - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } - } - - __device__ static constexpr auto GetSrcCoordinateResetStep() - { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } - } - - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) - { - // if src coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step( - src_desc, adjusted_step_idx, src_move_slice_window_step_hack); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - private: - SrcCoord src_coord_; -}; // namespace ck - - -template ::type = false> -struct ThreadwiseTensorSliceTransfer_v2_gather -{ - static_assert((InvalidElementAsNaN && !ck::is_integral::value) || - (!InvalidElementAsNaN), - "Filling invalid element as NaN is only for floating point types"); - - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - - static constexpr index_t PackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(const SrcDesc& src_desc, - const Index& src_slice_origin_idx, - const StaticallyIndexedArray& scale_gather_offsets) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)), - scale_gather_offsets_(scale_gather_offsets) - { - static_assert(DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc need to known at compile-time"); - static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, - "wrong! Not divisible"); - - if constexpr(is_same_v, pk_i4_t>) - { - static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); - } - } - - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) - { - auto adjusted_origin_idx = [&]() { - Index idx; - - static_for<0, nDim, 1>{}([&](auto i) { - idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number{}]; - }); - - return idx; - }(); - - src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx); - } - - template - __device__ void Run(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) - { - static_assert(DstDesc::IsKnownAtCompileTime(), - "wrong! DstDesc need to known at compile-time"); - - static_assert(is_known_at_compile_time>::value, - "wrong! DstSliceOrigin need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - // DstDesc and dst_slice_origin_idx are known at compile-time - constexpr auto dst_desc = remove_cvref_t{}; - constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - // loop over tensor and copy - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - - static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { - auto current_dst_origin = to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0); - - static_for<0, num_access, 1>{}([&](auto idx_1d) { - typename vector_type_maker::type src_vector; - - using src_vector_t = - typename vector_type_maker::type::type; - constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); - - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - - // copy data from src_buf into src_vector - src_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_coord_.GetOffset() / PackedSize + scale_gather_offsets_(gather_idx), - is_src_valid); - - // copy data from src_vector into dst_buf - static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { - constexpr index_t dst_offset = - dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + - i * src_scalar_step_in_vector); - auto full_dst_offset = dst_desc.CalculateOffset(current_dst_origin) + dst_offset; - - if constexpr(InvalidElementAsNaN) - { - dst_buf(full_dst_offset) = - is_src_valid - ? type_convert(src_vector.template AsType()[i]) - : NumericLimits::QuietNaN(); - } - else - { - dst_buf(full_dst_offset) = - type_convert(src_vector.template AsType()[i]); - } - }); - - if constexpr(idx_1d.value != num_access - 1) - { - constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - - move_tensor_coordinate( - src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); - } - }); - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } - } - - __device__ static constexpr auto GetSrcCoordinateResetStep() - { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } - } - - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) - { - // if src coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step( - src_desc, adjusted_step_idx, src_move_slice_window_step_hack); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - private: - SrcCoord src_coord_; - StaticallyIndexedArray scale_gather_offsets_; -}; // namespace ck - - - -// Assume: -// 1. src_desc and dst_desc are not known at compile-time -// 2. SrcBuffer and DstBuffer are DynamicBuffer -// 3. src_slice_origin and dst_slice_origin are not known at compile-time, -// 4. Use thread buffer -template // control whether to move back dst coordinate after each - // RunWrite(), will be fused with MoveDstSliceWindow to - // save addr computation -struct ThreadwiseTensorSliceTransfer_v3 -{ - static constexpr index_t nDim = SliceLengths::Size(); - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc, - const Index& src_slice_origin, - const DstDesc& dst_desc, - const Index& dst_slice_origin) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), - dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) - { - static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, - "wrong! Not divisible"); - static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, - "wrong! Not divisible"); - } - - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) - { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); - } - - __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) - { - dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); - } - - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) - { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - src_desc, forward_step_idx, src_step_hacks[I0][i]); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - src_desc, backward_step_idx, src_step_hacks[I1][i]); - }, - Number{}); - - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - vector_type_maker_t src_tmp_vector; - - using src_vector_t = typename decltype(src_tmp_vector)::type; - - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - - // copy data from src_buf to src_tmp_vector - src_tmp_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_coord_.GetOffset(), is_src_valid); - - // copy data from src_tmp_vector to buffer_ - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t buffer_offset = - buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); - - buffer_(Number{}) = src_tmp_vector.template AsType()[i]; - }); - - constexpr auto move_on_dim = [&]() constexpr - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - (); - - // move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } - }); - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } - } - - template - __device__ void - RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) - { - static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - // src scalar per access on each dim - // TODO: don't use this - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - dst_desc, forward_step_idx, dst_step_hacks[I0][i]); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step( - dst_desc, backward_step_idx, dst_step_hacks[I1][i]); - }, - Number{}); - - // loop over tensor and copy - static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - vector_type_maker_t dst_tmp_vector; - - // copy data from buffer_ to dst_tmp_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t buffer_offset = - buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); - - dst_tmp_vector.template AsType()(i) = - type_convert(buffer_[Number{}]); - }); - - using dst_vector_t = typename decltype(dst_tmp_vector)::type; - - // copy data from dst_tmp_vector to dst_buf - const bool is_dst_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_tmp_vector.template AsType()[Number<0>{}]); - - constexpr auto move_on_dim = [&]() constexpr - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - (); - - // move - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); - } - } - }); - }); - - // move dst coordinate back to slice origin (or not) - if constexpr(DstResetCoordinateAfterRun) - { - const auto dst_reset_step = - make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - - move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); - } - } - - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto src_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - RunRead(src_desc, src_buf, src_step_hacks); - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) - { - constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto dst_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - RunWrite(dst_desc, dst_buf, dst_step_hacks); - } - - __device__ static constexpr auto GetSrcCoordinateResetStep() - { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); - - return reset_src_data_step; - } - - __device__ static constexpr auto GetDstCoordinateResetStep() - { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step( - src_desc, adjusted_step_idx, src_move_slice_window_step_hack); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, - const Index& dst_slice_origin_step_idx) - { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); - } - - private: - static constexpr auto buffer_desc_ = - make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); - - static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - - StaticBuffer buffer_; - - SrcCoord src_coord_; - DstCoord dst_coord_; -}; - -// Assume: -// 1. src: -// 1. SrcDesc is known at compile-time -// 2. SrcBuffer is DynamicBuffer -// 3. src_ref_idx is known at run-time -// 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-step -// 2. dst: -// 1. DstDesc is known at compile-time -// 2. DstBuffer is StaticBuffer -// 3. DstOriginIdx is known at compile-time -// 4. use direct address calculation -// 3. vector access on src -template ::type = false> -struct ThreadwiseTensorSliceTransfer_v4 -{ - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - - static constexpr index_t PackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) - : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, - "wrong! Not divisible"); - - if constexpr(is_same_v, pk_i4_t>) - { - static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); - } - } - - template - __device__ void Run(const SrcDesc&, - const SrcRefToOriginDisplacement&, - const SrcBuffer& src_buf, - const DstDesc&, - const DstOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); - - // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - - // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time - constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); - constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); - - // scalar per access of each dim - constexpr auto src_scalar_per_access = generate_sequence_v2( - [&](auto i) constexpr { - if constexpr(i == SrcVectorDim) - { - return Number{}; - } - else - { - return Number<1>{}; - } - }, - Number{}); - - // scalar step (if steping on SrcVectorDim) of each dim - constexpr auto src_scalar_step_in_vector = generate_sequence_v2( - [&](auto i) constexpr { - if constexpr(i == SrcVectorDim) - { - return Number<1>{}; - } - else - { - return Number<0>{}; - } - }, - Number{}); - - constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - static_ford{}([&](auto ordered_access_idx) { -#if 0 - // TODO: unable to compile - // position in slice window - constexpr auto data_to_origin_disp_idx = - container_reorder_given_old2new(ordered_access_idx, dim_access_order) * - src_scalar_per_access; -#else - // position in slice window - constexpr auto data_to_origin_disp_idx = - ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; -#endif - // src coordinate - constexpr auto src_ref_to_data_disp_idx = - src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - - constexpr auto src_ref_to_data_disp_coord_step = - make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); - - auto src_data_coord = src_ref_coord_; - - move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); - - vector_type_maker_t src_tmp_vector; - - using src_vector_t = typename decltype(src_tmp_vector)::type; - - const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( - src_desc, src_data_coord); - - // copy data from src_buf into src_tmp_vector - if constexpr(SrcBuffer::IsDynamicBuffer()) - { - src_tmp_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_data_coord.GetOffset() / PackedSize, - is_src_valid); - } - else if constexpr(SrcBuffer::IsStaticBuffer()) - { - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_ref_to_origin_disp_idx + data_to_origin_disp_idx + - i * src_scalar_step_in_vector); - - src_tmp_vector.template AsType()(i) = src_buf[Number{}]; - }); - } - - if constexpr(is_same, pk_i4_t>::value) - { - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - - constexpr index_t pack_size = 8; - - static_assert(SrcScalarPerVector % pack_size == 0, ""); - - using src_v_t = typename vector_type_maker_t::type; - using dst_v_t = typename vector_type_maker_t::type; - - static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { - ck::tensor_operation::element_wise::PassThroughPack8{}( - dst_tmp_vector.template AsType()(i), - src_tmp_vector.template AsType()[i]); - }); - - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); - } - else if constexpr(is_same, f8_t>::value && - is_same, half_t>::value && - SrcScalarPerVector % 2 == 0) - { - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - - constexpr index_t pack_size = 2; - - using dst_v_t = typename vector_type_maker_t::type; - using src_v_t = typename vector_type_maker_t::type; - static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { - ck::tensor_operation::element_wise::PassThroughPack2{}( - dst_tmp_vector.template AsType()(i), - src_tmp_vector.template AsType()[i]); - }); - - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); - } - else - { - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - - // TODO: if SrcData and DstData are vetor type, then static_cast may not compile - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - dst_tmp_vector.template AsType()(i) = - type_convert(src_tmp_vector.template AsType()[i]); - }); - - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); - } - }); - } - - // Fuse scale - template - __device__ void Run(const SrcDesc&, - const SrcRefToOriginDisplacement&, - const SrcBuffer& src_buf, - const DstData& scale, - const DstDesc&, - const DstOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! SrcDesc and DstDesc need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); - - // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - - // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time - constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); - constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); - - // scalar per access of each dim - constexpr auto src_scalar_per_access = generate_sequence_v2( - [&](auto i) constexpr { - if constexpr(i == SrcVectorDim) - { - return Number{}; - } - else - { - return Number<1>{}; - } - }, - Number{}); - - // scalar step (if steping on SrcVectorDim) of each dim - constexpr auto src_scalar_step_in_vector = generate_sequence_v2( - [&](auto i) constexpr { - if constexpr(i == SrcVectorDim) - { - return Number<1>{}; - } - else - { - return Number<0>{}; - } - }, - Number{}); - - constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - static_ford{}([&](auto ordered_access_idx) { -#if 0 - // TODO: unable to compile - // position in slice window - constexpr auto data_to_origin_disp_idx = - container_reorder_given_old2new(ordered_access_idx, dim_access_order) * - src_scalar_per_access; -#else - // position in slice window - constexpr auto data_to_origin_disp_idx = - ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; -#endif - // src coordinate - constexpr auto src_ref_to_data_disp_idx = - src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - - constexpr auto src_ref_to_data_disp_coord_step = - make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); - - auto src_data_coord = src_ref_coord_; - - move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); - - vector_type_maker_t src_tmp_vector; - - using src_vector_t = typename decltype(src_tmp_vector)::type; - - const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( - src_desc, src_data_coord); - - // copy data from src_buf into src_tmp_vector - if constexpr(SrcBuffer::IsDynamicBuffer()) - { - src_tmp_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_data_coord.GetOffset() / PackedSize, - is_src_valid); - } - else if constexpr(SrcBuffer::IsStaticBuffer()) - { - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_ref_to_origin_disp_idx + data_to_origin_disp_idx + - i * src_scalar_step_in_vector); - - src_tmp_vector.template AsType()(i) = src_buf[Number{}]; - }); - } - - if constexpr(is_same, pk_i4_t>::value) - { - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - vector_type scale_vector; - scale_vector.template AsType()(Number<0>{}) = scale; - scale_vector.template AsType()(Number<1>{}) = scale; - - constexpr index_t pack_size = 8; - - static_assert(SrcScalarPerVector % pack_size == 0, ""); - - using src_v_t = typename vector_type_maker_t::type; - using dst_v_t = typename vector_type_maker_t::type; - using scale_v_t = typename vector_type_maker_t::type; - - static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { - ck::tensor_operation::element_wise::DequantPack8{}( - dst_tmp_vector.template AsType()(i), - src_tmp_vector.template AsType()[i], - scale_vector.template AsType()[Number<0>{}]); - }); - - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); - } - else if constexpr(is_same, f8_t>::value && - is_same, half_t>::value && - SrcScalarPerVector % 2 == 0) - { - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - - constexpr index_t pack_size = 2; - - using dst_v_t = typename vector_type_maker_t::type; - using src_v_t = typename vector_type_maker_t::type; - static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { - ck::tensor_operation::element_wise::PassThroughPack2{}( - dst_tmp_vector.template AsType()(i), - src_tmp_vector.template AsType()[i]); - }); - - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); - } - else - { - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - - // TODO: if SrcData and DstData are vetor type, then static_cast may not compile - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - dst_tmp_vector.template AsType()(i) = - type_convert(src_tmp_vector.template AsType()[i]); - }); - - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); - } - }); - } - - template - __device__ void MoveSrcSliceWindow(const SrcDesc&, - const SrcSliceMoveStepIdx& src_slice_move_step_idx) - { - constexpr auto src_desc = SrcDesc{}; - - const auto src_slice_move_step_iter = - make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); - - move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); - } - __device__ void SetSrcCoord(const Index& src_ref_idx) - { - src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx); - } - - private: - SrcCoord src_ref_coord_; -}; - -/** - * @brief Threadwise data transfer - * - * Do NOT involve any tensor coordinates with StaticBuffer - * - */ -template ::type = false> -struct ThreadwiseTensorSliceTransfer_StaticToStatic -{ - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - static constexpr index_t PackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic( - const ElementwiseOperation& element_op) - : element_op_{element_op} - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc need to known at compile-time"); - - static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, - "wrong! Not divisible"); - } - - template - __device__ void Run(const SrcDesc&, - const SrcSliceOriginIdx&, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc need to known at compile-time"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SliceOrigin need to known at compile-time"); - - static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), - "wrong! Buffer need to be StaticBuffer"); - - // SrcDesc and src_slice_origin_idx are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); - - // scalar per access on each dim - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, - "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - - if constexpr(is_same, pk_i4_t>::value) - { - static_for<0, num_access, 1>{}([&](auto idx_1d) { - typename vector_type_maker::type - src_tmp_vector; - - constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - src_tmp_vector.template AsType()(i) = src_buf[Number{}]; - }); - - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - - constexpr index_t pack_size = 8; - - static_assert(DstScalarPerVector % pack_size == 0, ""); - - using src_v_t = typename vector_type_maker_t::type; - using dst_v_t = typename vector_type_maker_t::type; - - static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) { - ck::tensor_operation::element_wise::PassThroughPack8{}( - dst_tmp_vector.template AsType()(i), - src_tmp_vector.template AsType()[i]); - }); - - // copy data from dst_tmp_vector into dst_buf - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); - }); - } - else - { - static_for<0, num_access, 1>{}([&](auto idx_1d) { - constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - DstData v; - - // apply element-wise operation - element_op_(v, src_buf[Number{}]); - - // apply type convert - dst_buf(Number{}) = v; - }); - }); - } - } - - ElementwiseOperation element_op_; -}; - -// Specialized for gfx11 -// A single Wave32 is composed by double row -// Data exchange allowed between these two rows -// This RowLane Dst buf will be filled from two Src buf -// SrcA: From specific thread buffer hold by This RowLane on This Row -// SrcB: From specific thread buffer hold by This RowLane on The other Row -template ::type = false> -struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow -{ - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc need to known at compile-time"); - - static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, - "wrong! Not divisible"); - ignore = src_idx; - } - - template - __device__ void Run(const SrcDesc&, - const SrcSliceOriginIdx&, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc need to known at compile-time"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SliceOrigin need to known at compile-time"); - - static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), - "wrong! Buffer need to be StaticBuffer"); - - // SrcDesc and src_slice_origin_idx are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); - - // scalar per access on each dim - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, - "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - - static_for<0, num_access, 1>{}([&](auto idx_1d) { - constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - // src_desc error, non constexpr, caused by merge transform - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - SrcData v_this_row, v_theother_row; - // int type temp value due to intrinsic requirement - int temp = 0; - - // apply element-wise operation - element_op_(v_this_row, src_buf[Number{}]); - - // apply intra-row permute. - if constexpr(IntraRowSwizzlePerm) - { - temp = __builtin_amdgcn_permlane16( - temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); - v_this_row = type_convert_sp(temp); - } - - // apply inter-row permute. - temp = __builtin_amdgcn_permlanex16(temp, - type_convert_sp(v_this_row), - LowEightRowlaneIdx, - HighEightRowLaneIdx, - 1, - 0); - v_theother_row = type_convert_sp(temp); - - if(get_thread_local_1d_id() % 32 < 16) - { - // apply type convert - dst_buf(Number{}) = type_convert_sp(v_this_row); - dst_buf(Number{}) = - type_convert_sp(v_theother_row); - } - else - { - // apply type convert - dst_buf(Number{}) = - type_convert_sp(v_this_row); - dst_buf(Number{}) = type_convert_sp(v_theother_row); - } - }); - }); - } - ElementwiseOperation element_op_{}; -}; - -// Specialized for gfx12 -template ::type = false> -struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow -{ - static constexpr index_t nDim = SliceLengths::Size(); - - using Index = MultiIndex; - - __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc need to known at compile-time"); - - static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, - "wrong! Not divisible"); - ignore = src_idx; - } - - template - __device__ void Run(const SrcDesc&, - const SrcSliceOriginIdx&, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) const - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc need to known at compile-time"); - - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! SliceOrigin need to known at compile-time"); - - static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), - "wrong! Buffer need to be StaticBuffer"); - - // SrcDesc and src_slice_origin_idx are known at compile-time - constexpr auto src_desc = remove_cvref_t{}; - constexpr auto dst_desc = remove_cvref_t{}; - constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); - constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); - - // scalar per access on each dim - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, - "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - - static_for<0, num_access, 1>{}([&](auto idx_1d) { - constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); - - // copy data from src_buf into dst_vector - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - // src_desc error, non constexpr, caused by merge transform - constexpr index_t src_offset = src_desc.CalculateOffset( - src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - - SrcData v_this_row; - // int type temp value due to intrinsic requirement - int temp = 0; - - // apply element-wise operation - element_op_(v_this_row, src_buf[Number{}]); - - // apply intra-row permute. - if constexpr(IntraRowSwizzlePerm) - { - temp = __builtin_amdgcn_permlane16( - temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); - v_this_row = type_convert_sp(temp); - } - - // apply type convert - dst_buf(Number{}) = type_convert_sp(v_this_row); - }); - }); - } - ElementwiseOperation element_op_{}; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index af19e6126b..063bd71de6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -362,99 +362,6 @@ struct ThreadwiseTensorSliceTransfer_v2 } } - template - __device__ void RunPrint(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc&, - const DstSliceOriginIdx&, - DstBuffer& dst_buf) - { - static_assert(DstDesc::IsKnownAtCompileTime(), - "wrong! DstDesc need to known at compile-time"); - - static_assert(is_known_at_compile_time>::value, - "wrong! DstSliceOrigin need to known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - // DstDesc and dst_slice_origin_idx are known at compile-time - constexpr auto dst_desc = remove_cvref_t{}; - constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_scalar_step_in_vector = - generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); - - using SpaceFillingCurve = SpaceFillingCurve>; - - // loop over tensor and copy - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - - static_for<0, num_access, 1>{}([&](auto idx_1d) { - typename vector_type_maker::type src_vector; - - using src_vector_t = - typename vector_type_maker::type::type; - constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); - - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - - printf("Tid: %03d, Ascale read gmem src_data_coord.GetOffset() = %d\n", - get_thread_local_1d_id(), - src_coord_.GetOffset()); - // copy data from src_buf into src_vector - src_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_coord_.GetOffset() / PackedSize, - is_src_valid); - - // copy data from src_vector into dst_buf - static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { - constexpr index_t dst_offset = - dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + - i * src_scalar_step_in_vector); - - if constexpr(InvalidElementAsNaN) - { - dst_buf(Number{}) = - is_src_valid - ? type_convert(src_vector.template AsType()[i]) - : NumericLimits::QuietNaN(); - } - else - { - dst_buf(Number{}) = - type_convert(src_vector.template AsType()[i]); - } - }); - - if constexpr(idx_1d.value != num_access - 1) - { - constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); - - move_tensor_coordinate( - src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); - } - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } - } - __device__ static constexpr auto GetSrcCoordinateResetStep() { constexpr auto src_scalar_per_access = generate_sequence(