mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
initial stream-k implementation with example (#699)
* initial stream-k implementation with example * fix unexpected change in err * improve a little bit performance by reorganize pipeline. * improve perf a little bit by swizzle block idx * add profiler * update example * fix spelling * shrink karg for streamk * support dynamic buffer using memory coherence glc_slc bit from template * control memory coherence while construct dynamic buffer * update reduction for streamk(not ready yet) * Add template parameter to make_dynamic_buffer to support amd_buffer coherence setting * fix build issue * fix several bug * now result is correct, everything works (but has scratch) * remove scratch by manually reset coordinate * update device code * fix a bug in final reduce * fix something in example * update async memset * fix enum as camel case * modify coherence enum name * clean code and use atomic streamk by default * remove unused var * throw exception if have empty pointer * fix format * fix CI warning * fix type in init * modify CI error * filter out on gfx10+ * restore changed example code --------- Co-authored-by: Qianfeng Zhang <Qianfeng.Zhang@amd.com>
This commit is contained in:
@@ -33,6 +33,19 @@ struct ProblemSize final
|
||||
ck::index_t StrideC = 4096;
|
||||
};
|
||||
|
||||
struct ProblemSizeStreamK final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -48,8 +61,17 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
inline bool
|
||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
template <typename ProblemType>
|
||||
bool parse_cmd_args(int, char*[], ProblemType&, ExecutionConfig&)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSize>(int argc,
|
||||
char* argv[],
|
||||
ProblemSize& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -87,3 +109,52 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
char* argv[],
|
||||
ProblemSizeStreamK& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc >= 10)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideC = std::stoi(argv[9]);
|
||||
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.NumSKBlocks = std::stoi(argv[10]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: NumSKBlocks(optional)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user