mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Add support for contiguous grouped gemm
This commit is contained in:
@@ -19,8 +19,9 @@ template <typename ADataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float grouped_flatmm(const ck_tile::GroupedFlatmmHostArgs& args, const ck_tile::stream_config& s)
|
||||
typename CLayout,
|
||||
typename KernelArguments>
|
||||
float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
@@ -76,12 +77,12 @@ float grouped_flatmm(const ck_tile::GroupedFlatmmHostArgs& args, const ck_tile::
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -184,34 +185,69 @@ int run_grouped_flatmm_example(int argc, char* argv[])
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string mode = arg_parser.get_str("mode");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
if(mode == "general")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::half_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::half_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::bf16_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::fp8_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::bf8_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
else if(mode == "contiguous")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::bf16_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::fp8_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_grouped_flatmm_example_with_layouts<ck_tile::bf8_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::half_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf16_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::fp8_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf8_t>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
throw std::runtime_error("Unsupported mode!");
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -29,11 +29,11 @@
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
|
||||
#endif
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#if (CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
#elif (CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
@@ -153,9 +153,9 @@ struct GemmConfig<ck_tile::fp8_t>
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
#elif defined(USING_MFMA_32x32x64_F8) // MI350 FP8 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -165,9 +165,9 @@ struct GemmConfig<ck_tile::fp8_t>
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
#elif defined(USING_MFMA_16x16x32_F8) // MI300 FP8 16X16
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -177,9 +177,9 @@ struct GemmConfig<ck_tile::fp8_t>
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
#elif defined(USING_MFMA_32x32x16_F8) // MI300 FP8 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 8;
|
||||
@@ -222,9 +222,9 @@ struct GemmConfig
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
#elif defined(USING_MFMA_32x32x64_F8) // MI350 FP8 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -234,9 +234,9 @@ struct GemmConfig
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
#elif defined(USING_MFMA_16x16x32_F16) // MI350 FP16 16X16 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -246,9 +246,9 @@ struct GemmConfig
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
#elif defined(USING_MFMA_32x32x16_F16) // MI350 FP16 32X32 (need tune)
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
@@ -323,15 +323,16 @@ struct GemmConfig
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "512,128,1024", "m dimension")
|
||||
.insert("Ns", "256,512,1024", "n dimension")
|
||||
.insert("Ks", "512,1024,1024", "k dimension")
|
||||
arg_parser.insert("Ms", "512,256,1024", "m dimension")
|
||||
.insert("Ns", "1024,512,256", "n dimension")
|
||||
.insert("Ks", "1024,1024,512", "k dimension")
|
||||
.insert("group_count", "3", "group count")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("mode", "general", "grouped gemm mode: [general | contiguous], general by default")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -89,6 +89,35 @@ float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::GroupedFlatmmHostAr
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float invoke_gemm(int n_warmup, int n_repeat, const ck_tile::ContiguousGroupedFlatmmHostArgs& args)
|
||||
{
|
||||
float ave_time =
|
||||
grouped_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_grouped_flatmm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
@@ -318,3 +347,181 @@ int run_grouped_flatmm_example_with_layouts(int argc,
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int BlockM = GemmConfig<BDataType>::M_Tile;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
|
||||
if(!(int(Ms.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
// padding additional Ms if needed
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 64 * i);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M =
|
||||
std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
|
||||
// round up to the multiple of BlockM
|
||||
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
});
|
||||
std::cout << "Total M: " << M << std::endl;
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
ck_tile::index_t stride_A = 0;
|
||||
ck_tile::index_t stride_B = 0;
|
||||
ck_tile::index_t stride_C = 0;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tensor(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(std::size_t(M), -1);
|
||||
int indices_fill_start = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
int group_m = Ms[i];
|
||||
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
for(int j = 0; j < padded_group_m; j++)
|
||||
{
|
||||
m_indices[indices_fill_start + j] = j < group_m ? i : -1; // -1 for padding
|
||||
}
|
||||
indices_fill_start += padded_group_m;
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_k_n_tensor);
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf->SetZero();
|
||||
|
||||
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
ck_tile::ContiguousGroupedFlatmmHostArgs kernal_args{
|
||||
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_m_k_dev_buf->GetDeviceBuffer(),
|
||||
stride_A,
|
||||
b_shfl_dev_buf->GetDeviceBuffer(),
|
||||
stride_B,
|
||||
c_m_n_dev_buf->GetDeviceBuffer(),
|
||||
stride_C,
|
||||
kbatch,
|
||||
};
|
||||
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
warmup, repeat, kernal_args);
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Not support v=1 host verification in contiguous grouped gemm, use "
|
||||
"v=2 device verification instead");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::index_t acc_m = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::index_t padded_M = (Ms[i] + BlockM - 1) / BlockM * BlockM;
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_tensor.data() + group_count * N * K,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + acc_m * K,
|
||||
d_B,
|
||||
d_C + acc_m * N,
|
||||
padded_M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
acc_m += padded_M;
|
||||
}
|
||||
ck_tile::hip_check_error(hipMemcpy(
|
||||
c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -53,10 +53,52 @@ struct GroupedFlatmmHostArgs
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
struct ContiguousGroupedFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t* M_indices_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const void* a_ptr_,
|
||||
index_t stride_A_,
|
||||
const void* b_shuffle_ptr_,
|
||||
index_t stride_B_,
|
||||
void* c_ptr_,
|
||||
index_t stride_C_,
|
||||
index_t k_batch_)
|
||||
: M_indices(M_indices_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
a_ptr(a_ptr_),
|
||||
stride_A(stride_A_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
stride_B(stride_B_),
|
||||
c_ptr(c_ptr_),
|
||||
stride_C(stride_C_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t* M_indices;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const void* a_ptr;
|
||||
index_t stride_A;
|
||||
const void* b_shuffle_ptr;
|
||||
index_t stride_B;
|
||||
void* c_ptr;
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using UnderlyingGemmKernel = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
|
||||
using BlockGemmShape = typename UnderlyingGemmKernel::BlockGemmShape;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
@@ -68,15 +110,13 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedFlatmmKernelArgs = GroupedFlatmmHostArgs;
|
||||
|
||||
CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
return concat(
|
||||
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs)
|
||||
template <class KernelArgs>
|
||||
CK_TILE_HOST_DEVICE static auto GridSizeImpl(const KernelArgs& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
@@ -89,29 +129,41 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
// reinterpret_cast<void*>(GroupedFlatmmKernel::Kernel),
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmKernelArgs>),
|
||||
reinterpret_cast<void*>(kentry2<block_size, GroupedFlatmmKernel, KernelArgs>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
|
||||
// print maxActiveBlocksPerCU and persistent_block_size
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedFlatmmKernelArgs
|
||||
MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs)
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs)
|
||||
{
|
||||
return GridSizeImpl<GroupedFlatmmHostArgs>(kernelArgs);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs& kernelArgs)
|
||||
{
|
||||
return GridSizeImpl<ContiguousGroupedFlatmmHostArgs>(kernelArgs);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs)
|
||||
{
|
||||
return hostArgs;
|
||||
}
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
|
||||
{
|
||||
return hostArgs;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedFlatmmKernelArgs kargs) const
|
||||
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs kargs) const
|
||||
{
|
||||
int group_idx = 0;
|
||||
int block_linear_idx = blockIdx.x;
|
||||
@@ -147,6 +199,36 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
block_linear_idx -= group_block_cnt;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs kargs) const
|
||||
{
|
||||
int block_linear_idx = blockIdx.x;
|
||||
int total_block_cnt = gridDim.x;
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
UnderlyingGemmKernel underlying_kernel{};
|
||||
for(; block_linear_idx < total_work_tile_cnt; block_linear_idx += total_block_cnt)
|
||||
{
|
||||
auto [block_m_idx, block_n_idx] = TilePartitioner::GetOutputTileIndex(block_linear_idx);
|
||||
// get the group index from the M_indices
|
||||
int group_idx = kargs.M_indices[block_m_idx * BlockGemmShape::kM];
|
||||
|
||||
typename UnderlyingGemmKernel::FlatmmKernelArgs impl_kargs{
|
||||
kargs.a_ptr,
|
||||
static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
|
||||
kargs.c_ptr,
|
||||
kargs.M,
|
||||
kargs.N,
|
||||
kargs.K,
|
||||
kargs.stride_A,
|
||||
kargs.stride_B,
|
||||
kargs.stride_C,
|
||||
kargs.k_batch,
|
||||
};
|
||||
// call the underlying flatmm kernel
|
||||
underlying_kernel(impl_kargs, block_linear_idx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user