diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp index 57ba31549e..8e5d229d08 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp index 881bc976fb..bbc359ee18 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp @@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), - type_convert(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), - type_convert(c_grid_desc_m_n_.GetElementSpaceSize())}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC}, block_2_ctile_map_{ GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, M01_{M01}, @@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm::value) { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({row * stride, stride, 1})); + std::vector({batch_stride, stride, 1})); } else { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({col * stride, 1, stride})); + std::vector({batch_stride, 1, stride})); } }; - Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); - Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{})); Tensor c_g_m_n_host_result( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); Tensor c_g_m_n_device_result( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; @@ -150,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification, StrideA, StrideB, StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{}, diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index 45ec352e72..90042c37bd 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[]) const int DefaultStrideB = ck::is_same_v ? N : K; const int DefaultStrideC = ck::is_same_v ? N : M; + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int BatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int BatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int BatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + bool pass = ck::profiler:: profile_batched_gemm_impl( do_verification, @@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[]) M, N, K, - (StrideA < 0) ? DefaultStrideA : StrideA, - (StrideB < 0) ? DefaultStrideB : StrideB, - (StrideC < 0) ? DefaultStrideC : StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + StrideA_, + StrideB_, + StrideC_, BatchCount); return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp index 24ebabcadf..7fc1f24f5f 100644 --- a/test/batched_gemm/batched_gemm_fp16.cpp +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -25,19 +25,19 @@ int main() pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, BatchCount); + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, BatchCount); + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, BatchCount); + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, BatchCount); + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1;