Sync with the debug branch

This commit is contained in:
Sudhir Kylasa
2025-03-26 04:06:30 -04:00
parent a93030a78d
commit 446d1ebeed
4 changed files with 184 additions and 27 deletions

View File

@@ -259,6 +259,46 @@ int run_gemm_example_with_layouts(int argc,
b_k_n.SetZero();
}
// set 1 column in A and 1 Row in B to perform outer product.
// and test the results.
//const ck_tile::index_t K_len = a_m_k.get_length(1);
const ck_tile::index_t M_len = a_m_k.get_length(0);
const ck_tile::index_t N_len = b_k_n.get_length(1);
// Fill 0th column in A
ck_tile::half_t dd = 1;
for(int i = 0; i < M_len; i++)
{
int j = 0;
{
a_m_k(i, j) = dd;
}
int k = 8;
{
a_m_k(i, k) = dd++;
}
}
// Fill 0th row in B
dd = 2;
int i = 0;
{
for(int j=0; j < N_len; j++)
{
b_k_n(i, j) = dd;
}
}
i = 8;
{
for(int j=0; j < N_len; j++)
{
b_k_n(i, j) = dd;
}
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());

View File

@@ -75,6 +75,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
CDataType,
CLayout,
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
/*
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -89,7 +103,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
UniversalGemmProblem::TransposeC>>;
*/
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -213,6 +228,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
@@ -253,6 +279,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
/*
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
if(a_layout == "R" && b_layout == "C")
@@ -273,11 +300,13 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
else
{
if(a_layout == "R" && b_layout == "R")
*/
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
argc, argv, Row{}, Col{}, Row{});
}
/*
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
@@ -293,11 +322,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
*/
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
//}
}
int run_gemm_example(int argc, char* argv[])
@@ -310,10 +340,19 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp16")
//if(data_type == "fp16")
{
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
//return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
//return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
//return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
}
/*
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
@@ -337,10 +376,12 @@ int run_gemm_example(int argc, char* argv[])
a_layout, b_layout, argc, argv);
}
#endif
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
*/
}
int main(int argc, char* argv[])

View File

@@ -186,7 +186,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 20)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
@@ -194,6 +194,34 @@ check_err(const Range& out,
res = false;
}
}
int total_err_count = err_count - 20;
err_count = 0;
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count >= total_err_count)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
for(std::size_t i = 0; i < ref.size() / 32; i++)
{
for(std::size_t j = 0; j < 32; j++)
{
const double o = type_convert<float>(*std::next(std::begin(out), i * 32 + j));
std::cerr << std::setw(10) << o;
}
std::cerr << std::endl;
}
if(!res)
{
const float error_percent =
@@ -246,7 +274,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 20)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
@@ -254,6 +282,34 @@ check_err(const Range& out,
res = false;
}
}
int total_err_count = err_count - 20;
err_count = 0;
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count >= total_err_count)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
for(std::size_t i = 0; i < ref.size() / 32; i++)
{
for(std::size_t j = 0; j < 32; j++)
{
const double o = type_convert<float>(*std::next(std::begin(out), i * 32 + j));
std::cerr << std::setw(10) << o;
}
std::cerr << std::endl;
}
if(!res)
{
const float error_percent =
@@ -305,7 +361,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 20)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
@@ -313,6 +369,34 @@ check_err(const Range& out,
res = false;
}
}
int total_err_count = err_count - 20;
err_count = 0;
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count >= total_err_count)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
for(std::size_t i = 0; i < ref.size() / 32; i++)
{
for(std::size_t j = 0; j < 32; j++)
{
const double o = type_convert<float>(*std::next(std::begin(out), i * 32 + j));
std::cerr << std::setw(10) << o;
}
std::cerr << std::endl;
}
if(!res)
{
const float error_percent =
@@ -367,7 +451,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
res = false;
}
}
}
if(!res)
{
const float error_percent =

View File

@@ -225,7 +225,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// define ping, pong steps here as lambda functions.
auto MemoryOpsStep = [&](auto idx) {
auto MemoryOpsStep = [&]() {
// Memory read half here.
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
@@ -255,36 +255,28 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
{
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
}
// transfer from LDS to registers
if(idx == 0)
{
Base::LocalPrefetch(a_tile_0, a_lds_window);
Base::LocalPrefetch(b_tile_0, b_lds_window);
}
else
{
Base::LocalPrefetch(a_tile_1, a_lds_window);
Base::LocalPrefetch(b_tile_1, b_lds_window);
}
};
auto ComputeStep = [&](auto idx) {
if(idx == 0)
{
Base::LocalPrefetch(a_tile_0, a_lds_window);
Base::LocalPrefetch(b_tile_0, b_lds_window);
block_gemm(c_block_tile, a_tile_0, b_tile_0);
// tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile);
}
else
{
Base::LocalPrefetch(a_tile_1, a_lds_window);
Base::LocalPrefetch(b_tile_1, b_lds_window);
block_gemm(c_block_tile, a_tile_1, b_tile_1);
// tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile);
}
};
if(op_id == 0)
if(op_id == 1)
{
MemoryOpsStep(group_id);
MemoryOpsStep();
}
// start the main loop.
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop) * 2 - 1;
@@ -295,9 +287,9 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
block_sync_lds();
op_id = (op_id + 1) % num_stages_;
if(op_id == 0)
if(op_id == 1)
{
MemoryOpsStep(group_id);
MemoryOpsStep();
}
else
{
@@ -310,7 +302,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
// Handle Tail Number here.
block_sync_lds();
if(op_id == 0)
if(op_id == 1)
{
ComputeStep(group_id);
}