diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 6cb40e45d1..fc67e3eaa2 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -259,6 +259,46 @@ int run_gemm_example_with_layouts(int argc, b_k_n.SetZero(); } + + // set 1 column in A and 1 Row in B to perform outer product. + // and test the results. + //const ck_tile::index_t K_len = a_m_k.get_length(1); + const ck_tile::index_t M_len = a_m_k.get_length(0); + const ck_tile::index_t N_len = b_k_n.get_length(1); + + // Fill 0th column in A + ck_tile::half_t dd = 1; + for(int i = 0; i < M_len; i++) + { + int j = 0; + { + a_m_k(i, j) = dd; + } + int k = 8; + { + a_m_k(i, k) = dd++; + } + } + + // Fill 0th row in B + dd = 2; + int i = 0; + { + for(int j=0; j < N_len; j++) + { + b_k_n(i, j) = dd; + } + } + + i = 8; + { + for(int j=0; j < N_len; j++) + { + b_k_n(i, j) = dd; + } + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eef8d3b60e..178702f743 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -75,6 +75,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& tail_number_v>; using GemmPipeline = GEMM_PIPELINE; + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; + + + /* using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + UniversalGemmProblem::TransposeC>>; + */ using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -213,6 +228,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& Run(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } #endif } else @@ -253,6 +279,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + /* if constexpr(std::is_same_v) { if(a_layout == "R" && b_layout == "C") @@ -273,11 +300,13 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } else { - if(a_layout == "R" && b_layout == "R") + */ + if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + argc, argv, Row{}, Col{}, Row{}); } + /* else if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( @@ -293,11 +322,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } + */ else { throw std::runtime_error("Unsupported memory layout for the input matrices!"); } - } + //} } int run_gemm_example(int argc, char* argv[]) @@ -310,10 +340,19 @@ int run_gemm_example(int argc, char* argv[]) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(data_type == "fp16") + //if(data_type == "fp16") { return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + + //return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + + //return run_gemm_example_prec_type( + // a_layout, b_layout, argc, argv); + + //return run_gemm_example_prec_type( + // a_layout, b_layout, argc, argv); } + /* else if(data_type == "bf16") { return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); @@ -337,10 +376,12 @@ int run_gemm_example(int argc, char* argv[]) a_layout, b_layout, argc, argv); } #endif + else { throw std::runtime_error("Unsupported data type for this operation !!!"); } + */ } int main(int argc, char* argv[]) diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 745c18d6dd..acdbe66bf2 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -186,7 +186,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 20) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -194,6 +194,34 @@ check_err(const Range& out, res = false; } } + int total_err_count = err_count - 20; + err_count = 0; + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count >= total_err_count) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + for(std::size_t i = 0; i < ref.size() / 32; i++) + { + for(std::size_t j = 0; j < 32; j++) + { + const double o = type_convert(*std::next(std::begin(out), i * 32 + j)); + std::cerr << std::setw(10) << o; + } + std::cerr << std::endl; + } if(!res) { const float error_percent = @@ -246,7 +274,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 20) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -254,6 +282,34 @@ check_err(const Range& out, res = false; } } + int total_err_count = err_count - 20; + err_count = 0; + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count >= total_err_count) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + for(std::size_t i = 0; i < ref.size() / 32; i++) + { + for(std::size_t j = 0; j < 32; j++) + { + const double o = type_convert(*std::next(std::begin(out), i * 32 + j)); + std::cerr << std::setw(10) << o; + } + std::cerr << std::endl; + } if(!res) { const float error_percent = @@ -305,7 +361,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 20) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -313,6 +369,34 @@ check_err(const Range& out, res = false; } } + int total_err_count = err_count - 20; + err_count = 0; + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count >= total_err_count) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + for(std::size_t i = 0; i < ref.size() / 32; i++) + { + for(std::size_t j = 0; j < 32; j++) + { + const double o = type_convert(*std::next(std::begin(out), i * 32 + j)); + std::cerr << std::setw(10) << o; + } + std::cerr << std::endl; + } if(!res) { const float error_percent = @@ -367,7 +451,7 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } res = false; } - } + } if(!res) { const float error_percent = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index a20398cf60..7580c28845 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -225,7 +225,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // define ping, pong steps here as lambda functions. - auto MemoryOpsStep = [&](auto idx) { + auto MemoryOpsStep = [&]() { // Memory read half here. Base::GlobalPrefetch( a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); @@ -255,36 +255,28 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); } - - // transfer from LDS to registers - if(idx == 0) - { - Base::LocalPrefetch(a_tile_0, a_lds_window); - Base::LocalPrefetch(b_tile_0, b_lds_window); - } - else - { - Base::LocalPrefetch(a_tile_1, a_lds_window); - Base::LocalPrefetch(b_tile_1, b_lds_window); - } }; auto ComputeStep = [&](auto idx) { if(idx == 0) { + Base::LocalPrefetch(a_tile_0, a_lds_window); + Base::LocalPrefetch(b_tile_0, b_lds_window); block_gemm(c_block_tile, a_tile_0, b_tile_0); // tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile); } else { + Base::LocalPrefetch(a_tile_1, a_lds_window); + Base::LocalPrefetch(b_tile_1, b_lds_window); block_gemm(c_block_tile, a_tile_1, b_tile_1); // tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile); } }; - if(op_id == 0) + if(op_id == 1) { - MemoryOpsStep(group_id); + MemoryOpsStep(); } // start the main loop. index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop) * 2 - 1; @@ -295,9 +287,9 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 block_sync_lds(); op_id = (op_id + 1) % num_stages_; - if(op_id == 0) + if(op_id == 1) { - MemoryOpsStep(group_id); + MemoryOpsStep(); } else { @@ -310,7 +302,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // Handle Tail Number here. block_sync_lds(); - if(op_id == 0) + if(op_id == 1) { ComputeStep(group_id); }