diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp index 8a9c932f4c..f27fc73b3b 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp @@ -19,7 +19,8 @@ template + typename CBlockClusterAdaptor, + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -37,14 +38,14 @@ __global__ void __shared__ FloatAB p_shared_block[shared_block_size]; - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER template + typename CBlockClusterAdaptor, + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -81,14 +83,14 @@ __global__ void __shared__ FloatAB p_shared_block[shared_block_size]; - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); } #endif @@ -102,7 +104,7 @@ template {}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -173,13 +175,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -220,7 +222,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) return false; - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; // check M01, N01 @@ -248,6 +250,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 return grid_size; } + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + __host__ __device__ static constexpr auto MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) { @@ -258,13 +267,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if constexpr(ABlockLdsExtraM) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -273,13 +282,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -338,6 +347,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -376,13 +386,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if constexpr(ABlockLdsExtraM) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -390,8 +400,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if constexpr(ABlockLdsExtraM) { return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, Number{} * K1, K1, I1)); @@ -399,7 +409,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number<1>{}, Number{}, Number{}, K1), max_lds_align); } }(); @@ -408,13 +418,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -422,8 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, Number{} * K1, K1, I1)); @@ -431,7 +441,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number<1>{}, Number{}, Number{}, K1), max_lds_align); } }(); @@ -439,7 +449,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4, + Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, @@ -466,7 +476,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4, + Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -491,8 +501,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlock, NPerBlock] is in LDS + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check @@ -518,8 +528,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 FloatAB* p_a_block = p_shared_block; FloatAB* p_b_block = p_shared_block + a_block_space_size; - constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; @@ -546,31 +556,35 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // main body index_t k_block_data_begin = 0; - - do + if constexpr(HasMainKBlockLoop) { - a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_step_hack); - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead( + a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - block_sync_lds(); + block_sync_lds(); - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead( + b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - block_sync_lds(); + block_sync_lds(); - a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - k_block_data_begin += KPerBlock; - } while(k_block_data_begin < (K0 - KPerBlock)); + k_block_data_begin += K0PerBlock; + } while(k_block_data_begin < (K0 - K0PerBlock)); + } // tail { diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp index ce674758ac..8207e2cb28 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp @@ -95,13 +95,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ const auto GemmN = Y * X * C; const auto GemmKTotal = N * Ho * Wo; - const auto GemmK = GemmKTotal / GemmK1; - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); - const index_t GemmK0 = BatchLen * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp index 579c7a1200..6381ce8bb4 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -123,13 +123,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ const auto GemmN = K; const auto GemmKTotal = N * Ho * Wo; - const auto GemmK = GemmKTotal / GemmK1; - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); - const index_t GemmK0 = BatchLen * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp index 29b404f7d0..603f872662 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -107,8 +107,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -291,13 +291,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ const auto GemmN = Y * X * C; const auto GemmKTotal = N * Ho * Wo; - const auto GemmK = GemmKTotal / GemmK1; - const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); - const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); - const index_t GemmK0 = BatchLen * GemmKPerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1 * GemmKPerBlock * GemmKBatch) * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp index 65c4f62367..30ecb02de1 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp @@ -156,27 +156,58 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, std::cout << "gridSize : " << grid_size << std::endl; } - const auto kernel = kernel_gemm_xdlops_v2r4, - remove_reference_t, - remove_reference_t, - remove_reference_t>; + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r4, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc)); @@ -189,20 +220,58 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - float ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r4, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } #endif return ave_time; }