Fix builder code and choices to compile a GEMM kernel.

I've only verfied that the kernel compiles.

Some of my choices, like float32 types and having the epilogue set the member, are not valid template parameters. I now have this indentical to
a default GEMM universal kernel.

I also fixed some other small logical mistakes I made.

The code currently outputs the GetName results for some of the classes:

```
Kernel name: gemm_bf16_pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0
Shape:       tile_gemm_shape_16x64x128x4_1x4x1_16x16x32
Problem:     gemm_problem_256_0x0x0_Intrawave
Pipeline:    pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0
```
This commit is contained in:
John Shumway
2025-08-04 16:50:37 +00:00
parent 79d34d53dd
commit e49ceff3f5
2 changed files with 18 additions and 27 deletions

View File

@@ -3,13 +3,10 @@
#include <concepts>
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
namespace ck_tile::builder {
@@ -85,7 +82,7 @@ class Gemm
template <DefinesGemmTypes Types>
struct GemmConfigForTypes
{
using PrecType = Types::AccDataType;
using PrecType = Types::ADataType;
static constexpr int CK_TILE_PIPELINE_COMPUTE_V3 = 1;
static consteval auto get_k_warp_tile(auto M_Warp_Tile)
@@ -183,17 +180,11 @@ struct GemmBuilder
typename Types::AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::Scheduler>;
GemmConfig::Scheduler,
true,
ck_tile::TailNumber::Full>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
GemmShape,
Traits>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename Types::ADataType,
@@ -213,10 +204,10 @@ struct GemmBuilder
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
ck_tile::memory_operation_enum::set,
ck_tile::memory_operation_enum::atomic_add,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmPipeline>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
};
} // namespace ck_tile::builder

View File

@@ -50,9 +50,9 @@ namespace ckb = ck_tile::builder;
struct MyGemmTypes
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using CDataType = ck_tile::bf16_t;
using AccDataType = float;
};
@@ -77,10 +77,10 @@ int main()
example::Gemm gemm;
// Describe the GEMM kernel:
std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl;
std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl;
// std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl;
// std::cout << "Kernel name: " << Kernel::GetName() << std::endl;
std::cout << "Kernel name: " << example::Kernel::GetName() << std::endl;
std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl;
std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl;
std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl;
// Try GPU execution.
try