mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Merge commit '7ea1508b59a0e8f89540d8d5f7eb3e7da9a50a62' into develop
This commit is contained in:
@@ -18,9 +18,12 @@ template <typename ADataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename CLayout,
|
||||
bool Persistent>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
if constexpr(Persistent)
|
||||
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
|
||||
@@ -213,11 +213,20 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
bool Persistent = false>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -162,7 +162,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
int n_repeat,
|
||||
bool persistent)
|
||||
{
|
||||
ck_tile::GemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
@@ -176,9 +177,31 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time =
|
||||
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time;
|
||||
if(persistent)
|
||||
{
|
||||
ave_time = gemm_calc<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
true>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = gemm_calc<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
false>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -193,8 +216,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
<< " B_Type=" << DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -229,6 +252,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -316,7 +340,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
n_repeat,
|
||||
persistent);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
@@ -11,19 +11,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename Pipeline, ck_tile::TailNumber TN>
|
||||
void try_run(ck_tile::TailNumber tn)
|
||||
{
|
||||
if constexpr(Pipeline::PrefetchStages > static_cast<int>(TN))
|
||||
{
|
||||
if(tn == TN)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, TN>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -31,7 +19,8 @@ template <typename ADataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename CLayout,
|
||||
bool Persistent>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -60,7 +49,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity>;
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
@@ -74,64 +64,113 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
@@ -150,101 +189,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
|
||||
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
||||
};
|
||||
|
||||
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
|
||||
@@ -183,137 +183,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
|
||||
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
|
||||
"got "
|
||||
<< tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -197,121 +197,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
|
||||
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
|
||||
<< "got " << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -115,9 +116,47 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time{0};
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
if(args.k_batch == 1)
|
||||
@@ -132,8 +171,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
@@ -133,4 +133,11 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -122,7 +122,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
|
||||
float ave_time =
|
||||
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
|
||||
Reference in New Issue
Block a user