mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Sync with the debug branch
This commit is contained in:
@@ -259,6 +259,46 @@ int run_gemm_example_with_layouts(int argc,
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
|
||||
// set 1 column in A and 1 Row in B to perform outer product.
|
||||
// and test the results.
|
||||
//const ck_tile::index_t K_len = a_m_k.get_length(1);
|
||||
const ck_tile::index_t M_len = a_m_k.get_length(0);
|
||||
const ck_tile::index_t N_len = b_k_n.get_length(1);
|
||||
|
||||
// Fill 0th column in A
|
||||
ck_tile::half_t dd = 1;
|
||||
for(int i = 0; i < M_len; i++)
|
||||
{
|
||||
int j = 0;
|
||||
{
|
||||
a_m_k(i, j) = dd;
|
||||
}
|
||||
int k = 8;
|
||||
{
|
||||
a_m_k(i, k) = dd++;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill 0th row in B
|
||||
dd = 2;
|
||||
int i = 0;
|
||||
{
|
||||
for(int j=0; j < N_len; j++)
|
||||
{
|
||||
b_k_n(i, j) = dd;
|
||||
}
|
||||
}
|
||||
|
||||
i = 8;
|
||||
{
|
||||
for(int j=0; j < N_len; j++)
|
||||
{
|
||||
b_k_n(i, j) = dd;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -75,6 +75,20 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
|
||||
|
||||
/*
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -89,7 +103,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
*/
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -213,6 +228,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -253,6 +279,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
/*
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
@@ -273,11 +300,13 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
*/
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
/*
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
@@ -293,11 +322,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
*/
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
//}
|
||||
}
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
@@ -310,10 +340,19 @@ int run_gemm_example(int argc, char* argv[])
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
//if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
|
||||
//return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
|
||||
//return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
// a_layout, b_layout, argc, argv);
|
||||
|
||||
//return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
// a_layout, b_layout, argc, argv);
|
||||
}
|
||||
/*
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
@@ -337,10 +376,12 @@ int run_gemm_example(int argc, char* argv[])
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
#endif
|
||||
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
Reference in New Issue
Block a user