diff --git a/CMakeLists.txt b/CMakeLists.txt index da5a86523e..e8add521b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -564,7 +564,7 @@ include_directories(BEFORE ${HIP_INCLUDE_DIRS} ) -SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") +SET(BUILD_DEV OFF CACHE BOOL "BUILD_DEV") if(BUILD_DEV) add_compile_options(-Werror) add_compile_options(-Weverything) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 60dced003b..e37f76cd8d 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -13,6 +13,56 @@ #include "flatmm_basic.hpp" #include +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8") + .insert("persistent", "0", "0: no persistent, 1: persistent kernel") + .insert("warp_tile", + "0", + "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +auto shuffle_b_v1(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; + constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp; + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Tile, + FlatmmConfig::N_Warp, + FlatmmConfig::N_Warp_Tile, + NRepeat, + k_ / FlatmmConfig::K_Warp_Tile, + divisor, + FlatmmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); +} + +#include "run_flatmm_example.inc" +/* template constexpr const char* DataTypeToString() { @@ -62,25 +112,6 @@ auto shuffle_b(const ck_tile::HostTensor& t) return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } -template -auto shuffle_b_v1(const ck_tile::HostTensor& t) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; - constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp; - ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Tile, - FlatmmConfig::N_Warp, - FlatmmConfig::N_Warp_Tile, - NRepeat, - k_ / FlatmmConfig::K_Warp_Tile, - divisor, - FlatmmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); -} - template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, @@ -101,6 +132,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K, // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } +*/ template & args, return ave_time; } +/* template typename FlatmmConfig> int run_flatmm_example(int argc, char* argv[]) diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index fd3c76170b..aeeecb2382 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -97,7 +97,10 @@ template + typename ScaleM, + typename ScaleN, + bool UsePersistentKernel = false, + typename CDEElementWise = ck_tile::element_wise::PassThrough> float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_shuffle_dev_buf, ck_tile::DeviceMem& c_dev_buf, @@ -108,21 +111,25 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::index_t stride_B, ck_tile::index_t stride_C, ck_tile::index_t kbatch, + ScaleM scale_m, + ScaleN scale_n, int n_warmup, int n_repeat) { - ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(), - b_shuffle_dev_buf.GetDeviceBuffer(), - {}, - c_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::ScaleFlatmmHostArgs args = {a_dev_buf.GetDeviceBuffer(), + b_shuffle_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C, + scale_m, + scale_n}; float ave_time = flatmm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); @@ -152,6 +161,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, return ave_time; } + template