mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
203 lines
7.3 KiB
C++
203 lines
7.3 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <cstring>
|
|
|
|
#include "config.h"
|
|
#include "ck_tile/host.hpp"
|
|
#include "gemm.hpp"
|
|
#include "reference_gemm.hpp"
|
|
|
|
/*
|
|
* Toy code of GEMM
|
|
* Assume simplest case.
|
|
* A [M, K]
|
|
* B [N, K]
|
|
* C [M, N]
|
|
*/
|
|
|
|
// elementwise lambda
|
|
struct CElementFunction
|
|
{
|
|
template <typename X>
|
|
CK_TILE_HOST_DEVICE auto operator()(const X& x) const
|
|
{
|
|
return x;
|
|
}
|
|
};
|
|
|
|
int main(int argc, char* argv[])
|
|
{
|
|
using ADataType = ck_tile::half_t;
|
|
using BDataType = ck_tile::half_t;
|
|
using AccDataType = float;
|
|
using CDataType = ck_tile::half_t;
|
|
|
|
ck_tile::index_t verification = 0;
|
|
ck_tile::index_t M = 3328;
|
|
ck_tile::index_t N = 4096;
|
|
ck_tile::index_t K = 4096;
|
|
|
|
if(argc == 2)
|
|
{
|
|
verification = std::stoi(argv[1]);
|
|
}
|
|
if(argc == 5)
|
|
{
|
|
verification = std::stoi(argv[1]);
|
|
M = std::stoi(argv[2]);
|
|
N = std::stoi(argv[3]);
|
|
K = std::stoi(argv[4]);
|
|
}
|
|
|
|
#if defined(KERNEL_A)
|
|
printf("*** Kernel A test *** \n");
|
|
printf(" --> Using mfma_32x32x(8x2)\n");
|
|
#elif defined(KERNEL_B)
|
|
printf("*** Kernel B test *** \n");
|
|
printf(" --> Using mfma_16x16x16\n");
|
|
#elif defined(KERNEL_C)
|
|
printf("*** Kernel C test *** \n");
|
|
printf(" --> Using mfma_16x16x(16x2)\n");
|
|
#elif defined(KERNEL_D)
|
|
printf("*** Kernel D test *** \n");
|
|
printf(" --> Using mfma_16x16x(16x2)\n");
|
|
printf(" --> XOR-based bank-conflict-free\n");
|
|
#elif defined(KERNEL_E)
|
|
printf("*** Kernel E test ***\n");
|
|
printf(" --> Using mfma_16x16x(16x2)\n");
|
|
printf(" --> XOR-based bank-conflict-free\n");
|
|
printf(" --> Adjust block tile shape\n");
|
|
#elif defined(KERNEL_F)
|
|
printf("*** Kernel F test ***\n");
|
|
printf(" --> Using mfma_16x16x(16x2)\n");
|
|
printf(" --> XOR-based bank-conflict-free\n");
|
|
printf(" --> Adjust block tile shape\n");
|
|
printf(" --> Enable prefetch\n");
|
|
#elif defined(KERNEL_G)
|
|
printf("*** Kernel G test ***\n");
|
|
printf(" --> Using mfma_16x16x(16x2)\n");
|
|
printf(" --> XOR-based bank-conflict-free\n");
|
|
printf(" --> Adjust block tile shape\n");
|
|
printf(" --> Enable prefetch\n");
|
|
printf(" --> Enable instruction schedule\n");
|
|
#elif defined(KERNEL_H)
|
|
printf("*** Kernel H test ***\n");
|
|
printf(" --> Using mfma_16x16x(16x2)\n");
|
|
printf(" --> XOR-based bank-conflict-free\n");
|
|
printf(" --> Adjust block tile shape\n");
|
|
printf(" --> Enable prefetch\n");
|
|
printf(" --> Enable instruction schedule\n");
|
|
printf(" --> Enable cache-aware thread blocks schedule\n");
|
|
#else
|
|
printf("*** Naive implementation test ***\n");
|
|
#endif
|
|
|
|
const ck_tile::index_t Lda = K;
|
|
const ck_tile::index_t Ldb = K;
|
|
const ck_tile::index_t Ldc = N;
|
|
|
|
const auto a_lengths = std::array<ck_tile::index_t, 2>{M, K};
|
|
const auto a_strides = std::array<ck_tile::index_t, 2>{Lda, 1};
|
|
|
|
const auto b_lengths = std::array<ck_tile::index_t, 2>{N, K};
|
|
const auto b_strides = std::array<ck_tile::index_t, 2>{Ldb, 1};
|
|
|
|
const auto c_lengths = std::array<ck_tile::index_t, 2>{M, N};
|
|
const auto c_strides = std::array<ck_tile::index_t, 2>{Ldc, 1};
|
|
|
|
// host verify
|
|
ck_tile::HostTensor<ADataType> a_host(a_lengths, a_strides);
|
|
ck_tile::HostTensor<BDataType> b_host(b_lengths, b_strides);
|
|
ck_tile::HostTensor<CDataType> c_host_dev(c_lengths, c_strides);
|
|
|
|
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
|
|
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);
|
|
|
|
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
|
ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes());
|
|
|
|
a_buf.ToDevice(a_host.mData.data());
|
|
b_buf.ToDevice(b_host.mData.data());
|
|
|
|
// Alignment
|
|
constexpr ck_tile::index_t kAAlignment = 8;
|
|
constexpr ck_tile::index_t kBAlignment = 8;
|
|
constexpr ck_tile::index_t kCAlignment = 8;
|
|
|
|
constexpr ck_tile::index_t kBlockSize = 256;
|
|
|
|
#ifdef ADJUST_BLOCK_TILE_SHAPE
|
|
constexpr ck_tile::index_t kGemmMPerBlock = 128;
|
|
constexpr ck_tile::index_t kGemmKPerBlock = 64;
|
|
#else
|
|
constexpr ck_tile::index_t kGemmMPerBlock = 256;
|
|
constexpr ck_tile::index_t kGemmKPerBlock = 32;
|
|
#endif
|
|
constexpr ck_tile::index_t kGemmNPerBlock = 128;
|
|
|
|
ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
|
|
|
|
std::cout << "grid size " << kGridSize << std::endl;
|
|
|
|
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
|
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
|
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
|
|
|
using gemm_kernel = ck_tile::Gemm<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CDataType,
|
|
CElementFunction,
|
|
kAAlignment,
|
|
kBAlignment,
|
|
kCAlignment,
|
|
kBlockSize,
|
|
kGemmMPerBlock,
|
|
kGemmNPerBlock,
|
|
kGemmKPerBlock>;
|
|
|
|
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000},
|
|
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
|
gemm_kernel{},
|
|
kGridSize,
|
|
kBlockSize,
|
|
0,
|
|
static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
|
|
static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
|
|
static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
|
|
M,
|
|
N,
|
|
K,
|
|
Lda,
|
|
Ldb,
|
|
Ldc,
|
|
CElementFunction{}));
|
|
auto pass = true;
|
|
|
|
if(verification)
|
|
{
|
|
// reference gemm
|
|
ck_tile::HostTensor<CDataType> c_host_ref(c_lengths, c_strides);
|
|
reference_basic_gemm<ADataType, ADataType, AccDataType, CDataType>(
|
|
a_host, b_host, c_host_ref);
|
|
c_buf.FromDevice(c_host_dev.mData.data());
|
|
pass &= ck_tile::check_err(c_host_dev, c_host_ref);
|
|
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
|
}
|
|
|
|
std::size_t flop = std::size_t(2) * M * N * K;
|
|
std::size_t num_btype =
|
|
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
|
|
|
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
|
|
|
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
|
|
|
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
|
<< std::endl;
|
|
|
|
return !pass;
|
|
}
|