mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add persistent option on flatmm for tuning
This commit is contained in:
@@ -228,7 +228,8 @@ struct FlatmmKernel
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -255,37 +256,50 @@ struct FlatmmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
FlatmmKernel,
|
||||
FlatmmKernelArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, 0>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(M, N);
|
||||
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size
|
||||
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(KBatch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, KBatch);
|
||||
|
||||
assert(!UsePersistentKernel);
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
FlatmmKernel,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size
|
||||
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(KBatch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
@@ -371,6 +385,14 @@ struct FlatmmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -780,22 +802,8 @@ struct FlatmmKernel
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
// GWS
|
||||
const int voffset = 0;
|
||||
const int vdata = 1;
|
||||
__shared__ int shared_part[1];
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
asm volatile("global_atomic_add %0, %1, %2, %3 sc0; \n\t"
|
||||
"s_waitcnt vmcnt(0); \n\t"
|
||||
: "=v"(partition_idx)
|
||||
: "v"(voffset), "v"(vdata), "s"(kargs.a_ptr));
|
||||
shared_part[0] = partition_idx % (1024 + 80);
|
||||
}
|
||||
block_sync_lds();
|
||||
partition_idx = shared_part[0];
|
||||
|
||||
while(partition_idx < total_work_tile_cnt)
|
||||
do
|
||||
{
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
@@ -830,17 +838,8 @@ struct FlatmmKernel
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
asm volatile("global_atomic_add %0, %1, %2, %3 sc0; \n\t"
|
||||
"s_waitcnt vmcnt(0); \n\t"
|
||||
: "=v"(partition_idx)
|
||||
: "v"(voffset), "v"(vdata), "s"(kargs.a_ptr));
|
||||
shared_part[0] = partition_idx % (1024 + 80);
|
||||
}
|
||||
block_sync_lds();
|
||||
partition_idx = shared_part[0];
|
||||
}
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user