mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Fix pre-commit error
This commit is contained in:
6
example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp
Executable file → Normal file
6
example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.cpp
Executable file → Normal file
@@ -36,7 +36,8 @@
|
||||
* @param chunk_idx Index of chunk to signal
|
||||
* @param stream HIP stream for async operations
|
||||
*/
|
||||
[[maybe_unused]] static void signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream)
|
||||
[[maybe_unused]] static void
|
||||
signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream)
|
||||
{
|
||||
uint32_t ready = 1;
|
||||
ck_tile::hip_check_error(hipMemcpyAsync(
|
||||
@@ -67,7 +68,7 @@ int main(int argc, char* argv[])
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
|
||||
auto res = invoke_grouped_gemm_persistent_async<ck_tile::half_t>(
|
||||
a_layout, b_layout, data_type, arg_parser,
|
||||
, tiles_per_chunk_m, tile_idx_pivot_m);
|
||||
@@ -76,6 +77,5 @@ int main(int argc, char* argv[])
|
||||
|
||||
*/
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
206
example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp
Executable file → Normal file
206
example/ck_tile/17_grouped_gemm/grouped_gemm_persistent_async.hpp
Executable file → Normal file
@@ -3,106 +3,112 @@
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
|
||||
template <typename GroupedGemKernelParam,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename DsLayout,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
template <typename GroupedGemKernelParam, typename ADataType, typename BDataType, typename AccDataType, typename DsDataType, typename CDataType, typename DsLayout, typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName()
|
||||
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
|
||||
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
}
|
||||
|
||||
9
example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp
Executable file → Normal file
9
example/ck_tile/17_grouped_gemm/persistent_async_utils.hpp
Executable file → Normal file
@@ -43,19 +43,20 @@ CK_TILE_DEVICE static void wait_chunk_signal(const uint32_t* chunk_signals, inde
|
||||
if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0)
|
||||
{
|
||||
volatile const uint32_t* signal_ptr = chunk_signals + chunk_idx;
|
||||
|
||||
|
||||
// Poll until chunk is ready (signal == 1)
|
||||
// Use acquire semantics for proper memory ordering
|
||||
uint32_t signal_value;
|
||||
do {
|
||||
do
|
||||
{
|
||||
signal_value = __builtin_nontemporal_load(signal_ptr);
|
||||
__builtin_amdgcn_s_sleep(1); // Brief sleep to reduce contention
|
||||
} while(signal_value == 0);
|
||||
|
||||
|
||||
// Memory fence with acquire semantics
|
||||
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "agent");
|
||||
}
|
||||
|
||||
|
||||
// Barrier to release all threads in the workgroup
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user