mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Fix builder code and choices to compile a GEMM kernel.
I've only verfied that the kernel compiles. Some of my choices, like float32 types and having the epilogue set the member, are not valid template parameters. I now have this indentical to a default GEMM universal kernel. I also fixed some other small logical mistakes I made. The code currently outputs the GetName results for some of the classes: ``` Kernel name: gemm_bf16_pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0 Shape: tile_gemm_shape_16x64x128x4_1x4x1_16x16x32 Problem: gemm_problem_256_0x0x0_Intrawave Pipeline: pipeline_AgBgCrCompV3_16x64x128_256_1x4_0x0x0 ```
This commit is contained in:
@@ -3,13 +3,10 @@
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
@@ -85,7 +82,7 @@ class Gemm
|
||||
template <DefinesGemmTypes Types>
|
||||
struct GemmConfigForTypes
|
||||
{
|
||||
using PrecType = Types::AccDataType;
|
||||
using PrecType = Types::ADataType;
|
||||
|
||||
static constexpr int CK_TILE_PIPELINE_COMPUTE_V3 = 1;
|
||||
static consteval auto get_k_warp_tile(auto M_Warp_Tile)
|
||||
@@ -183,17 +180,11 @@ struct GemmBuilder
|
||||
typename Types::AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
GemmConfig::Scheduler,
|
||||
true,
|
||||
ck_tile::TailNumber::Full>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
GemmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename Types::ADataType,
|
||||
@@ -213,10 +204,10 @@ struct GemmBuilder
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
ck_tile::memory_operation_enum::atomic_add,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmPipeline>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -50,9 +50,9 @@ namespace ckb = ck_tile::builder;
|
||||
|
||||
struct MyGemmTypes
|
||||
{
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using CDataType = float;
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
};
|
||||
|
||||
@@ -77,10 +77,10 @@ int main()
|
||||
example::Gemm gemm;
|
||||
|
||||
// Describe the GEMM kernel:
|
||||
std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl;
|
||||
std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl;
|
||||
// std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl;
|
||||
// std::cout << "Kernel name: " << Kernel::GetName() << std::endl;
|
||||
std::cout << "Kernel name: " << example::Kernel::GetName() << std::endl;
|
||||
std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl;
|
||||
std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl;
|
||||
std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl;
|
||||
|
||||
// Try GPU execution.
|
||||
try
|
||||
|
||||
Reference in New Issue
Block a user