Merge commit 'e82ccbdaf7a24af7d14d65a61ad00f4f144f84f5' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-09 08:14:24 +00:00
parent e71c156d69
commit 0ed9296d16
4 changed files with 79 additions and 37 deletions

View File

@@ -71,7 +71,7 @@ auto create_args(int argc, char* argv[])
.insert("batch_stride_b", "2097152", "Batch B stride")
.insert("batch_stride_c", "524288", "Batch C stride")
.insert("batch_count", "8", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")

View File

@@ -246,23 +246,9 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
@@ -284,15 +270,6 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C,
batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
@@ -16,11 +16,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
std::tuple< Col, Col, Row, F16, F16, F32, F16>
>;
// clang-format on

View File

@@ -1,9 +1,74 @@
#pragma once
struct GemmParams
{
int M;
int N;
int K;
int batchCount;
};
struct StrideConfig
{
int strideA;
int strideB;
int strideC;
int batchStrideA;
int batchStrideB;
int batchStrideC;
};
TYPED_TEST(TestCkTileBatchedGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 256;
constexpr int K = 512;
this->Run(M, N, K);
std::vector<GemmParams> gemmParams{{256, 256, 256, 1},
{256, 256, 256, 2},
{256, 256, 512, 2},
{256, 256, 128, 2},
{256, 256, 64, 2},
{256, 256, 64, 3},
{256, 256, 64, 4},
{256, 256, 64, 8},
{256, 256, 64, 16}};
for(auto& params : gemmParams)
{
std::vector<StrideConfig> strideConfigs{{params.K,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.K,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.N,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N},
{params.M,
params.K,
params.N,
params.M * params.K,
params.K * params.N,
params.M * params.N}};
for(auto& conf : strideConfigs)
{
this->Run(params.M,
params.N,
params.K,
conf.strideA,
conf.strideB,
conf.strideC,
conf.batchStrideA,
conf.batchStrideB,
conf.batchStrideC,
params.batchCount);
}
}
}