mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
I've only verfied that the kernel compiles. Some of my choices, like float32 types and having the epilogue set the member, are not valid template parameters. I now have this indentical to a default GEMM universal kernel. I also fixed some other small logical mistakes I made. The code currently outputs the GetName results for some of the classes: ``` Kernel name: gemm_bf16_pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0 Shape: tile_gemm_shape_16x64x128x4_1x4x1_16x16x32 Problem: gemm_problem_256_0x0x0_Intrawave Pipeline: pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0 ```
102 lines
2.9 KiB
C++
102 lines
2.9 KiB
C++
|
|
// gemm_example.cpp (formerly hello_world.cpp)
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <hip/hip_runtime.h>
|
|
|
|
#include "gemm_builder.h"
|
|
|
|
namespace example {
|
|
|
|
// Helper to allocate device memory.
|
|
template <typename T>
|
|
auto AllocDevMem(const size_t n)
|
|
{
|
|
auto hip_deleter = [](int* ptr) {
|
|
if(!ptr)
|
|
{
|
|
return;
|
|
}
|
|
if(hipError_t err = hipFree(ptr); err != hipSuccess)
|
|
{
|
|
throw std::runtime_error(std::string("Error during hipFree: ") +
|
|
hipGetErrorString(err));
|
|
}
|
|
std::cout << "hipFree called for device memory at " << ptr << std::endl;
|
|
};
|
|
std::unique_ptr<int, decltype(hip_deleter)> d_data(nullptr, hip_deleter);
|
|
|
|
// Allocate memory on the device
|
|
void* ptr = nullptr;
|
|
if(hipError_t err = hipMalloc(&ptr, n * sizeof(T)); err != hipSuccess)
|
|
{
|
|
throw std::runtime_error(std::string("Error during hipMalloc: ") + hipGetErrorString(err));
|
|
}
|
|
std::cout << "Allocated device memory at " << ptr << std::endl;
|
|
|
|
// Transfer ownership to the unique_ptr
|
|
d_data.reset(static_cast<int*>(ptr));
|
|
return d_data;
|
|
}
|
|
|
|
namespace ckb = ck_tile::builder;
|
|
|
|
// Reproduce example/ck_tile/03_gemm/universal_gemm.cpp
|
|
//
|
|
// That example is kind of hard to follow, but the basic idea is that
|
|
// the function invoke_gemm (in run_gemm_example.inc) gets called with
|
|
// a GemmConfigComputeV3 (in gemm_utils.hpp), which calls the function
|
|
// gemm (in universal_gemm.cpp).
|
|
|
|
struct MyGemmTypes
|
|
{
|
|
using ADataType = ck_tile::bf16_t;
|
|
using BDataType = ck_tile::bf16_t;
|
|
using CDataType = ck_tile::bf16_t;
|
|
using AccDataType = float;
|
|
};
|
|
|
|
struct MyGemmLayout
|
|
{
|
|
using ALayout = ckb::RowMajor;
|
|
using BLayout = ckb::ColMajor;
|
|
using CLayout = ckb::RowMajor;
|
|
};
|
|
|
|
using Builder = ckb::GemmBuilder<MyGemmTypes, MyGemmLayout>;
|
|
|
|
using Gemm = Builder::value;
|
|
using Kernel = Builder::Kernel;
|
|
|
|
} // namespace example
|
|
|
|
int main()
|
|
{
|
|
// Create the GEMM kernel.
|
|
const int M = 1024, N = 2048, K = 64;
|
|
example::Gemm gemm;
|
|
|
|
// Describe the GEMM kernel:
|
|
std::cout << "Kernel name: " << example::Kernel::GetName() << std::endl;
|
|
std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl;
|
|
std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl;
|
|
std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl;
|
|
|
|
// Try GPU execution.
|
|
try
|
|
{
|
|
auto a_dev = example::AllocDevMem<float>(M * K);
|
|
auto b_dev = example::AllocDevMem<float>(K * N);
|
|
auto c_dev = example::AllocDevMem<float>(M * N);
|
|
|
|
gemm.run({.m = M, .n = N, .k = K, .a = a_dev.get(), .b = b_dev.get(), .c = c_dev.get()});
|
|
}
|
|
catch(const std::exception& e)
|
|
{
|
|
std::cerr << "Exception: " << e.what() << std::endl;
|
|
return 1;
|
|
}
|
|
|
|
return 0;
|
|
}
|