Add persistent option on flatmm for tuning

This commit is contained in:
Feng Shijie
2025-07-29 15:42:58 +00:00
parent a587701117
commit c117a1986a
4 changed files with 143 additions and 117 deletions

View File

@@ -214,7 +214,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
@@ -271,10 +271,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
}
else
{
// ave_time =
// ck_tile::launch_kernel(s,
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
// Kernel{}, grids, blocks, 0, kargs));
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -289,10 +289,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
}
else
{
// Run(has_hot_loop_,
// tail_number_,
// ck_tile::integral_constant<ck_tile::memory_operation_enum,
// ck_tile::memory_operation_enum::atomic_add>{});
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
@@ -311,7 +311,8 @@ template <typename FlatmmConfig,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
bool UsePersistentKernel = false,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
@@ -354,7 +355,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
CLayout,
ScaleM,
ScaleN,
false,
UsePersistentKernel,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
@@ -393,6 +394,7 @@ auto create_args(int argc, char* argv[])
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
@@ -416,45 +418,67 @@ int run_flatmm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int scale_opt = arg_parser.get_int("scale");
int persistent_opt = arg_parser.get_int("persistent");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
// run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
// argc, argv, Row{}, Col{}, Row{});
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
argc, argv, Row{}, Col{}, Row{});
}
// else if(data_type == "bf16")
// {
// run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
else if(data_type == "fp8")
{
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
if(persistent_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
-1,
-1,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else
{
if(persistent_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
1,
1,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
}
else if(data_type == "bf8")
{
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
// else if(data_type == "bf8")
// {
// if(scale_opt == 0)
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
// argc, argv, Row{}, Col{}, Row{});
// }
// else
// {
// run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1,
// 1>(
// argc, argv, Row{}, Col{}, Row{});
// }
// }
else
{
throw std::runtime_error("Unsupported data_type!");
@@ -480,18 +504,18 @@ int main(int argc, char* argv[])
{
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_flatmm_example<FlatmmConfig32>(argc, argv);
// }
// else if(warp_tile == 2)
// {
// return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
// }
// else
// {
// return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
// }
else if(warp_tile == 1)
{
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)
{
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
}
else
{
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
}
catch(const std::runtime_error& e)
{

View File

@@ -4,8 +4,9 @@
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
bool UsePersistentKernel = false,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -222,20 +223,21 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::tuple<>,
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
decltype(per_channel_scale_dev_ptr),
UsePersistentKernel>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());

View File

@@ -228,7 +228,8 @@ struct FlatmmKernel
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
@@ -255,37 +256,50 @@ struct FlatmmKernel
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
FlatmmKernel,
FlatmmKernelArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, 0>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(M, N);
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(KBatch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, KBatch);
assert(!UsePersistentKernel);
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
{
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
FlatmmKernel,
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(KBatch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
}
else
{
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
template <class ScaleM, class ScaleN>
@@ -371,6 +385,14 @@ struct FlatmmKernel
return false;
}
}
if constexpr(UsePersistentKernel)
{
if(kargs.k_batch != 1)
{
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
return false;
}
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
@@ -780,22 +802,8 @@ struct FlatmmKernel
int partition_idx = blockIdx.x) const
{
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// GWS
const int voffset = 0;
const int vdata = 1;
__shared__ int shared_part[1];
if(threadIdx.x == 0)
{
asm volatile("global_atomic_add %0, %1, %2, %3 sc0; \n\t"
"s_waitcnt vmcnt(0); \n\t"
: "=v"(partition_idx)
: "v"(voffset), "v"(vdata), "s"(kargs.a_ptr));
shared_part[0] = partition_idx % (1024 + 80);
}
block_sync_lds();
partition_idx = shared_part[0];
while(partition_idx < total_work_tile_cnt)
do
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
@@ -830,17 +838,8 @@ struct FlatmmKernel
i_m,
i_n);
}
if(threadIdx.x == 0)
{
asm volatile("global_atomic_add %0, %1, %2, %3 sc0; \n\t"
"s_waitcnt vmcnt(0); \n\t"
: "=v"(partition_idx)
: "v"(voffset), "v"(vdata), "s"(kargs.a_ptr));
shared_part[0] = partition_idx % (1024 + 80);
}
block_sync_lds();
partition_idx = shared_part[0];
}
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
}
};

View File

@@ -80,6 +80,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t kLdsAlignmentInBytes = 16;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();