Files
composable_kernel/experimental/gemm_builder/gemm_example.cpp
John Shumway e49ceff3f5 Fix builder code and choices to compile a GEMM kernel.
I've only verfied that the kernel compiles.

Some of my choices, like float32 types and having the epilogue set the member, are not valid template parameters. I now have this indentical to
a default GEMM universal kernel.

I also fixed some other small logical mistakes I made.

The code currently outputs the GetName results for some of the classes:

```
Kernel name: gemm_bf16_pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0
Shape:       tile_gemm_shape_16x64x128x4_1x4x1_16x16x32
Problem:     gemm_problem_256_0x0x0_Intrawave
Pipeline:    pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0
```
2025-08-04 16:50:37 +00:00

102 lines
2.9 KiB
C++

// gemm_example.cpp (formerly hello_world.cpp)
#include <iostream>
#include <memory>
#include <hip/hip_runtime.h>
#include "gemm_builder.h"
namespace example {
// Helper to allocate device memory.
template <typename T>
auto AllocDevMem(const size_t n)
{
auto hip_deleter = [](int* ptr) {
if(!ptr)
{
return;
}
if(hipError_t err = hipFree(ptr); err != hipSuccess)
{
throw std::runtime_error(std::string("Error during hipFree: ") +
hipGetErrorString(err));
}
std::cout << "hipFree called for device memory at " << ptr << std::endl;
};
std::unique_ptr<int, decltype(hip_deleter)> d_data(nullptr, hip_deleter);
// Allocate memory on the device
void* ptr = nullptr;
if(hipError_t err = hipMalloc(&ptr, n * sizeof(T)); err != hipSuccess)
{
throw std::runtime_error(std::string("Error during hipMalloc: ") + hipGetErrorString(err));
}
std::cout << "Allocated device memory at " << ptr << std::endl;
// Transfer ownership to the unique_ptr
d_data.reset(static_cast<int*>(ptr));
return d_data;
}
namespace ckb = ck_tile::builder;
// Reproduce example/ck_tile/03_gemm/universal_gemm.cpp
//
// That example is kind of hard to follow, but the basic idea is that
// the function invoke_gemm (in run_gemm_example.inc) gets called with
// a GemmConfigComputeV3 (in gemm_utils.hpp), which calls the function
// gemm (in universal_gemm.cpp).
struct MyGemmTypes
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using CDataType = ck_tile::bf16_t;
using AccDataType = float;
};
struct MyGemmLayout
{
using ALayout = ckb::RowMajor;
using BLayout = ckb::ColMajor;
using CLayout = ckb::RowMajor;
};
using Builder = ckb::GemmBuilder<MyGemmTypes, MyGemmLayout>;
using Gemm = Builder::value;
using Kernel = Builder::Kernel;
} // namespace example
int main()
{
// Create the GEMM kernel.
const int M = 1024, N = 2048, K = 64;
example::Gemm gemm;
// Describe the GEMM kernel:
std::cout << "Kernel name: " << example::Kernel::GetName() << std::endl;
std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl;
std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl;
std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl;
// Try GPU execution.
try
{
auto a_dev = example::AllocDevMem<float>(M * K);
auto b_dev = example::AllocDevMem<float>(K * N);
auto c_dev = example::AllocDevMem<float>(M * N);
gemm.run({.m = M, .n = N, .k = K, .a = a_dev.get(), .b = b_dev.get(), .c = c_dev.get()});
}
catch(const std::exception& e)
{
std::cerr << "Exception: " << e.what() << std::endl;
return 1;
}
return 0;
}