[CK_TILE] Fix Batched GEMM Example GPU verification (#2800)

Added more batched GEMM test cases
This commit is contained in:
aledudek
2025-09-09 09:30:57 +02:00
committed by GitHub
parent 75570d0fa8
commit e82ccbdaf7
4 changed files with 79 additions and 37 deletions

View File

@@ -71,7 +71,7 @@ auto create_args(int argc, char* argv[])
.insert("batch_stride_b", "2097152", "Batch B stride")
.insert("batch_stride_c", "524288", "Batch C stride")
.insert("batch_count", "8", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")

View File

@@ -246,23 +246,9 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
@@ -284,15 +270,6 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C,
batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
@@ -16,11 +16,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on

View File

@@ -1,9 +1,74 @@
#pragma once
struct GemmParams
{
int M;
int N;
int K;
int batchCount;
};
struct StrideConfig
{
int strideA;
int strideB;
int strideC;
int batchStrideA;
int batchStrideB;
int batchStrideC;
};
TYPED_TEST(TestCkTileBatchedGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
this->Run(M, N, K);
std::vector<GemmParams> gemmParams{{256, 256, 256, 1},
{256, 256, 256, 2},
{256, 256, 512, 2},
{256, 256, 128, 2},
{256, 256, 64, 2},
{256, 256, 64, 3},
{256, 256, 64, 4},
{256, 256, 64, 8},
{256, 256, 64, 16}};
for(auto& params : gemmParams)
{
std::vector<StrideConfig> strideConfigs{{params.K,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.K,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N}};
for(auto& conf : strideConfigs)
{
this->Run(params.M,
params.N,
params.K,
conf.strideA,
conf.strideB,
conf.strideC,
conf.batchStrideA,
conf.batchStrideB,
conf.batchStrideC,
params.batchCount);
}
}
}