mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Grouped gemm test fix (#150)
* fixed test: return res; rand gemm shapes
* fixed return
[ROCm/composable_kernel commit: fe6ce55c24]
This commit is contained in:
@@ -66,7 +66,7 @@ static bool check_err(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
|
||||
bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
|
||||
{
|
||||
int group_count = 4;
|
||||
int group_count = rand() % 10 + 1;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
|
||||
@@ -77,9 +77,9 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
int M = 256 + 256 * i;
|
||||
int N = 128 + 128 * i;
|
||||
int K = 128 + 64 * i;
|
||||
int M = 256 + 256 * (rand() % 10);
|
||||
int N = 256 + 256 * (rand() % 10);
|
||||
int K = 128 + 128 * (rand() % 10);
|
||||
|
||||
int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M;
|
||||
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
|
||||
@@ -132,8 +132,8 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
|
||||
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
|
||||
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
|
||||
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
}
|
||||
|
||||
for(int i = 0; i < gemm_shapes.size(); i++)
|
||||
@@ -181,6 +181,11 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
bool res = check_err(c_device_tensors[i], c_host_tensors[i]);
|
||||
@@ -210,4 +215,6 @@ int main()
|
||||
}
|
||||
|
||||
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
|
||||
return res ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user