From e57d5da9bbee60a258f566bdee4e46ae963a9ae1 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 28 Mar 2022 16:46:21 -0500 Subject: [PATCH] Grouped gemm test fix (#150) * fixed test: return res; rand gemm shapes * fixed return [ROCm/composable_kernel commit: fe6ce55c2449f3758dd9b7b9418a669ae74fc311] --- test/grouped_gemm/grouped_gemm_fp16.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp index 9b3d2901ee..1568f4935f 100644 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -66,7 +66,7 @@ static bool check_err(const Tensor& ref, const Tensor& result) bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) { - int group_count = 4; + int group_count = rand() % 10 + 1; // GEMM shape std::vector gemm_shapes; @@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) for(int i = 0; i < group_count; i++) { - int M = 256 + 256 * i; - int N = 128 + 128 * i; - int K = 128 + 64 * i; + int M = 256 + 256 * (rand() % 10); + int N = 256 + 256 * (rand() % 10); + int K = 128 + 128 * (rand() % 10); int AStride = std::is_same::value ? K : M; int BStride = std::is_same::value ? N : K; @@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) c_device_tensors.emplace_back(Tensor(f_host_tensor_descriptor( gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); - a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); } for(int i = 0; i < gemm_shapes.size(); i++) @@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) b_element_op, c_element_op); + if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get())) + { + return false; + } + ref_invoker.Run(ref_argument); bool res = check_err(c_device_tensors[i], c_host_tensors[i]); @@ -210,4 +215,6 @@ int main() } std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + return res ? 0 : 1; }