mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
update f16xMXF4
This commit is contained in:
@@ -58,6 +58,46 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size,
|
||||
F16xMXF4FlatmmKernel,
|
||||
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
|
||||
Reference in New Issue
Block a user