diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.cpp b/example/ck_tile/42_mx_gemm/mx_gemm.cpp index b0b0c19e56..c329347e11 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.cpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.cpp @@ -60,40 +60,40 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, scale_m, scale_n); - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; + // using GemmShape = ck_tile::TileGemmShape< + // ck_tile::sequence, + // ck_tile::sequence, + // ck_tile::sequence>; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + // using TilePartitioner = + // ck_tile::GemmSpatiallyLocalTilePartitioner; - using MXGemmTraits = ck_tile::TileGemmUniversalTraits; + // using MXGemmTraits = ck_tile::TileGemmUniversalTraits; - using MXPipelineProblem = MXGemmPipelineProblem; + // using MXPipelineProblem = MXGemmPipelineProblem; - // Use the new MX comp_async pipeline with MX scaling support - using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync; + // // Use the new MX comp_async pipeline with MX scaling support + // using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync; // Simplified invocation - comp_async handles hot loop and tail internally auto invoke_splitk_path = [&](auto split_k_) { diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.hpp b/example/ck_tile/42_mx_gemm/mx_gemm.hpp index dbabdff9ea..42c4a34da2 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.hpp @@ -43,7 +43,7 @@ struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0> struct MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 512; static constexpr ck_tile::index_t M_Warp = 1; @@ -74,7 +74,7 @@ struct MxGemmConfig struct MXfp4_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; }; diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index 11f687a6ef..422c2b6833 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -154,17 +154,17 @@ int run_mx_gemm_example(int argc, char* argv[]) MXfp4_GemmConfig16, true>(argc, argv, Row{}, Col{}, Row{}); } - else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") - { - return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } + // else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") + // { + // return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + // } else { - throw std::runtime_error("Only fp4 and fp8 is supported currently!"); + throw std::runtime_error("Only fp4 is supported currently!"); } } else diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 351dcabe06..0c7efbfbf9 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -673,8 +673,8 @@ struct UniversalGemmKernel using AiLayout = remove_cvref_t>; using AiDataType = remove_cvref_t>; static_assert(GemmPipeline::GetVectorSizeA() == GemmPipeline::GetVectorSizeB(), "Vector size of A and B must be the same!"); - static_assert(GemmPipeline::GetVectorSizeA() == 16, "Vector size of A must be 16!"); - static_assert(GemmPipeline::GetVectorSizeB() == 16, "Vector size of B must be 16!"); + static_assert(GemmPipeline::GetVectorSizeA() == 32, "Vector size of A must be 16!"); + static_assert(GemmPipeline::GetVectorSizeB() == 32, "Vector size of B must be 16!"); if constexpr(std::is_same_v) { return make_naive_tensor_view( diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index e123cee9e1..e745de9d13 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -843,7 +843,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { using ADataType = remove_cvref_t; constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor(); @@ -853,7 +853,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() { using BDataType = std::conditional_t, @@ -866,7 +866,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { constexpr index_t smem_size_a = GetSmemSizeA(); constexpr index_t smem_size_b = GetSmemSizeB(); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 551e434ff9..7377430a50 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -316,10 +316,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< number{}); /// Check tile window traits for vector size - using ATileDstr = remove_cvref_t())>; + // using ATileDstr = remove_cvref_t())>; // static_assert(ATileDstr::LargestVec >= 16, "wrong! not implemented vector size"); // static_assert(ATileDstr::X1 >= 16, "wrong! not implemented vector size"); - using BTileDstr = remove_cvref_t())>; + // using BTileDstr = remove_cvref_t())>; // static_assert(BTileDstr::LargestVec >= 16, "wrong! not implemented vector size"); // static_assert(BTileDstr::X1 >= 16, "wrong! not implemented vector size"); using ATileType = remove_cvref_t{}])>; diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 8e4fa06888..6633e9493e 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -39,7 +39,15 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy // constexpr index_t vector_size_for_16_bytes = 16 / sizeof(ADataType); // return vector_size_for_16_bytes; - return 16; + static_assert(std::is_same_v, "ADataType must be pk_fp4_t or pk_fp4_raw_t"); + if constexpr(std::is_same_v || std::is_same_v) + { + return 32; + } + else + { + return 16; + } } template @@ -55,7 +63,15 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy // constexpr index_t vector_size_for_16_bytes = 16 / sizeof(BDataType); // return vector_size_for_16_bytes; - return 16; + static_assert(std::is_same_v, "BDataType must be pk_fp4_t or pk_fp4_raw_t"); + if constexpr(std::is_same_v || std::is_same_v) + { + return 32; + } + else + { + return 16; + } } // Override DRAM tile distributions to use the constrained vector sizes