[CK_TILE] Fix Batched GEMM Example GPU verification (#2800)

Added more batched GEMM test cases

[ROCm/composable_kernel commit: e82ccbdaf7]
This commit is contained in:
aledudek
2025-09-09 09:30:57 +02:00
committed by GitHub
parent f8c8263798
commit 9cc281007e
4 changed files with 79 additions and 37 deletions

View File

@@ -71,7 +71,7 @@ auto create_args(int argc, char* argv[])
.insert("batch_stride_b", "2097152", "Batch B stride")
.insert("batch_stride_c", "524288", "Batch C stride")
.insert("batch_count", "8", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")

View File

@@ -246,23 +246,9 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
@@ -284,15 +270,6 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C,
batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
@@ -16,11 +16,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on

View File

@@ -1,9 +1,74 @@
#pragma once
struct GemmParams
{
int M;
int N;
int K;
int batchCount;
};
struct StrideConfig
{
int strideA;
int strideB;
int strideC;
int batchStrideA;
int batchStrideB;
int batchStrideC;
};
TYPED_TEST(TestCkTileBatchedGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
this->Run(M, N, K);
std::vector<GemmParams> gemmParams{{256, 256, 256, 1},
{256, 256, 256, 2},
{256, 256, 512, 2},
{256, 256, 128, 2},
{256, 256, 64, 2},
{256, 256, 64, 3},
{256, 256, 64, 4},
{256, 256, 64, 8},
{256, 256, 64, 16}};
for(auto& params : gemmParams)
{
std::vector<StrideConfig> strideConfigs{{params.K,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.K,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N}};
for(auto& conf : strideConfigs)
{
this->Run(params.M,
params.N,
params.K,
conf.strideA,
conf.strideB,
conf.strideC,
conf.batchStrideA,
conf.batchStrideB,
conf.batchStrideC,
params.batchCount);
}
}
}