mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Sync with the debug branch
This commit is contained in:
@@ -259,6 +259,46 @@ int run_gemm_example_with_layouts(int argc,
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
|
||||
// set 1 column in A and 1 Row in B to perform outer product.
|
||||
// and test the results.
|
||||
//const ck_tile::index_t K_len = a_m_k.get_length(1);
|
||||
const ck_tile::index_t M_len = a_m_k.get_length(0);
|
||||
const ck_tile::index_t N_len = b_k_n.get_length(1);
|
||||
|
||||
// Fill 0th column in A
|
||||
ck_tile::half_t dd = 1;
|
||||
for(int i = 0; i < M_len; i++)
|
||||
{
|
||||
int j = 0;
|
||||
{
|
||||
a_m_k(i, j) = dd;
|
||||
}
|
||||
int k = 8;
|
||||
{
|
||||
a_m_k(i, k) = dd++;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill 0th row in B
|
||||
dd = 2;
|
||||
int i = 0;
|
||||
{
|
||||
for(int j=0; j < N_len; j++)
|
||||
{
|
||||
b_k_n(i, j) = dd;
|
||||
}
|
||||
}
|
||||
|
||||
i = 8;
|
||||
{
|
||||
for(int j=0; j < N_len; j++)
|
||||
{
|
||||
b_k_n(i, j) = dd;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -75,6 +75,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
|
||||
/*
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -89,7 +103,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
*/
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -213,6 +228,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -253,6 +279,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
/*
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
@@ -273,11 +300,13 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
*/
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
/*
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
@@ -293,11 +322,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
*/
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
//}
|
||||
}
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
@@ -310,10 +340,19 @@ int run_gemm_example(int argc, char* argv[])
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
//if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
|
||||
//return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
|
||||
//return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
// a_layout, b_layout, argc, argv);
|
||||
|
||||
//return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
// a_layout, b_layout, argc, argv);
|
||||
}
|
||||
/*
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
@@ -337,10 +376,12 @@ int run_gemm_example(int argc, char* argv[])
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
#endif
|
||||
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -186,7 +186,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 20)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -194,6 +194,34 @@ check_err(const Range& out,
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
int total_err_count = err_count - 20;
|
||||
err_count = 0;
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count >= total_err_count)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
for(std::size_t i = 0; i < ref.size() / 32; i++)
|
||||
{
|
||||
for(std::size_t j = 0; j < 32; j++)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i * 32 + j));
|
||||
std::cerr << std::setw(10) << o;
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
@@ -246,7 +274,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 20)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -254,6 +282,34 @@ check_err(const Range& out,
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
int total_err_count = err_count - 20;
|
||||
err_count = 0;
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count >= total_err_count)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
for(std::size_t i = 0; i < ref.size() / 32; i++)
|
||||
{
|
||||
for(std::size_t j = 0; j < 32; j++)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i * 32 + j));
|
||||
std::cerr << std::setw(10) << o;
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
@@ -305,7 +361,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < 20)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -313,6 +369,34 @@ check_err(const Range& out,
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
int total_err_count = err_count - 20;
|
||||
err_count = 0;
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
||||
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
||||
err = std::abs(o - r);
|
||||
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count >= total_err_count)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
for(std::size_t i = 0; i < ref.size() / 32; i++)
|
||||
{
|
||||
for(std::size_t j = 0; j < 32; j++)
|
||||
{
|
||||
const double o = type_convert<float>(*std::next(std::begin(out), i * 32 + j));
|
||||
std::cerr << std::setw(10) << o;
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
@@ -367,7 +451,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
|
||||
@@ -225,7 +225,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
auto MemoryOpsStep = [&]() {
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
@@ -255,36 +255,28 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
}
|
||||
|
||||
// transfer from LDS to registers
|
||||
if(idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
}
|
||||
};
|
||||
|
||||
auto ComputeStep = [&](auto idx) {
|
||||
if(idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
block_gemm(c_block_tile, a_tile_0, b_tile_0);
|
||||
// tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
block_gemm(c_block_tile, a_tile_1, b_tile_1);
|
||||
// tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
if(op_id == 0)
|
||||
if(op_id == 1)
|
||||
{
|
||||
MemoryOpsStep(group_id);
|
||||
MemoryOpsStep();
|
||||
}
|
||||
// start the main loop.
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop) * 2 - 1;
|
||||
@@ -295,9 +287,9 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
block_sync_lds();
|
||||
op_id = (op_id + 1) % num_stages_;
|
||||
|
||||
if(op_id == 0)
|
||||
if(op_id == 1)
|
||||
{
|
||||
MemoryOpsStep(group_id);
|
||||
MemoryOpsStep();
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -310,7 +302,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
|
||||
// Handle Tail Number here.
|
||||
block_sync_lds();
|
||||
if(op_id == 0)
|
||||
if(op_id == 1)
|
||||
{
|
||||
ComputeStep(group_id);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user