diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index fffc86c994..ef7b3cac48 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -72,7 +72,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; + using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 23e99bc2a8..93524bd8d9 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -73,7 +73,7 @@ auto create_args(int argc, char* argv[]) .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index f727abe81d..3425da6712 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -194,22 +194,23 @@ int run_gemm_example(int argc, char* argv[]) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") + // if(a_layout == "R" && b_layout == "R") + // { + // return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } + // else + if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - } + // else if(a_layout == "C" && b_layout == "C") + // { + // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + // } + // else if(a_layout == "C" && b_layout == "R") + // { + // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + // } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index d6fee879b1..433f5c0dcc 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -42,6 +42,9 @@ struct BlockGemmASmemBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); + // if(threadIdx.x == 0 && blockIdx.x==0) { + // printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock); + // } constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -60,6 +63,12 @@ struct BlockGemmASmemBSmemCRegV1 const index_t iMWarp = get_warp_id() / NWarp; const index_t iNWarp = get_warp_id() % NWarp; + // if(threadIdx.x == 0 && blockIdx.x==0) { + // printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter); + // } + // MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8 + + // construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window.get_bottom_tensor_view(), @@ -136,7 +145,6 @@ struct BlockGemmASmemBSmemCRegV1 constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index 8dd1d1ec28..f510355aad 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -40,7 +40,8 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); } #else - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 2, 2); + // return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); #endif } else if constexpr(std::is_same_v && diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index c765b3ce9d..04091480d1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -112,10 +112,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { constexpr index_t smem_size_a = GetSmemSizeA(); constexpr index_t smem_size_b = GetSmemSizeB(); - index_t smem_size = 0; - smem_size += smem_size_a + smem_size_b; - - return smem_size; + return smem_size_a + smem_size_b; } template @@ -259,7 +256,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t M2 = get_warp_size() / K0; // coalesce reading for each blocks if constexpr(get_warp_size() % (M2 * K0) == 0) - { + {//Number{}, Number{}, Number{}))), constexpr index_t M1 = BlockSize / get_warp_size(); static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");