diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp index e21a5cb335..a5f4b75127 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp @@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, M)); + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); } }(); @@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, K)); + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); } }(); @@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, M)); + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); } }();