diff --git a/experimental/gemm_builder/gemm_builder.h b/experimental/gemm_builder/gemm_builder.h index 0203f3a89d..5e18b27c23 100644 --- a/experimental/gemm_builder/gemm_builder.h +++ b/experimental/gemm_builder/gemm_builder.h @@ -3,13 +3,10 @@ #include #include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" namespace ck_tile::builder { @@ -85,7 +82,7 @@ class Gemm template struct GemmConfigForTypes { - using PrecType = Types::AccDataType; + using PrecType = Types::ADataType; static constexpr int CK_TILE_PIPELINE_COMPUTE_V3 = 1; static consteval auto get_k_warp_tile(auto M_Warp_Tile) @@ -183,17 +180,11 @@ struct GemmBuilder typename Types::AccDataType, GemmShape, GemmUniversalTraits, - GemmConfig::Scheduler>; + GemmConfig::Scheduler, + true, + ck_tile::TailNumber::Full>; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernel; }; } // namespace ck_tile::builder diff --git a/experimental/gemm_builder/gemm_example.cpp b/experimental/gemm_builder/gemm_example.cpp index 81fbe2fb81..7671185090 100644 --- a/experimental/gemm_builder/gemm_example.cpp +++ b/experimental/gemm_builder/gemm_example.cpp @@ -50,9 +50,9 @@ namespace ckb = ck_tile::builder; struct MyGemmTypes { - using ADataType = float; - using BDataType = float; - using CDataType = float; + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using CDataType = ck_tile::bf16_t; using AccDataType = float; }; @@ -77,10 +77,10 @@ int main() example::Gemm gemm; // Describe the GEMM kernel: - std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl; - std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl; - // std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl; - // std::cout << "Kernel name: " << Kernel::GetName() << std::endl; + std::cout << "Kernel name: " << example::Kernel::GetName() << std::endl; + std::cout << "Shape: " << example::Builder::GemmShape::GetName() << std::endl; + std::cout << "Problem: " << example::Builder::UniversalGemmProblem::GetName() << std::endl; + std::cout << "Pipeline: " << example::Builder::GemmPipeline::GetName() << std::endl; // Try GPU execution. try