diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index c2fff84bf8..a9aa53e071 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -128,7 +128,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; -#if 1 +#if 0 using GridwiseGemmPipe = GridwiseGemmPipeline_v1; #else using GridwiseGemmPipe = GridwiseGemmPipeline_v2; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index dad9006385..7a5af95fd9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -7,6 +7,7 @@ #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" +#include "gridwise_gemm_pipeline_v2.hpp" namespace ck { @@ -120,7 +121,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using ThisThreadBlock = ThisThreadBlock; +#if 0 using GridwiseGemmPipe = GridwiseGemmPipeline_v1; +#else + using GridwiseGemmPipe = GridwiseGemmPipeline_v2; +#endif __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 7069f1d8c3..f432ac37a1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -10,6 +10,7 @@ #include "blockwise_tensor_slice_transfer_v6r2.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" +#include "gridwise_gemm_pipeline_v2.hpp" namespace ck { @@ -136,7 +137,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 using ThisThreadBlock = ThisThreadBlock; +#if 0 using GridwiseGemmPipe = GridwiseGemmPipeline_v1; +#else + using GridwiseGemmPipe = GridwiseGemmPipeline_v2; +#endif __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() {