// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "config.h" #include "ck_tile/host.hpp" #include "gemm.hpp" #include "reference_gemm.hpp" /* * Toy code of GEMM * Assume simplest case. * A [M, K] * B [N, K] * C [M, N] */ // elementwise lambda struct CElementFunction { template CK_TILE_HOST_DEVICE auto operator()(const X& x) const { return x; } }; int main(int argc, char* argv[]) { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; using CDataType = ck_tile::half_t; ck_tile::index_t verification = 0; ck_tile::index_t M = 3328; ck_tile::index_t N = 4096; ck_tile::index_t K = 4096; if(argc == 2) { verification = std::stoi(argv[1]); } if(argc == 5) { verification = std::stoi(argv[1]); M = std::stoi(argv[2]); N = std::stoi(argv[3]); K = std::stoi(argv[4]); } #if defined(KERNEL_A) printf("*** Kernel A test *** \n"); printf(" --> Using mfma_32x32x(8x2)\n"); #elif defined(KERNEL_B) printf("*** Kernel B test *** \n"); printf(" --> Using mfma_16x16x16\n"); #elif defined(KERNEL_C) printf("*** Kernel C test *** \n"); printf(" --> Using mfma_16x16x(16x2)\n"); #elif defined(KERNEL_D) printf("*** Kernel D test *** \n"); printf(" --> Using mfma_16x16x(16x2)\n"); printf(" --> XOR-based bank-conflict-free\n"); #elif defined(KERNEL_E) printf("*** Kernel E test ***\n"); printf(" --> Using mfma_16x16x(16x2)\n"); printf(" --> XOR-based bank-conflict-free\n"); printf(" --> Adjust block tile shape\n"); #elif defined(KERNEL_F) printf("*** Kernel F test ***\n"); printf(" --> Using mfma_16x16x(16x2)\n"); printf(" --> XOR-based bank-conflict-free\n"); printf(" --> Adjust block tile shape\n"); printf(" --> Enable prefetch\n"); #elif defined(KERNEL_G) printf("*** Kernel G test ***\n"); printf(" --> Using mfma_16x16x(16x2)\n"); printf(" --> XOR-based bank-conflict-free\n"); printf(" --> Adjust block tile shape\n"); printf(" --> Enable prefetch\n"); printf(" --> Enable instruction schedule\n"); #elif defined(KERNEL_H) printf("*** Kernel H test ***\n"); printf(" --> Using mfma_16x16x(16x2)\n"); printf(" --> XOR-based bank-conflict-free\n"); printf(" --> Adjust block tile shape\n"); printf(" --> Enable prefetch\n"); printf(" --> Enable instruction schedule\n"); printf(" --> Enable cache-aware thread blocks schedule\n"); #else printf("*** Naive implementation test ***\n"); #endif const ck_tile::index_t Lda = K; const ck_tile::index_t Ldb = K; const ck_tile::index_t Ldc = N; const auto a_lengths = std::array{M, K}; const auto a_strides = std::array{Lda, 1}; const auto b_lengths = std::array{N, K}; const auto b_strides = std::array{Ldb, 1}; const auto c_lengths = std::array{M, N}; const auto c_strides = std::array{Ldc, 1}; // host verify ck_tile::HostTensor a_host(a_lengths, a_strides); ck_tile::HostTensor b_host(b_lengths, b_strides); ck_tile::HostTensor c_host_dev(c_lengths, c_strides); ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); a_buf.ToDevice(a_host.mData.data()); b_buf.ToDevice(b_host.mData.data()); // Alignment constexpr ck_tile::index_t kAAlignment = 8; constexpr ck_tile::index_t kBAlignment = 8; constexpr ck_tile::index_t kCAlignment = 8; constexpr ck_tile::index_t kBlockSize = 256; #ifdef ADJUST_BLOCK_TILE_SHAPE constexpr ck_tile::index_t kGemmMPerBlock = 128; constexpr ck_tile::index_t kGemmKPerBlock = 64; #else constexpr ck_tile::index_t kGemmMPerBlock = 256; constexpr ck_tile::index_t kGemmKPerBlock = 32; #endif constexpr ck_tile::index_t kGemmNPerBlock = 128; ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); std::cout << "grid size " << kGridSize << std::endl; constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; using gemm_kernel = ck_tile::Gemm; float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000}, ck_tile::make_kernel( gemm_kernel{}, kGridSize, kBlockSize, 0, static_cast(a_buf.GetDeviceBuffer()), static_cast(b_buf.GetDeviceBuffer()), static_cast(c_buf.GetDeviceBuffer()), M, N, K, Lda, Ldb, Ldc, CElementFunction{})); auto pass = true; if(verification) { // reference gemm ck_tile::HostTensor c_host_ref(c_lengths, c_strides); reference_basic_gemm( a_host, b_host, c_host_ref); c_buf.FromDevice(c_host_dev.mData.data()); pass &= ck_tile::check_err(c_host_dev, c_host_ref); std::cout << "valid:" << (pass ? "y" : "n") << std::endl; } std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; return !pass; }