[CK_TILE] Fix Batched GEMM Example GPU verification (#2800)

Added more batched GEMM test cases

[ROCm/composable_kernel commit: e82ccbdaf7]
This commit is contained in:
aledudek
2025-09-09 09:30:57 +02:00
committed by GitHub
parent ea352f2510
commit a3273fef14
4 changed files with 79 additions and 37 deletions

View File

@@ -71,7 +71,7 @@ auto create_args(int argc, char* argv[])
.insert("batch_stride_b", "2097152", "Batch B stride")
.insert("batch_stride_c", "524288", "Batch C stride")
.insert("batch_count", "8", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")

View File

@@ -246,23 +246,9 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
@@ -284,15 +270,6 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C,
batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
@@ -16,11 +16,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on

View File

@@ -1,9 +1,74 @@
#pragma once
struct GemmParams
{
int M;
int N;
int K;
int batchCount;
};
struct StrideConfig
{
int strideA;
int strideB;
int strideC;
int batchStrideA;
int batchStrideB;
int batchStrideC;
};
TYPED_TEST(TestCkTileBatchedGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
this->Run(M, N, K);
std::vector<GemmParams> gemmParams{{256, 256, 256, 1},
{256, 256, 256, 2},
{256, 256, 512, 2},
{256, 256, 128, 2},
{256, 256, 64, 2},
{256, 256, 64, 3},
{256, 256, 64, 4},
{256, 256, 64, 8},
{256, 256, 64, 16}};
for(auto& params : gemmParams)
{
std::vector<StrideConfig> strideConfigs{{params.K,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.K,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N}};
for(auto& conf : strideConfigs)
{
this->Run(params.M,
params.N,
params.K,
conf.strideA,
conf.strideB,
conf.strideC,
conf.batchStrideA,
conf.batchStrideB,
conf.batchStrideC,
params.batchCount);
}
}
}