diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 0da6501568..b63c269377 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -71,7 +71,7 @@ auto create_args(int argc, char* argv[]) .insert("batch_stride_b", "2097152", "Batch B stride") .insert("batch_stride_c", "524288", "Batch C stride") .insert("batch_count", "8", "Batch count") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 3289a2836b..c446fa7428 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -246,23 +246,9 @@ int run_batched_gemm_example_with_layouts(int argc, c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType))); - ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType))); - ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType))); - - ck_tile::hip_check_error(hipMemcpy(d_A, - a_m_k_dev_buf.GetDeviceBuffer(), - batch_count * M * K * sizeof(ADataType), - hipMemcpyHostToDevice)); - - ck_tile::hip_check_error(hipMemcpy(d_B, - b_k_n_dev_buf.GetDeviceBuffer(), - batch_count * N * K * sizeof(BDataType), - hipMemcpyHostToDevice)); + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); ck_tile::reference_batched_gemm_gpu @@ -16,11 +16,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType - // std::tuple< Row, Row, Row, F16, F16, F32, F16>, - //std::tuple< Col, Row, Row, F16, F16, F32, F16>, - std::tuple< Row, Col, Row, F16, F16, F32, F16>//, - //std::tuple< Col, Col, Row, F16, F16, F32, F16> + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType + std::tuple< Row, Row, Row, F16, F16, F32, F16>, + std::tuple< Col, Row, Row, F16, F16, F32, F16>, + std::tuple< Row, Col, Row, F16, F16, F32, F16>, + std::tuple< Col, Col, Row, F16, F16, F32, F16> >; // clang-format on diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index 74338ba383..b2f965764d 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -1,9 +1,74 @@ #pragma once +struct GemmParams +{ + int M; + int N; + int K; + int batchCount; +}; + +struct StrideConfig +{ + int strideA; + int strideB; + int strideC; + int batchStrideA; + int batchStrideB; + int batchStrideC; +}; + TYPED_TEST(TestCkTileBatchedGemm, Basic) { - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 512; - this->Run(M, N, K); + std::vector gemmParams{{256, 256, 256, 1}, + {256, 256, 256, 2}, + {256, 256, 512, 2}, + {256, 256, 128, 2}, + {256, 256, 64, 2}, + {256, 256, 64, 3}, + {256, 256, 64, 4}, + {256, 256, 64, 8}, + {256, 256, 64, 16}}; + + for(auto& params : gemmParams) + { + std::vector strideConfigs{{params.K, + params.N, + params.N, + params.M * params.K, + params.K * params.N, + params.M * params.N}, + {params.K, + params.K, + params.N, + params.M * params.K, + params.K * params.N, + params.M * params.N}, + {params.M, + params.N, + params.N, + params.M * params.K, + params.K * params.N, + params.M * params.N}, + {params.M, + params.K, + params.N, + params.M * params.K, + params.K * params.N, + params.M * params.N}}; + + for(auto& conf : strideConfigs) + { + this->Run(params.M, + params.N, + params.K, + conf.strideA, + conf.strideB, + conf.strideC, + conf.batchStrideA, + conf.batchStrideB, + conf.batchStrideC, + params.batchCount); + } + } }