mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
merge M grouped flatmm
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
|
||||
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
|
||||
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
|
||||
@@ -11,3 +12,4 @@ list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo)
|
||||
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
|
||||
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1")
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
382
example/ck_tile/18_flatmm/grouped_flatmm.cpp
Normal file
382
example/ck_tile/18_flatmm/grouped_flatmm.cpp
Normal file
@@ -0,0 +1,382 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "flatmm_basic.hpp"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "512,256,1024", "m dimension")
|
||||
.insert("Ns", "512,512,512", "n dimension")
|
||||
.insert("Ks", "1024,1024,512", "k dimension")
|
||||
.insert("group_count", "3", "group count")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("mode", "general", "grouped gemm mode: [general | contiguous], general by default")
|
||||
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
typename KernelArguments>
|
||||
float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel =
|
||||
ck_tile::GroupedFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
// if(!Kernel::IsSupportedArgument(kargs))
|
||||
// {
|
||||
// throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
// }
|
||||
|
||||
// if(s.flush_cache_)
|
||||
// {
|
||||
// std::cout << "Flushing cache..." << std::endl;
|
||||
// static constexpr ck_tile::index_t APackedSize =
|
||||
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
// static constexpr ck_tile::index_t BPackedSize =
|
||||
// std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
// ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
// args.group_count * args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
// ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
// args.K, args.group_count * args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
// auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
// auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
// ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
// kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
// rotating_mem.Print();
|
||||
|
||||
// auto run_flush_cache = [&]() {
|
||||
// // flush icache
|
||||
// ck_tile::flush_icache();
|
||||
// // rotating mem
|
||||
// rotating_mem.Next();
|
||||
// // clear c mem
|
||||
// if(args.k_batch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(
|
||||
// args.e_ptr, 0, args.group_count * args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
// };
|
||||
// ave_time = ck_tile::launch_kernel_preprocess(
|
||||
// s,
|
||||
// run_flush_cache,
|
||||
// ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
// Kernel{}, grids, blocks, 0, kargs));
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
// }
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_flatmm_example.inc"
|
||||
|
||||
template<template <typename PreType> typename FlatmmConfig>
|
||||
int run_grouped_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string mode = arg_parser.get_str("mode");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int scale_opt = arg_parser.get_int("scale");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mode == "general")
|
||||
{
|
||||
// if(data_type == "fp16")
|
||||
// {
|
||||
// run_grouped_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(data_type == "bf16")
|
||||
// {
|
||||
// run_grouped_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(data_type == "fp8")
|
||||
// {
|
||||
// run_grouped_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else if(data_type == "bf8")
|
||||
// {
|
||||
// run_grouped_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
// argc, argv, Row{}, Col{}, Row{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// throw std::runtime_error("Unsupported data_type!");
|
||||
// }
|
||||
}
|
||||
else if(mode == "contiguous")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(mode == "masked")
|
||||
{
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported mode!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_grouped_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
// }
|
||||
// else if(warp_tile == 2)
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
935
example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc
Normal file
935
example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc
Normal file
@@ -0,0 +1,935 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
// template <typename FlatmmConfig,
|
||||
// typename ADataType,
|
||||
// typename BDataType,
|
||||
// typename DsDatatype,
|
||||
// typename AccDataType,
|
||||
// typename CDataType,
|
||||
// typename ALayout,
|
||||
// typename BLayout,
|
||||
// typename DsLayout,
|
||||
// typename CLayout,
|
||||
// typename ScaleM,
|
||||
// typename ScaleN,
|
||||
// typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
// float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::GroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
|
||||
// {
|
||||
// float ave_time =
|
||||
// grouped_flatmm<FlatmmConfig,
|
||||
// ADataType,
|
||||
// BDataType,
|
||||
// DsDatatype,
|
||||
// AccDataType,
|
||||
// CDataType,
|
||||
// ALayout,
|
||||
// BLayout,
|
||||
// DsLayout,
|
||||
// CLayout,
|
||||
// false,
|
||||
// CDEElementWise>(
|
||||
// args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
// std::string op_name{"Grouped Gemm"};
|
||||
|
||||
// std::size_t flop = 0, num_btype = 0;
|
||||
// for(int j = 0; j < args.group_count; ++j)
|
||||
// {
|
||||
// flop += std::size_t(2) * args.M[j] * args.N[j] * args.K[j];
|
||||
|
||||
// num_btype += sizeof(ADataType) * args.M[j] * args.K[j] +
|
||||
// sizeof(BDataType) * args.K[j] * args.N[j] +
|
||||
// sizeof(CDataType) * args.M[j] * args.N[j];
|
||||
// }
|
||||
|
||||
// float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
// float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
// std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
// << gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
// return ave_time;
|
||||
// }
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
|
||||
{
|
||||
float ave_time =
|
||||
grouped_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup, int n_repeat, int val_m, const ck_tile::MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
|
||||
{
|
||||
float ave_time =
|
||||
grouped_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * val_m * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * val_m * args.K +
|
||||
sizeof(BDataType) * args.N * args.K * args.group_count +
|
||||
sizeof(CDataType) * val_m * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// template <typename PrecType,
|
||||
// typename FlatmmConfig,
|
||||
// int ScaleGranularityM = -1,
|
||||
// int ScaleGranularityN = -1,
|
||||
// typename ALayout,
|
||||
// typename BLayout,
|
||||
// typename CLayout>
|
||||
// int run_grouped_flatmm_example_with_layouts(int argc,
|
||||
// char* argv[],
|
||||
// const ALayout a_layout = ALayout{},
|
||||
// const BLayout b_layout = BLayout{},
|
||||
// [[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
// {
|
||||
// auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
// if(!result)
|
||||
// {
|
||||
// return -1;
|
||||
// };
|
||||
|
||||
// using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
// using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
// using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
// using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
// const int group_count = arg_parser.get_int("group_count");
|
||||
// const int repeat = arg_parser.get_int("repeat");
|
||||
// const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
// std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
// std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
// std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
// std::vector<ck_tile::index_t> stride_As;
|
||||
// std::vector<ck_tile::index_t> stride_Bs;
|
||||
// std::vector<ck_tile::index_t> stride_Cs;
|
||||
// ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
// std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
// std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
// std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
// std::vector<ck_tile::HostTensor<AccDataType>> per_token_scales;
|
||||
// std::vector<ck_tile::HostTensor<AccDataType>> per_channel_scales;
|
||||
|
||||
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_shfl_dev_buf;
|
||||
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> per_token_scales_dev_buf;
|
||||
// std::vector<std::unique_ptr<ck_tile::DeviceMem>> per_channel_scales_dev_buf;
|
||||
|
||||
// std::vector<void*> group_a_ptrs;
|
||||
// std::vector<void*> group_b_ptrs;
|
||||
// std::vector<void*> group_c_ptrs;
|
||||
// std::vector<ck_tile::FlatmmScalePointer<ScaleGranularityM>> group_scale_a_ptrs;
|
||||
// std::vector<ck_tile::FlatmmScalePointer<ScaleGranularityN>> group_scale_b_ptrs;
|
||||
|
||||
// if(!(int(Ms.size()) == group_count && int(Ns.size()) == group_count &&
|
||||
// int(Ks.size()) == group_count))
|
||||
// {
|
||||
// std::cout << "Please check the input data." << std::endl;
|
||||
// for(int i = 0; i < group_count; i++)
|
||||
// {
|
||||
// Ms.push_back(256 + 256 * i);
|
||||
// Ns.push_back(128 + 128 * i);
|
||||
// Ks.push_back(512 + 512 * i);
|
||||
// }
|
||||
// }
|
||||
|
||||
// for(int i = 0; i < group_count; ++i)
|
||||
// {
|
||||
// const ck_tile::index_t M = Ms[i];
|
||||
// const ck_tile::index_t N = Ns[i];
|
||||
// const ck_tile::index_t K = Ks[i];
|
||||
|
||||
// stride_As.push_back(ck_tile::get_default_stride(M, K, 0, is_row_major(a_layout)));
|
||||
// stride_Bs.push_back(ck_tile::get_default_stride(K, N, 0, is_row_major(b_layout)));
|
||||
// stride_Cs.push_back(ck_tile::get_default_stride(M, N, 0, is_row_major(c_layout)));
|
||||
|
||||
// a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
// ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
// b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
// ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
|
||||
// c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
// ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(c_layout))));
|
||||
|
||||
// per_token_scales.push_back(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
// per_channel_scales.push_back(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
// std::cout << "gemm[" << i << "]"
|
||||
// << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
// << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
|
||||
// ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
// ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_k_n_tensors[i]);
|
||||
// ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(per_token_scales[i]);
|
||||
// ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(per_channel_scales[i]);
|
||||
// ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensors[i]);
|
||||
|
||||
// a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
// a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
// b_shfl_dev_buf.push_back(
|
||||
// std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
// c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
// c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
|
||||
// a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
// b_shfl_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
// c_m_n_tensors[i].SetZero();
|
||||
// c_m_n_dev_buf[i]->SetZero();
|
||||
// per_token_scales_dev_buf[i]->ToDevice(per_token_scales[i].data());
|
||||
// per_channel_scales_dev_buf[i]->ToDevice(per_channel_scales[i].data());
|
||||
// auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
// static_cast<float*>(per_token_scales_dev_buf[i]->GetDeviceBuffer())};
|
||||
// auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
// static_cast<float*>(per_channel_scales_dev_buf[i]->GetDeviceBuffer())};
|
||||
|
||||
// group_a_ptrs.push_back(a_m_k_dev_buf[i]->GetDeviceBuffer());
|
||||
// group_b_ptrs.push_back(b_shfl_dev_buf[i]->GetDeviceBuffer());
|
||||
// group_c_ptrs.push_back(c_m_n_dev_buf[i]->GetDeviceBuffer());
|
||||
// group_scale_a_ptrs.push_back(per_token_scale_dev_ptr);
|
||||
// group_scale_b_ptrs.push_back(per_channel_scale_dev_ptr);
|
||||
|
||||
|
||||
// }
|
||||
|
||||
// ck_tile::DeviceMem group_m_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
// ck_tile::DeviceMem group_n_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
// ck_tile::DeviceMem group_k_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
// ck_tile::DeviceMem group_a_ptrs_dev_buf(group_count * sizeof(void*));
|
||||
// ck_tile::DeviceMem group_b_ptrs_dev_buf(group_count * sizeof(void*));
|
||||
// ck_tile::DeviceMem group_c_ptrs_dev_buf(group_count * sizeof(void*));
|
||||
// ck_tile::DeviceMem group_stride_a_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
// ck_tile::DeviceMem group_stride_b_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
// ck_tile::DeviceMem group_stride_c_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
// ck_tile::DeviceMem group_scale_a_ptrs_dev_buf(group_count * sizeof(void*));
|
||||
// ck_tile::DeviceMem group_scale_b_ptrs_dev_buf(group_count * sizeof(void*));
|
||||
|
||||
// group_m_dev_buf.ToDevice(Ms.data());
|
||||
// group_n_dev_buf.ToDevice(Ns.data());
|
||||
// group_k_dev_buf.ToDevice(Ks.data());
|
||||
// group_a_ptrs_dev_buf.ToDevice(group_a_ptrs.data());
|
||||
// group_b_ptrs_dev_buf.ToDevice(group_b_ptrs.data());
|
||||
// group_c_ptrs_dev_buf.ToDevice(group_c_ptrs.data());
|
||||
// group_stride_a_dev_buf.ToDevice(stride_As.data());
|
||||
// group_stride_b_dev_buf.ToDevice(stride_Bs.data());
|
||||
// group_stride_c_dev_buf.ToDevice(stride_Cs.data());
|
||||
// group_scale_a_ptrs_dev_buf.ToDevice(group_scale_a_ptrs.data());
|
||||
// group_scale_b_ptrs_dev_buf.ToDevice(group_scale_b_ptrs.data());
|
||||
|
||||
// using ScaleAPtr = decltype(group_scale_a_ptrs[0]);
|
||||
// using ScaleBPtr = decltype(group_scale_b_ptrs[0]);
|
||||
|
||||
// ck_tile::GroupedFlatmmHostArgs<ScaleAPtr, ScaleBPtr> kernal_args{
|
||||
// group_count,
|
||||
// static_cast<ck_tile::index_t*>(group_m_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<ck_tile::index_t*>(group_n_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<ck_tile::index_t*>(group_k_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<const void**>(group_a_ptrs_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<ck_tile::index_t*>(group_stride_a_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<const void**>(group_b_ptrs_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<ck_tile::index_t*>(group_stride_b_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<void**>(group_c_ptrs_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<ck_tile::index_t*>(group_stride_c_dev_buf.GetDeviceBuffer()),
|
||||
// kbatch,
|
||||
// static_cast<ScaleAPtr*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
// static_cast<ScaleBPtr*>(per_channel_scale_dev_buf.GetDeviceBuffer())
|
||||
// };
|
||||
|
||||
// invoke_gemm<FlatmmConfig,
|
||||
// ADataType,
|
||||
// BDataType,
|
||||
// ck_tile::tuple<>,
|
||||
// AccDataType,
|
||||
// CDataType,
|
||||
// ALayout,
|
||||
// BLayout,
|
||||
// ck_tile::tuple<>,
|
||||
// CLayout,
|
||||
// ScaleAPtr,
|
||||
// ScaleBPtr>(
|
||||
// warmup, repeat, kernal_args);
|
||||
|
||||
// for(int i = 0; i < group_count; i++)
|
||||
// {
|
||||
// c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
// }
|
||||
|
||||
// bool pass{true};
|
||||
// if(arg_parser.get_int("v") == 1)
|
||||
// {
|
||||
// for(int i = 0; i < group_count; ++i)
|
||||
// {
|
||||
// ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
// Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
|
||||
// c_m_n_host_ref.SetZero();
|
||||
// ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
// a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
// const float max_accumulated_value =
|
||||
// *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
// const auto rtol_atol =
|
||||
// calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
// Ks[i], 1 /*kbatch*/, max_accumulated_value);
|
||||
// pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
// c_m_n_host_ref,
|
||||
// "Error: Incorrect results!",
|
||||
// rtol_atol.at(ck_tile::number<0>{}),
|
||||
// rtol_atol.at(ck_tile::number<1>{}));
|
||||
// std::cout << "gemm[" << i
|
||||
// << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
// << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
// << std::endl;
|
||||
// }
|
||||
// std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
// }
|
||||
// else if(arg_parser.get_int("v") == 2)
|
||||
// {
|
||||
// for(int i = 0; i < group_count; ++i)
|
||||
// {
|
||||
// ck_tile::index_t M = Ms[i];
|
||||
// ck_tile::index_t N = Ns[i];
|
||||
// ck_tile::index_t K = Ks[i];
|
||||
// ck_tile::index_t stride_A = stride_As[i];
|
||||
// ck_tile::index_t stride_B = stride_Bs[i];
|
||||
// ck_tile::index_t stride_C = stride_Cs[i];
|
||||
|
||||
// ADataType* d_A;
|
||||
// BDataType* d_B;
|
||||
// CDataType* d_C;
|
||||
|
||||
// ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
|
||||
// ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
// ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
|
||||
// ck_tile::hip_check_error(hipMemcpy(
|
||||
// d_A, a_m_k_tensors[i].data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice));
|
||||
// ck_tile::hip_check_error(hipMemcpy(
|
||||
// d_B, b_k_n_tensors[i].data(), N * K * sizeof(BDataType), hipMemcpyHostToDevice));
|
||||
|
||||
// ck_tile::reference_gemm_gpu<ADataType,
|
||||
// BDataType,
|
||||
// AccDataType,
|
||||
// CDataType,
|
||||
// ALayout,
|
||||
// BLayout,
|
||||
// CLayout>(
|
||||
// d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
// ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
// ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
// ck_tile::hip_check_error(hipMemcpy(
|
||||
// c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
|
||||
|
||||
// ck_tile::hip_check_error(hipFree(d_A));
|
||||
// ck_tile::hip_check_error(hipFree(d_B));
|
||||
// ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
// const float max_accumulated_value =
|
||||
// *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
|
||||
// const auto rtol_atol =
|
||||
// calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
// K, kbatch, max_accumulated_value);
|
||||
|
||||
// float rtol = 1e-3;
|
||||
// float atol = 1e-3;
|
||||
|
||||
// pass = ck_tile::check_err(
|
||||
// c_m_n_tensors[i], c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
// std::cout << "gemm[" << i << "]\nRelative error threshold: " << rtol
|
||||
// << " Absolute error threshold: " << atol << std::endl;
|
||||
// std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail")
|
||||
// << std::endl;
|
||||
// }
|
||||
// }
|
||||
|
||||
// return pass;
|
||||
// }
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int BlockM = FlatmmConfig::M_Tile;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
|
||||
if(!(int(Ms.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
// padding additional Ms if needed
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 64 * i);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M =
|
||||
std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
|
||||
// round up to the multiple of BlockM
|
||||
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
});
|
||||
std::cout << "Total M: " << M << std::endl;
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
ck_tile::index_t stride_A = 0;
|
||||
ck_tile::index_t stride_B = 0;
|
||||
ck_tile::index_t stride_C = 0;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tensor(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(M);
|
||||
int indices_fill_start = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
int group_m = Ms[i];
|
||||
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
for(int j = 0; j < padded_group_m; j++)
|
||||
{
|
||||
m_indices[indices_fill_start + j] = j < group_m ? i : -1; // -1 for padding
|
||||
}
|
||||
indices_fill_start += padded_group_m;
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
|
||||
constexpr int N_Warp_Tile = FlatmmConfig::N_Warp_Tile;
|
||||
assert(N % N_Warp_Tile == 0 &&
|
||||
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_dev_buf->SetZero();
|
||||
|
||||
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
|
||||
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
ck_tile::ContiguousGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr), decltype(per_channel_scale_dev_ptr)> kernal_args{
|
||||
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_m_k_dev_buf->GetDeviceBuffer(),
|
||||
stride_A,
|
||||
b_shfl_dev_buf->GetDeviceBuffer(),
|
||||
stride_B,
|
||||
{},{},
|
||||
c_m_n_dev_buf->GetDeviceBuffer(),
|
||||
stride_C,
|
||||
kbatch,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())
|
||||
};
|
||||
|
||||
invoke_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>(
|
||||
warmup, repeat, kernal_args);
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Not support v=1 host verification in contiguous grouped gemm, use "
|
||||
"v=2 device verification instead");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::index_t acc_m = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::index_t padded_M = (Ms[i] + BlockM - 1) / BlockM * BlockM;
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_tensor.data() + i * N * K,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + acc_m * K,
|
||||
d_B,
|
||||
d_C + acc_m * N,
|
||||
padded_M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
acc_m += padded_M;
|
||||
}
|
||||
ck_tile::hip_check_error(hipMemcpy(
|
||||
c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_masked_grouped_flatmm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int BlockM = FlatmmConfig::M_Tile;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
|
||||
if(!(int(Ms.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
// padding additional Ms if needed
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 64 * i);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M = 4096; // Ms[0];
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
ck_tile::index_t stride_A = K;
|
||||
ck_tile::index_t stride_B = K;
|
||||
ck_tile::index_t stride_C = N;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(group_count * M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(group_count * M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tensor(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({group_count * M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({group_count * N}, {1}));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(group_count);
|
||||
int indices_fill_start = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
int group_m = Ms[i];
|
||||
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
for(int j = 0; j < padded_group_m; j++)
|
||||
{
|
||||
m_indices[i] = group_m;
|
||||
// m_indices[i] = padded_group_m; // -1 for padding
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
|
||||
constexpr int N_Warp_Tile = FlatmmConfig::N_Warp_Tile;
|
||||
assert(N % N_Warp_Tile == 0 &&
|
||||
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
c_m_n_dev_buf->SetZero();
|
||||
|
||||
ck_tile::DeviceMem m_indices_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
|
||||
|
||||
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
ck_tile::MaskedGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr), decltype(per_channel_scale_dev_ptr)> kernal_args{
|
||||
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
group_count,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_m_k_dev_buf->GetDeviceBuffer(),
|
||||
stride_A,
|
||||
b_shfl_dev_buf->GetDeviceBuffer(),
|
||||
stride_B,
|
||||
{},{},
|
||||
c_m_n_dev_buf->GetDeviceBuffer(),
|
||||
stride_C,
|
||||
kbatch,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())
|
||||
};
|
||||
int sum_val_m = 0;
|
||||
for (int gi = 0; gi < group_count; gi++)
|
||||
{
|
||||
sum_val_m += m_indices[gi];
|
||||
}
|
||||
|
||||
invoke_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>(
|
||||
warmup, repeat, sum_val_m, kernal_args);
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Not support v=1 host verification in contiguous grouped gemm, use "
|
||||
"v=2 device verification instead");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::index_t acc_m = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_tensor.data() + i * N * K,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
|
||||
{
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
|
||||
d_B,
|
||||
d_C + i * M * N,
|
||||
m_indices[i],
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
|
||||
d_B,
|
||||
d_C + i * M * N,
|
||||
m_indices[i],
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
ScaleGranularityM,
|
||||
ScaleGranularityN,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()) + i * M,
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())) + i * N;
|
||||
}
|
||||
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N,
|
||||
d_C + i * M * N,
|
||||
M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -32,6 +32,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
|
||||
__launch_bounds__(MaxThreadPerBlock) __global__ void kentry2(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
|
||||
//
|
||||
// return a anonymous functor(lambda) to be called later
|
||||
// the KernelImpl should be a class without non-static data member, or let's say
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
6
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Executable file → Normal file
6
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Executable file → Normal file
@@ -244,7 +244,9 @@ struct FlatmmKernel
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
template<class ScaleM, class ScaleN>
|
||||
using KernelArgs = FlatmmKernelArgs<ScaleM, ScaleN, DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -751,7 +753,7 @@ struct FlatmmKernel
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
|
||||
465
include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp
Normal file
465
include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp
Normal file
@@ -0,0 +1,465 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct GroupedFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST GroupedFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST GroupedFlatmmHostArgs(index_t group_count_,
|
||||
index_t* M_,
|
||||
index_t* N_,
|
||||
index_t* K_,
|
||||
const void** a_ptr_,
|
||||
index_t* stride_A_,
|
||||
const void** b_shuffle_ptr_,
|
||||
index_t* stride_B_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
void** c_ptr_,
|
||||
index_t* stride_C_,
|
||||
index_t k_batch_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: group_count(group_count_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
a_ptr(a_ptr_),
|
||||
stride_A(stride_A_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
stride_B(stride_B_),
|
||||
ds_ptr(ds_ptr_),
|
||||
stride_Ds(stride_Ds_),
|
||||
c_ptr(c_ptr_),
|
||||
stride_C(stride_C_),
|
||||
k_batch(k_batch_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t group_count;
|
||||
index_t* M;
|
||||
index_t* N;
|
||||
index_t* K;
|
||||
const void** a_ptr;
|
||||
index_t* stride_A;
|
||||
const void** b_shuffle_ptr;
|
||||
index_t* stride_B;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
void** e_ptr;
|
||||
void** c_ptr;
|
||||
};
|
||||
index_t* stride_C;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct ContiguousGroupedFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t* M_indices_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const void* a_ptr_,
|
||||
index_t stride_A_,
|
||||
const void* b_shuffle_ptr_,
|
||||
index_t stride_B_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
void* c_ptr_,
|
||||
index_t stride_C_,
|
||||
index_t k_batch_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: M_indices(M_indices_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
a_ptr(a_ptr_),
|
||||
stride_A(stride_A_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
stride_B(stride_B_),
|
||||
ds_ptr(ds_ptr_),
|
||||
stride_Ds(stride_Ds_),
|
||||
c_ptr(c_ptr_),
|
||||
stride_C(stride_C_),
|
||||
k_batch(k_batch_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t* M_indices;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const void* a_ptr;
|
||||
index_t stride_A;
|
||||
const void* b_shuffle_ptr;
|
||||
index_t stride_B;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct MaskedGroupedFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST MaskedGroupedFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST MaskedGroupedFlatmmHostArgs(index_t* M_indices_,
|
||||
index_t group_count_,
|
||||
index_t Max_M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const void* a_ptr_,
|
||||
index_t stride_A_,
|
||||
const void* b_shuffle_ptr_,
|
||||
index_t stride_B_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
void* c_ptr_,
|
||||
index_t stride_C_,
|
||||
index_t k_batch_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: M_indices(M_indices_),
|
||||
group_count(group_count_),
|
||||
M(Max_M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
a_ptr(a_ptr_),
|
||||
stride_A(stride_A_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
stride_B(stride_B_),
|
||||
ds_ptr(ds_ptr_),
|
||||
stride_Ds(stride_Ds_),
|
||||
c_ptr(c_ptr_),
|
||||
stride_C(stride_C_),
|
||||
k_batch(k_batch_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t* M_indices;
|
||||
index_t group_count;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const void* a_ptr;
|
||||
index_t stride_A;
|
||||
const void* b_shuffle_ptr;
|
||||
index_t stride_B;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using UnderlyingGemmKernel = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
|
||||
using BlockGemmShape = typename UnderlyingGemmKernel::BlockGemmShape;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
return concat(
|
||||
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size, GroupedFlatmmKernel, ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
|
||||
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size
|
||||
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size, GroupedFlatmmKernel, MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
// const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
|
||||
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
template<typename HostArgs>
|
||||
CK_TILE_HOST static constexpr auto MakeKernelArgs(const HostArgs& hostArgs)
|
||||
{
|
||||
return hostArgs;
|
||||
}
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
// CK_TILE_HOST static constexpr auto
|
||||
// MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
|
||||
// {
|
||||
// return hostArgs;
|
||||
// }
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
|
||||
{
|
||||
int group_idx = 0;
|
||||
int block_linear_idx = blockIdx.x;
|
||||
int total_block_cnt = gridDim.x;
|
||||
|
||||
UnderlyingGemmKernel underlying_kernel{};
|
||||
for(; group_idx < kargs.group_count; ++group_idx)
|
||||
{
|
||||
const index_t M = kargs.M[group_idx];
|
||||
const index_t N = kargs.N[group_idx];
|
||||
const index_t group_block_cnt = TilePartitioner::GridSize(M, N);
|
||||
|
||||
while(block_linear_idx < group_block_cnt)
|
||||
{
|
||||
// Found the group this block belongs to
|
||||
// create the kernel args for the underlying flatmm kernel
|
||||
typename UnderlyingGemmKernel::template KernelArgs<ScaleM, ScaleN> impl_kargs{
|
||||
kargs.a_ptr[group_idx],
|
||||
kargs.b_shuffle_ptr[group_idx],
|
||||
kargs.ds_ptr,
|
||||
kargs.c_ptr[group_idx],
|
||||
kargs.M[group_idx],
|
||||
kargs.N[group_idx],
|
||||
kargs.K[group_idx],
|
||||
kargs.stride_A[group_idx],
|
||||
kargs.stride_B[group_idx],
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_C[group_idx],
|
||||
kargs.k_batch,
|
||||
};
|
||||
// call the underlying flatmm kernel
|
||||
underlying_kernel(impl_kargs, block_linear_idx);
|
||||
block_linear_idx += total_block_cnt;
|
||||
}
|
||||
block_linear_idx -= group_block_cnt;
|
||||
}
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
|
||||
{
|
||||
int block_linear_idx = blockIdx.x;
|
||||
int total_block_cnt = gridDim.x;
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
UnderlyingGemmKernel underlying_kernel{};
|
||||
for(; block_linear_idx < total_work_tile_cnt; block_linear_idx += total_block_cnt)
|
||||
{
|
||||
auto [block_m_idx, block_n_idx] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(block_linear_idx);
|
||||
// get the group index from the M_indices
|
||||
int group_idx = kargs.M_indices[block_m_idx * BlockGemmShape::kM];
|
||||
|
||||
typename UnderlyingGemmKernel::template KernelArgs<ScaleM, ScaleN> impl_kargs{
|
||||
kargs.a_ptr,
|
||||
static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
|
||||
kargs.ds_ptr,
|
||||
kargs.c_ptr,
|
||||
kargs.M,
|
||||
kargs.N,
|
||||
kargs.K,
|
||||
kargs.stride_A,
|
||||
kargs.stride_B,
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_C,
|
||||
kargs.k_batch,
|
||||
};
|
||||
// call the underlying flatmm kernel
|
||||
underlying_kernel(impl_kargs, block_linear_idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
CK_TILE_DEVICE void operator()(MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor> kargs) const
|
||||
{
|
||||
int group_idx = 0;
|
||||
int block_linear_idx = blockIdx.x;
|
||||
int total_block_cnt = gridDim.x;
|
||||
|
||||
UnderlyingGemmKernel underlying_kernel{};
|
||||
for(; group_idx < kargs.group_count; ++group_idx)
|
||||
{
|
||||
const index_t valid_M = kargs.M_indices[group_idx];
|
||||
const index_t N = kargs.N;
|
||||
const index_t group_block_cnt = TilePartitioner::GridSize(valid_M, N);
|
||||
|
||||
while(block_linear_idx < group_block_cnt)
|
||||
{
|
||||
// Found the group this block belongs to
|
||||
// create the kernel args for the underlying flatmm kernel
|
||||
typename UnderlyingGemmKernel::template KernelArgs<ScaleM, ScaleN> impl_kargs{
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + group_idx * kargs.M * kargs.K,
|
||||
static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
|
||||
kargs.ds_ptr,
|
||||
static_cast<CDataType*>(kargs.c_ptr) + group_idx * kargs.M * kargs.N,
|
||||
valid_M,
|
||||
kargs.N,
|
||||
kargs.K,
|
||||
kargs.stride_A,
|
||||
kargs.stride_B,
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_C,
|
||||
kargs.k_batch,
|
||||
};
|
||||
// call the underlying flatmm kernel
|
||||
underlying_kernel(impl_kargs, block_linear_idx);
|
||||
block_linear_idx += total_block_cnt;
|
||||
}
|
||||
block_linear_idx -= group_block_cnt;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -112,7 +112,7 @@ struct GemmTile1DPartitioner
|
||||
* @param N GEMM's N dimension.
|
||||
* @return dim3 Structure holding grid's X,Y and Z dimensions.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
|
||||
@@ -16,7 +16,7 @@ fi
|
||||
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=ON \
|
||||
|
||||
Reference in New Issue
Block a user