fix build

This commit is contained in:
Max Podkorytov
2026-01-05 18:21:35 -06:00
parent 03c030203e
commit 4c0afcd71e
2 changed files with 144 additions and 162 deletions

View File

@@ -58,106 +58,6 @@ signal_chunk_ready(uint32_t* signals, int chunk_idx, hipStream_t stream)
&signals[chunk_idx], &ready, sizeof(uint32_t), hipMemcpyHostToDevice, stream));
}
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
// Parse async-specific arguments
const bool enable_async = arg_parser.get_int("enable_async") != 0;
const ck_tile::index_t tiles_per_chunk_m = arg_parser.get_int("tiles_per_chunk_m");
const ck_tile::index_t tile_idx_pivot_m = arg_parser.get_int("tile_idx_pivot_m");
std::cout << "\n=== Async Parameters ===" << std::endl;
std::cout << " enable_async: " << (enable_async ? "YES (will allocate chunk signals)" : "NO")
<< std::endl;
std::cout << " tiles_per_chunk_m: " << tiles_per_chunk_m << std::endl;
std::cout << " tile_idx_pivot_m: " << tile_idx_pivot_m << std::endl;
// Create async args (chunk signals will be allocated in the example function)
ck_tile::PersistentAsyncArgs async_args(
tiles_per_chunk_m, nullptr, tile_idx_pivot_m, enable_async);
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Row{}, Col{}, Row{}, async_args);
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Row{}, Row{}, Row{}, async_args);
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Col{}, Row{}, Row{}, async_args);
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Col{}, Col{}, Row{}, async_args);
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(ck_tile::ArgParser& arg_parser)
{
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
@@ -378,7 +278,6 @@ int run_grouped_gemm_persistent_async_example(ck_tile::ArgParser& arg_parser,
// Launch persistent async kernel
std::cout << "\nLaunching persistent async GEMM kernel..." << std::endl;
const bool splitk = kbatch > 1;
float ave_time =
invoke_grouped_gemm_persistent<GemmConfig,
@@ -390,7 +289,7 @@ int run_grouped_gemm_persistent_async_example(ck_tile::ArgParser& arg_parser,
ck_tile::tuple<>,
ALayout,
BLayout,
CLayout>(stream, group_count, kargs_ptr, splitk);
CLayout>(stream, group_count, kargs_ptr);
std::size_t total_flops = 0;
std::size_t total_bytes = 0;
@@ -476,6 +375,106 @@ int run_grouped_gemm_persistent_async_example(ck_tile::ArgParser& arg_parser,
return 0;
}
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
// Parse async-specific arguments
const bool enable_async = arg_parser.get_int("enable_async") != 0;
const ck_tile::index_t tiles_per_chunk_m = arg_parser.get_int("tiles_per_chunk_m");
const ck_tile::index_t tile_idx_pivot_m = arg_parser.get_int("tile_idx_pivot_m");
std::cout << "\n=== Async Parameters ===" << std::endl;
std::cout << " enable_async: " << (enable_async ? "YES (will allocate chunk signals)" : "NO")
<< std::endl;
std::cout << " tiles_per_chunk_m: " << tiles_per_chunk_m << std::endl;
std::cout << " tile_idx_pivot_m: " << tile_idx_pivot_m << std::endl;
// Create async args (chunk signals will be allocated in the example function)
ck_tile::PersistentAsyncArgs async_args(
tiles_per_chunk_m, nullptr, tile_idx_pivot_m, enable_async);
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Row{}, Col{}, Row{}, async_args);
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Row{}, Row{}, Row{}, async_args);
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Col{}, Row{}, Row{}, async_args);
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_persistent_async_example<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(
arg_parser, Col{}, Col{}, Row{}, async_args);
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(ck_tile::ArgParser& arg_parser)
{
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);

View File

@@ -15,8 +15,7 @@ template <typename GroupedGemKernelParam,
typename CLayout>
float invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
void* kargs_ptr)
{
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer = false;
@@ -47,68 +46,52 @@ float invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
CLayout,
TransposeC>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
if(splitk)
if(s.log_level_ > 0)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
else
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
}