This commit is contained in:
Feng Shijie
2025-07-28 08:24:51 +00:00
parent 5473f06461
commit 1b6d7cf407
4 changed files with 163 additions and 97 deletions

View File

@@ -255,6 +255,34 @@ struct FlatmmKernel
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
FlatmmKernel,
FlatmmKernelArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, 0>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(M, N);
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(KBatch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, KBatch);
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
@@ -751,37 +779,67 @@ struct FlatmmKernel
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
int partition_idx = blockIdx.x) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// GWS
const int voffset = 0;
const int vdata = 1;
__shared__ int shared_part[1];
if(threadIdx.x == 0)
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
asm volatile("global_atomic_add %0, %1, %2, %3 sc0; \n\t"
"s_waitcnt vmcnt(0); \n\t"
: "=v"(partition_idx)
: "v"(voffset), "v"(vdata), "s"(kargs.a_ptr));
shared_part[0] = partition_idx % (1024 + 80);
}
block_sync_lds();
partition_idx = shared_part[0];
while(partition_idx < total_work_tile_cnt)
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
if(threadIdx.x == 0)
{
asm volatile("global_atomic_add %0, %1, %2, %3 sc0; \n\t"
"s_waitcnt vmcnt(0); \n\t"
: "=v"(partition_idx)
: "v"(voffset), "v"(vdata), "s"(kargs.a_ptr));
shared_part[0] = partition_idx % (1024 + 80);
}
block_sync_lds();
partition_idx = shared_part[0];
}
}
};