Sync with the debug branch

This commit is contained in:
Sudhir Kylasa
2025-03-26 04:06:30 -04:00
parent a93030a78d
commit 446d1ebeed
4 changed files with 184 additions and 27 deletions

View File

@@ -259,6 +259,46 @@ int run_gemm_example_with_layouts(int argc,
b_k_n.SetZero();
}
// set 1 column in A and 1 Row in B to perform outer product.
// and test the results.
//const ck_tile::index_t K_len = a_m_k.get_length(1);
const ck_tile::index_t M_len = a_m_k.get_length(0);
const ck_tile::index_t N_len = b_k_n.get_length(1);
// Fill 0th column in A
ck_tile::half_t dd = 1;
for(int i = 0; i < M_len; i++)
{
int j = 0;
{
a_m_k(i, j) = dd;
}
int k = 8;
{
a_m_k(i, k) = dd++;
}
}
// Fill 0th row in B
dd = 2;
int i = 0;
{
for(int j=0; j < N_len; j++)
{
b_k_n(i, j) = dd;
}
}
i = 8;
{
for(int j=0; j < N_len; j++)
{
b_k_n(i, j) = dd;
}
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());

View File

@@ -75,6 +75,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
CDataType,
CLayout,
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
/*
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -89,7 +103,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
UniversalGemmProblem::TransposeC>>;
*/
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -213,6 +228,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
@@ -253,6 +279,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
/*
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
if(a_layout == "R" && b_layout == "C")
@@ -273,11 +300,13 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
else
{
if(a_layout == "R" && b_layout == "R")
*/
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
argc, argv, Row{}, Col{}, Row{});
}
/*
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
@@ -293,11 +322,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
*/
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
//}
}
int run_gemm_example(int argc, char* argv[])
@@ -310,10 +340,19 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp16")
//if(data_type == "fp16")
{
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
//return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
//return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
//return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
// a_layout, b_layout, argc, argv);
}
/*
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
@@ -337,10 +376,12 @@ int run_gemm_example(int argc, char* argv[])
a_layout, b_layout, argc, argv);
}
#endif
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
*/
}
int main(int argc, char* argv[])