mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Merge commit 'aaa35f0bbfa45dadc4380ddd6e0224668ddb97b4' into develop
This commit is contained in:
@@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
## (Unreleased) Composable Kernel 1.3.0
|
||||
|
||||
### Added
|
||||
* Added preshuffleB support for abquant mode in blockscale GEMM.
|
||||
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
|
||||
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
|
||||
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
|
||||
|
||||
@@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
quant_grouped_gemm_bf8_rowcol.cpp
|
||||
quant_grouped_gemm_bf8_tensor.cpp
|
||||
)
|
||||
|
||||
add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp)
|
||||
add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
@@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
278
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp
Normal file
278
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "abquant_grouped_gemm.hpp"
|
||||
|
||||
// Non-persistent grouped gemm for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
// Persistent grouped gemm tileloop for ABQuant
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
GemmConfig::TransposeC>;
|
||||
|
||||
using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_abquant_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = run_abquant_grouped_gemm_example(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
171
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp
Normal file
171
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <bool Persistent_>
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <ck_tile::QuantType QuantMode>
|
||||
struct GemmQuantConfig;
|
||||
|
||||
// ABQuant specialization for GemmQuantConfig
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::ABQuantGrouped>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert(
|
||||
"stride_As",
|
||||
"",
|
||||
"Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs
|
||||
// can be set to zero if
|
||||
// Ms/Ns/Ks is not empty
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
|
||||
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
|
||||
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.")
|
||||
.insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
// Forward declaration of the non-persistent version
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr);
|
||||
|
||||
// Forward declaration of the tileloop version for persistent kernels
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr);
|
||||
@@ -0,0 +1,604 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::ABQuantGrouped,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_abquant_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
{
|
||||
// Workspace memory allocated to hold the gemm descriptions.
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
ave_time = grouped_gemm_abquant<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
if(args[0].k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("Split-K not supported yet for persistent kernel");
|
||||
}
|
||||
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AQuantGroupSize,
|
||||
typename BQuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_abquant_grouped_gemm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
const int init_method = arg_parser.get_int("init");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
std::cout << "WARNING: Data validation enabled with SplitK and more than"
|
||||
<< "1 warmup/repeat. Disabling validation." << std::endl;
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
|
||||
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
std::vector<ck_tile::index_t> stride_AQs = arg_parser.get_int_vec("stride_AQs");
|
||||
std::vector<ck_tile::index_t> stride_BQs = arg_parser.get_int_vec("stride_BQs");
|
||||
|
||||
ck_tile::index_t AQK, BQK;
|
||||
|
||||
if(!valid_input_data(
|
||||
group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs))
|
||||
{
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
|
||||
// Clear existing (invalid) data before adding defaults
|
||||
Ms.clear();
|
||||
Ns.clear();
|
||||
Ks.clear();
|
||||
stride_As.clear();
|
||||
stride_Bs.clear();
|
||||
stride_Cs.clear();
|
||||
stride_AQs.clear();
|
||||
stride_BQs.clear();
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
// Let get_default_stride calculate based on layout
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
|
||||
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
aq_tensors.reserve(group_count);
|
||||
bq_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
aq_dev_buf.reserve(group_count);
|
||||
bq_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
// For ABQuantGrouped, both A and B need quantization
|
||||
static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
"This file only supports ABQuantGrouped mode");
|
||||
|
||||
AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize
|
||||
BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize
|
||||
if(K % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
if(K % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
|
||||
|
||||
if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
aq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(aq_tensors[i].get_element_space_size_in_bytes()));
|
||||
bq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_aq,
|
||||
p_bq,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
BQK,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
stride_Cs[i],
|
||||
stride_AQs[i],
|
||||
stride_BQs[i]});
|
||||
}
|
||||
|
||||
float ave_time = invoke_abquant_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K;
|
||||
|
||||
num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K +
|
||||
sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N +
|
||||
sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(validate)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Reference implementation for ABQuantGrouped
|
||||
ck_tile::reference_gemm_abquant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
Ks[i], kbatch, max_accumulated_value);
|
||||
pass &=
|
||||
ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "gemm[" << i
|
||||
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_grouped_gemm_json_results<ALayout, BLayout, CLayout>(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
group_count,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename GemmConfig, typename BQuantGroupSize>
|
||||
int run_abquant_grouped_gemm_example_prec_type_with_bquant(
|
||||
std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
using AQDataType = typename Types::AccDataType;
|
||||
using BQDataType = typename Types::AccDataType;
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R" && c_layout == "R")
|
||||
{
|
||||
return run_abquant_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
QuantMode>(
|
||||
argc, argv, Col{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrecType, typename GemmConfig>
|
||||
int run_abquant_grouped_gemm_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
std::string bquant_group_size,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
if(bquant_group_size == "1x1x128")
|
||||
{
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
|
||||
GemmConfig,
|
||||
BQuantGroupSize>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else if(bquant_group_size == "1x128x128")
|
||||
{
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return run_abquant_grouped_gemm_example_prec_type_with_bquant<PrecType,
|
||||
GemmConfig,
|
||||
BQuantGroupSize>(
|
||||
a_layout, b_layout, c_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PrecType>
|
||||
int run_abquant_gemm_example_persistency(std::string a_layout,
|
||||
std::string b_layout,
|
||||
std::string c_layout,
|
||||
bool persistent,
|
||||
std::string bquant_group_size,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
if(persistent)
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<
|
||||
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, true>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GemmConfig = typename GemmQuantConfig<
|
||||
ck_tile::QuantType::ABQuantGrouped>::template GemmConfig<PrecType, false>;
|
||||
return run_abquant_grouped_gemm_example_prec_type<PrecType, GemmConfig>(
|
||||
a_layout, b_layout, c_layout, bquant_group_size, argc, argv);
|
||||
}
|
||||
}
|
||||
|
||||
int run_abquant_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string c_layout = arg_parser.get_str("c_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
const std::string bquant_group_size = arg_parser.get_str("bquant_group_size");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run_abquant_gemm_example_persistency<ck_tile::fp8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_abquant_gemm_example_persistency<ck_tile::bf8_t>(
|
||||
a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
@@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory(
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -9,36 +9,194 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
@@ -47,10 +205,63 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
|
||||
@@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
using AQuantPipeline =
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
|
||||
|
||||
using BQuantPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using ABQuantPipeline =
|
||||
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>>>;
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
AQuantPipeline,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ABQuantPipeline,
|
||||
BQuantPipeline>>>;
|
||||
|
||||
constexpr bool TiledPermuteN =
|
||||
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
|
||||
@@ -532,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
@@ -908,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if((QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
|
||||
GemmConfig::PreshuffleB)
|
||||
@@ -938,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::PreshuffleQuant)
|
||||
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
@@ -24,6 +25,8 @@
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// BQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on block distributed tensor.
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
// Threadblock GEMM tile size
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN;
|
||||
static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
// number of warps along M and N for threadblock's GEMM problem size
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
|
||||
"Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
|
||||
"Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
|
||||
"Error! WarpGemm's M is not consistent with BlockGemmShape!");
|
||||
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
|
||||
"Error! WarpGemm's N is not consistent with BlockGemmShape!");
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / BQuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
|
||||
"Error! Warps should cover all Block tile!");
|
||||
|
||||
// Currently tested combinations (A, B, BQ)
|
||||
// 1. fp8, fp8, fp32 -> f32
|
||||
// 2. bf8, bf8, fp32 -> f32
|
||||
// 3. i4, fp8, (fp8/fp32) -> f32
|
||||
// 4. i4, bf8, (fp8/fp32) -> f32
|
||||
static_assert(
|
||||
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
|
||||
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
|
||||
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<BQDataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<ComputeDataType, fp8_t> || std::is_same_v<ComputeDataType, bf8_t>) &&
|
||||
std::is_same_v<CDataType, fp32_t>);
|
||||
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
static constexpr bool TransposeC = Problem::TransposeC;
|
||||
};
|
||||
|
||||
public:
|
||||
using Traits = GemmTraits_<Problem_, BlockPolicy_>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
static constexpr auto warp_size = get_warp_size();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 128 / (1 * 16) = 8
|
||||
static constexpr index_t NIterPerWarp =
|
||||
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); // 128 / (4 * 16) = 2
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 128 / 16 = 8
|
||||
static constexpr auto MIter_2nd_last =
|
||||
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
return BlockGemmQuantCommon<CDataType, WG, MIterPerWarp, MWarp, NIterPerWarp, NWarp>::
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ABlockTensor,
|
||||
typename BFlatBlockTensor,
|
||||
typename AQBlockTensor,
|
||||
typename BQBlockTensor,
|
||||
typename ABlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
ABlockTensor& a_warp_tensor,
|
||||
BFlatBlockTensor& b_warp_tensor,
|
||||
AQBlockTensor& aq_block_tensor,
|
||||
BQBlockTensor& bq_block_tensor,
|
||||
ABlockWindow& a_warp_windows) const
|
||||
{
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using AccTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AccTensor, NIterPerWarp>, MIterPerWarp>
|
||||
c_acc;
|
||||
|
||||
auto zero_accumulators = [&] {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) {
|
||||
c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
|
||||
}); // make sure WG::CWarpTensor exposes a clear/zero
|
||||
});
|
||||
});
|
||||
};
|
||||
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
zero_accumulators();
|
||||
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// warp GEMM
|
||||
WG{}(c_acc(mIter)(nIter),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor(nIter)(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// Could be deleted
|
||||
if constexpr((mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_reg_f =
|
||||
aq_picker.template cvt_scale_to_fp32<BQDataType>(scale_reg);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
float a_scale_reg_f = aq_picker.template pick<c_row>();
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -322,6 +322,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
constexpr index_t reg_offset = nIter;
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
@@ -280,12 +280,13 @@ struct QuantGemmKernel
|
||||
// Helper: Create Pre-shuffled Quantization Tensor Descriptor
|
||||
// ===================================================================
|
||||
template <index_t KPerBlockBQ,
|
||||
index_t NPerBlockBQ,
|
||||
index_t NPerBlock,
|
||||
index_t WarpTileN,
|
||||
index_t GetVectorSizeBQ,
|
||||
typename BQDataType_>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B)
|
||||
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B)
|
||||
{
|
||||
// Step 1: Calculate base BQ tensor dimensions
|
||||
// ----------------------------------------------------------
|
||||
@@ -304,8 +305,9 @@ struct QuantGemmKernel
|
||||
// ----------------------------------------------------------
|
||||
// Pad the X dimension to be a multiple of block_tile_size to ensure
|
||||
// each thread block can process complete tiles without edge cases
|
||||
const auto block_tile_size = NPerBlock * KPerBlockBQ;
|
||||
const auto bq_pad0_desc = transform_tensor_descriptor(
|
||||
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ;
|
||||
|
||||
const auto bq_pad0_desc = transform_tensor_descriptor(
|
||||
bq_desc,
|
||||
make_tuple(make_pass_through_transform(bq_y),
|
||||
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))),
|
||||
@@ -318,7 +320,7 @@ struct QuantGemmKernel
|
||||
// This separates the work into tiles that can be processed by
|
||||
// individual warps/waves
|
||||
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size = WarpTileN * KPerBlockBQ;
|
||||
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ;
|
||||
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
|
||||
|
||||
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
@@ -813,12 +815,18 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"PreshuffleQuant with BQuantGrouped currently only supports "
|
||||
"ColumnMajor BQ layout");
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
|
||||
return MakePreshuffledQuantTensorView<
|
||||
GemmPipeline::KPerBlockBQ,
|
||||
GemmPipeline::NPerBlockBQ,
|
||||
GemmPipeline::NPerBlock,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
|
||||
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
|
||||
GemmPipeline::GetVectorSizeBQ()>(
|
||||
bq_ptr,
|
||||
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
|
||||
QuantGroupSize::kN,
|
||||
kargs.QK_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -879,13 +887,38 @@ struct QuantGemmKernel
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
constexpr auto block_n =
|
||||
TilePartitioner::NPerBlock /
|
||||
QuantGroupSize::kN; // Number of N-dimension quantization groups per block
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
|
||||
I1); // Number of N-dimension elements per warp
|
||||
constexpr auto warp_per_group =
|
||||
(QuantGroupSize::kN <
|
||||
warp_n) // Determine how many warps share the same scale in N-dimension
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
constexpr auto bqk_per_block =
|
||||
TilePartitioner::KPerBlock /
|
||||
QuantGroupSize::kK; // Number of K-dimension quantization groups per block
|
||||
constexpr auto
|
||||
tile_window_width = // The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
// to ensure coalesced memory access.
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_n / warp_n;
|
||||
auto block_n_idx = i_n / block_n;
|
||||
|
||||
// Adapts based on fine vs coarse quantization granularity:
|
||||
// - Fine-grained (QuantGroupSize::kN < warp_n):
|
||||
// Multiple quant groups per warp → fewer rows needed per block.
|
||||
// height = block_n / warp_per_group
|
||||
//
|
||||
// - Coarse-grained (QuantGroupSize::kN >= warp_n):
|
||||
// Each row represents one quant group.
|
||||
// height = block_n
|
||||
constexpr auto tile_window_height =
|
||||
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
|
||||
auto block_n_idx =
|
||||
i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
|
||||
// block index.
|
||||
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
@@ -1125,596 +1158,6 @@ struct QuantGemmKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
|
||||
make_tuple(aq_x, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
|
||||
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_pad0_desc = transform_tensor_descriptor(
|
||||
aq_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size =
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
aq_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto aq_pad1_desc = transform_tensor_descriptor(
|
||||
aq_unmerge_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(
|
||||
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto pad_wave_size =
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
}
|
||||
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped) &&
|
||||
!PreshuffleQuant)
|
||||
{
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions
|
||||
make_tuple(kargs.stride_AQ, 1), // Same stride pattern
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, 0), // broadcasting over n
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type for this
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
index_t kFlatK = GemmPipeline::flatKPerWarp *
|
||||
(splitk_batch_offset.splitted_k /
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
else
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(0, 1), // broadcasting over m
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"PreshuffleQuant with BQuantGrouped currently only supports "
|
||||
"ColumnMajor BQ layout");
|
||||
|
||||
return MakePreshuffledQuantTensorView<
|
||||
GemmPipeline::KPerBlockBQ,
|
||||
GemmPipeline::NPerBlock,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
|
||||
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN]
|
||||
// Dimensions: [K/QuantGroupK, N/QuantGroupN]
|
||||
// Strides: [N/QuantGroupN, 1]
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
|
||||
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
// For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK]
|
||||
// Dimensions: [N/QuantGroupN, K/QuantGroupK]
|
||||
// Strides: [K/QuantGroupK, 1]
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
|
||||
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type for this
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(
|
||||
a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// no padding
|
||||
const auto& aq_pad_view = [&]() { return views.at(I1); }();
|
||||
|
||||
const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
else
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// no padding
|
||||
const auto& bq_pad_view = [&]() { return views.at(I3); }();
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I4);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& bq_pad_view = views.at(I3);
|
||||
const auto& c_pad_view = views.at(I4);
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& aq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
auto block_m_idx = i_m / block_m;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * tile_window_height, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else // Column major AQ
|
||||
{
|
||||
return make_tile_window(aq_pad_view,
|
||||
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(aq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type?
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::flatNPerWarp>{},
|
||||
number<GemmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
{i_n, 0});
|
||||
else
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_n / warp_n;
|
||||
auto block_n_idx = i_n / block_n;
|
||||
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type here
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(
|
||||
a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
|
||||
@@ -484,6 +484,17 @@ struct QuantGroupedGemmKernel
|
||||
tail_num,
|
||||
smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
@@ -499,7 +510,8 @@ struct QuantGroupedGemmKernel
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
kQuantType == QuantType::BQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
@@ -527,7 +539,8 @@ struct QuantGroupedGemmKernel
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
kQuantType == QuantType::BQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
}
|
||||
|
||||
@@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
@@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
BlockSize,
|
||||
NPerBlock / WarpGemm::kN,
|
||||
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
|
||||
VecLoadSize,
|
||||
Problem::BQuantGroupSize::kN,
|
||||
Problem::BQuantGroupSize::kK,
|
||||
BQLayout,
|
||||
PreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
@@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
KPerBlockBQ, // Logical K dimension
|
||||
NPerBlockBQ, // Logical N dimension
|
||||
Problem::BQuantGroupSize::kN,
|
||||
Problem::BQuantGroupSize::kK,
|
||||
BQLayout>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
@@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
0)
|
||||
(PreshuffleQuant)
|
||||
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0)
|
||||
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ template <typename BlockGemmShape,
|
||||
index_t KPerTile,
|
||||
index_t NPerTile,
|
||||
index_t NPerQ,
|
||||
index_t KPerQ,
|
||||
typename BQLayout = tensor_layout::gemm::ColumnMajor,
|
||||
bool PreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
@@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
static_assert(num_warps == MWarps * NWarps * KWarps);
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
|
||||
///
|
||||
/// This function determines the optimal thread distribution pattern for loading and applying
|
||||
/// quantization scales to the B matrix based on the quantization group size (NPerQ) relative
|
||||
/// to warp dimensions.
|
||||
///
|
||||
/// Three distinct distribution patterns are handled:
|
||||
///
|
||||
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
|
||||
/// - Multiple quantization groups exist within a single warp's N-dimension
|
||||
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
|
||||
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast
|
||||
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
|
||||
///
|
||||
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
|
||||
/// - Each warp handles exactly one quantization scale
|
||||
/// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN
|
||||
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
|
||||
///
|
||||
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
|
||||
/// - Quantization group spans multiple warps
|
||||
/// - All warps share the same scale value
|
||||
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
|
||||
///
|
||||
/// @return A static tile distribution encoding for the BQ scale tensor
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
// Preshuffle only supported for ColumnMajor currently
|
||||
@@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
// ColumnMajor only for preshuffle
|
||||
constexpr index_t X1 = warp_size;
|
||||
constexpr index_t X0 = NPerTile / warp_size;
|
||||
constexpr index_t Y1 = NWarps;
|
||||
constexpr index_t Y0 = KPerTile / Y1;
|
||||
// =============================================================================
|
||||
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
|
||||
// =============================================================================
|
||||
// For pre-shuffled quantization, the BQ scale tensor has been reorganized
|
||||
// (pre-shuffled) to optimize memory access patterns during dequantization.
|
||||
//
|
||||
// Tile Dimensions:
|
||||
// - K-axis (Y in encoding): Corresponds to the K-dimension iteration
|
||||
// - N-axis (X in encoding): Flattened scale index combining N and K groups
|
||||
//
|
||||
// The encoding distributes work across threads such that each thread loads
|
||||
// the correct pre-shuffled scale for its corresponding B-matrix elements.
|
||||
// =============================================================================
|
||||
if constexpr(NPerQ <= WarpGemm::kN)
|
||||
{
|
||||
// =========================================================================
|
||||
// CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN)
|
||||
// =========================================================================
|
||||
// Multiple quantization scales exist within a single warp's N-dimension.
|
||||
// Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp.
|
||||
//
|
||||
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
|
||||
// → 2 scales per warp in N, 2 K-groups per block
|
||||
constexpr auto N1 = BlockGemmShape::kK /
|
||||
KPerQ; // Number of K-dimension quantization groups per block,
|
||||
// Each K-group of KPerQ elements shares the same scale.
|
||||
constexpr auto N0 =
|
||||
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
|
||||
// <= WarpGemm::kN, each warp handles multiple scales.
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
|
||||
constexpr auto NR0 =
|
||||
warp_size /
|
||||
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
|
||||
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
|
||||
constexpr auto KR = 1; // No replication in K-dimension
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 1>, sequence<2>>,
|
||||
tuple<sequence<0, 1>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 0, 2, 1, 3>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else if constexpr(NPerQ < WarpGemm::kN * NWarps)
|
||||
{
|
||||
// =========================================================================
|
||||
// CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN *
|
||||
// NWarps)
|
||||
// =========================================================================
|
||||
// Each warp handles exactly one quantization scale in N-dimension.
|
||||
// Some warps share the same scale (KR > 1 creates warp grouping).
|
||||
//
|
||||
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
|
||||
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
|
||||
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
|
||||
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// =========================================================================
|
||||
// CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps)
|
||||
// =========================================================================
|
||||
// The quantization group spans ALL warps in N-dimension.
|
||||
// All warps share the same scale value for their N-tiles.
|
||||
//
|
||||
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
|
||||
// → 128 >= 16*4=64, so all 4 warps use the same scale
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = 32; // Fixed broadcast size
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
|
||||
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
|
||||
///
|
||||
/// This function determines the optimal thread distribution pattern for loading and
|
||||
/// applying quantization scales to the B matrix based on the quantization group size
|
||||
/// (NPerQ) relative to warp dimensions.
|
||||
///
|
||||
/// Three distinct distribution patterns are handled:
|
||||
///
|
||||
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
|
||||
/// - Multiple quantization groups exist within a single warp's N-dimension
|
||||
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
|
||||
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale
|
||||
/// broadcast
|
||||
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
|
||||
///
|
||||
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
|
||||
/// - Each warp handles exactly one quantization scale
|
||||
/// - Scales are distributed across warps with replication factor XR = NPerQ /
|
||||
/// WarpGemm::kN
|
||||
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
|
||||
///
|
||||
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
|
||||
/// - Quantization group spans multiple warps
|
||||
/// - All warps share the same scale value
|
||||
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
|
||||
///
|
||||
/// @return A static tile distribution encoding for the BQ scale tensor
|
||||
if constexpr(NPerQ < WarpGemm::kN)
|
||||
{
|
||||
// Case 1: Fine-grained - multiple quantization scales within a single warp
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ()
|
||||
{
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
|
||||
|
||||
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution()
|
||||
{
|
||||
return GemmAQuantPipelineAgBgCrDefaultPolicy::MakeAQDramTileDistribution<Problem>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
|
||||
{
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
|
||||
{
|
||||
return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
// as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed;
|
||||
// move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here
|
||||
// temporarily
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
// TODO : Use a custom block policy for AsBrCr
|
||||
using BlockGemmPolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmWeightPreshuffleABQuantARegBRegCReg<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,611 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = GemmWPABQuantPipelineAgBgCrPolicy>
|
||||
struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using Base = WeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockWeightPreshuffle = remove_cvref_t<
|
||||
decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant<Problem>())>;
|
||||
|
||||
static constexpr auto config =
|
||||
BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
using Base::kKPerBlock;
|
||||
using Base::kMPerBlock;
|
||||
using Base::kNPerBlock;
|
||||
|
||||
using Base::KIterPerWarp;
|
||||
using Base::MIterPerWarp;
|
||||
using Base::NIterPerWarp;
|
||||
|
||||
using Base::BlockSize;
|
||||
|
||||
using Base::kPadK;
|
||||
using Base::kPadM;
|
||||
using Base::kPadN;
|
||||
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
|
||||
using Base::MWarp;
|
||||
using Base::NWarp;
|
||||
|
||||
using Base::KPerBlockPerIter;
|
||||
using Base::MPerBlockPerIter;
|
||||
|
||||
using Base::flatKPerWarp;
|
||||
using Base::flatNPerWarp;
|
||||
|
||||
using Base::m_preload;
|
||||
|
||||
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
|
||||
static constexpr index_t KPerBlockAQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK);
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK);
|
||||
static constexpr index_t GetVectorSizeAQ()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeAQ<Problem>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeBQ<Problem>();
|
||||
}
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0);
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1);
|
||||
return concat('_', "bquant_pipeline_AgBgCrV2_preshuffleB",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock),
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeAQ(), GetVectorSizeBQ()),
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <index_t nloop>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// Estimated number of VMEM vector loads for A per block:
|
||||
// total A bytes / (threads per block * vector width)
|
||||
constexpr index_t Aload_inst =
|
||||
(kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize;
|
||||
// Estimated number of VMEM vector loads for B per block:
|
||||
// total B bytes / (threads per block * vector width)
|
||||
constexpr index_t Bload_inst =
|
||||
(kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize;
|
||||
|
||||
// Estimated number of VMEM loads for B's quant data (e.g. scales / zp).
|
||||
// First ceil-divide by quant group size (how many elements share one scale),
|
||||
// then by vector width to get an approximate number of vector loads.
|
||||
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
|
||||
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
|
||||
BQuantGroupSize::kK * BQuantGroupSize::kK),
|
||||
VectorLoadSize);
|
||||
|
||||
// ToDo: Hardcoded, need to change in future. How many instruction emit per iteration
|
||||
constexpr index_t kLdsInstCycle = 8;
|
||||
// Total VMEM load instructions (A + B + quant data)
|
||||
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
|
||||
// Approximate number of LDS reads per block
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
|
||||
// Approximate number of LDS writes per block
|
||||
// (e.g., writing A from VMEM into LDS once per A load)
|
||||
constexpr index_t ds_write_inst = Aload_inst;
|
||||
// Number of MFMA instructions per wave for one block tile:
|
||||
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
|
||||
// How often (in MFMA units) we should insert DS (LDS) operations.
|
||||
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
|
||||
// How often (in MFMA units) we should insert VMEM buffer loads.
|
||||
// buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load
|
||||
// is assumed to cover at most 4 MFMA instructions.
|
||||
constexpr index_t buffer_load_rep =
|
||||
min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma
|
||||
|
||||
static_for<0, nloop, 1>{}([&](auto) {
|
||||
static_for<0, mfma_inst, 1>{}([&](auto i_inst) {
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA
|
||||
|
||||
// Insert LDS read/write groups periodically based on ds_rep.
|
||||
// The % pattern staggers READ and WRITE so they don't collapse
|
||||
// into the same cycle in the model.
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
|
||||
}
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
|
||||
}
|
||||
|
||||
if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0)
|
||||
{
|
||||
if constexpr(ds_write_inst > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
|
||||
}
|
||||
}
|
||||
// Always mark some VALU work in the loop to reflect auxiliary scalar
|
||||
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
static constexpr bool PreshuffleB = Problem::PreshuffleB;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
template <TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AQDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t m,
|
||||
index_t n,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
(void)m;
|
||||
(void)n;
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BFlatBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BQDataType, remove_cvref_t<typename BQDramBlockWindowTmp::DataType>>,
|
||||
"A/B/BQ Dram block window should have the same data type as appropriate "
|
||||
"([A|B|BQ]DataType) defined in Problem definition!");
|
||||
|
||||
constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
static_assert(!is_a_col_major, "A must be row major (col major not supported yet)");
|
||||
|
||||
constexpr bool is_bq_col_major = std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
static_assert(!is_b_row_major, "B must be col major (row major not supported yet)");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
// Double-Buffering (loop_count=2) for full load/compute overlap.
|
||||
const index_t loop_count = 2;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// A tile in LDS
|
||||
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem);
|
||||
ADataType* p_a_lds_pong =
|
||||
reinterpret_cast<ADataType*>(static_cast<char*>(p_smem) + smem_size);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto a_copy_lds_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
auto block_weight_preshuffle = BlockWeightPreshuffle();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
auto aq_copy_dram_window =
|
||||
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
aq_dram_block_window_tmp.get_window_lengths(),
|
||||
aq_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeAQDramTileDistribution<Problem>());
|
||||
// BQ DRAM window for load
|
||||
auto bq_copy_dram_window =
|
||||
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bq_dram_block_window_tmp.get_window_lengths(),
|
||||
bq_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeBQDramTileDistribution<Problem>());
|
||||
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// Strictly not needed given type deduction, but helps with readability
|
||||
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
|
||||
using AQBlockTile =
|
||||
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
|
||||
using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
|
||||
using BQBlockTile =
|
||||
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
|
||||
|
||||
// Load tile 0 for BQ data directly into registers for block tile
|
||||
AQBlockTile aq_block_tile, aq_block_tile_2;
|
||||
BQBlockTile bq_block_tile, bq_block_tile_2;
|
||||
aq_block_tile = load_tile(aq_copy_dram_window);
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
// move BQ to tile 1
|
||||
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
// Prefill A0
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10 from lds
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / loop_count;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_ping,
|
||||
aq_block_tile,
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
aq_block_tile_2 = load_tile(aq_copy_dram_window);
|
||||
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
aq_block_tile = load_tile(aq_copy_dram_window);
|
||||
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_pong,
|
||||
aq_block_tile_2,
|
||||
bq_block_tile_2,
|
||||
a_warp_windows_pong);
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
iCounter--;
|
||||
HotLoopScheduler<loop_count>();
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
aq_block_tile_2 = load_tile(aq_copy_dram_window);
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// GEMM loopK-1
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_ping,
|
||||
aq_block_tile,
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
// GEMM loopK
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_pong,
|
||||
aq_block_tile_2,
|
||||
bq_block_tile_2,
|
||||
a_warp_windows_pong);
|
||||
HotLoopScheduler<loop_count>();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_ping,
|
||||
aq_block_tile,
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
Base::LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
// Replace lines 485-526 with a single optimized operator:
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AQDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem,
|
||||
index_t m = 0,
|
||||
index_t n = 0) const // Default value for non-preshuffle case
|
||||
{
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
aq_dram_block_window_tmp,
|
||||
bq_dram_block_window_tmp,
|
||||
m,
|
||||
n,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AQDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem,
|
||||
index_t n = 0) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
return operator()<tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
aq_dram_block_window_tmp,
|
||||
bq_dram_block_window_tmp,
|
||||
n, // dummy value, won't be used
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
@@ -352,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
}
|
||||
else
|
||||
@@ -427,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
}
|
||||
else
|
||||
@@ -462,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
}
|
||||
else
|
||||
|
||||
@@ -9,6 +9,7 @@ add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(grouped_gemm_preshuffle)
|
||||
add_subdirectory(grouped_gemm_multi_d)
|
||||
add_subdirectory(grouped_gemm_quant)
|
||||
add_subdirectory(grouped_gemm_abquant)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_multi_abd)
|
||||
add_subdirectory(gemm_streamk)
|
||||
|
||||
@@ -39,6 +39,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle
|
||||
test_gemm_quant_abquant_preshuffle_2d.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# AQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
|
||||
test_gemm_quant_aquant_prefill.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleBTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -894,10 +894,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
std::conditional_t<PreshuffleB == false,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
@@ -926,8 +926,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
std::conditional_t<PreshuffleB == false,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
|
||||
16
test/ck_tile/grouped_gemm_abquant/CMakeLists.txt
Normal file
16
test/ck_tile/grouped_gemm_abquant/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x1x128 test_grouped_gemm_abquant_1x1x128.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_abquant_1x1x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x128x128 test_grouped_gemm_abquant_1x128x128.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_abquant_1x128x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_abquant_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
|
||||
// AQuant group size is fixed at 1x1x128
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// BQuant group size: 1x128x128
|
||||
using BQuantGroupSize_1x128x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_ABQuant_1x128x128 = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent
|
||||
|
||||
// FP8 variants
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>,
|
||||
|
||||
// BF8 variants
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x128x128, KernelTypes_ABQuant_1x128x128);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x128x128
|
||||
#include "test_grouped_gemm_abquant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_abquant_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
|
||||
// AQuant group size is fixed at 1x1x128
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// BQuant group size: 1x1x128
|
||||
using BQuantGroupSize_1x1x128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_ABQuant_1x1x128 = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent
|
||||
|
||||
// FP8 variants
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>,
|
||||
|
||||
// BF8 variants
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x1x128, KernelTypes_ABQuant_1x1x128);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x1x128
|
||||
#include "test_grouped_gemm_abquant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_CLASS_NAME, Basic)
|
||||
{
|
||||
const int group_count = 6;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
std::vector<int> stride_AQs;
|
||||
std::vector<int> stride_BQs;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
|
||||
// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop
|
||||
// Using 256x256x128 to match the test kernel's tile size (M_Tile=128, N_Tile=128, K_Tile=128)
|
||||
TYPED_TEST(TEST_CLASS_NAME, SmallUniform)
|
||||
{
|
||||
const int group_count = 2;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
std::vector<int> stride_AQs;
|
||||
std::vector<int> stride_BQs;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256);
|
||||
Ns.push_back(256);
|
||||
Ks.push_back(256);
|
||||
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_CLASS_NAME, OddTail)
|
||||
{
|
||||
const int group_count = 2;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
std::vector<int> stride_AQs;
|
||||
std::vector<int> stride_BQs;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256);
|
||||
Ns.push_back(256);
|
||||
Ks.push_back(128);
|
||||
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
@@ -0,0 +1,530 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemmABQuant : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using AQDataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using BQDataType = std::tuple_element_t<6, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
using AQuantGroupSize = std::tuple_element_t<9, Tuple>;
|
||||
using BQuantGroupSize = std::tuple_element_t<10, Tuple>;
|
||||
static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using AQLayout = Row;
|
||||
using BQLayout = Col;
|
||||
|
||||
static constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped;
|
||||
|
||||
struct GemmConfig
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(ADataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<ADataType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr bool IsPersistent = Persistent;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename Config>
|
||||
float invoke_grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<Config::M_Tile, Config::N_Tile, Config::K_Tile>,
|
||||
ck_tile::sequence<Config::M_Warp, Config::N_Warp, Config::K_Warp>,
|
||||
ck_tile::sequence<Config::M_Warp_Tile, Config::N_Warp_Tile, Config::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::
|
||||
TileGemmTraits<Config::kPadM, Config::kPadN, Config::kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<Config::kPadM,
|
||||
Config::kPadN,
|
||||
Config::kPadK,
|
||||
false,
|
||||
Config::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
Config::TransposeC,
|
||||
Config::DoubleSmemBuffer,
|
||||
Config::IsPersistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = Config::Scheduler;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
Config::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Config::M_Warp,
|
||||
Config::N_Warp,
|
||||
Config::M_Warp_Tile,
|
||||
Config::N_Warp_Tile,
|
||||
Config::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName()
|
||||
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
|
||||
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<Config::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename Config>
|
||||
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<Config::M_Tile, Config::N_Tile, Config::K_Tile>,
|
||||
ck_tile::sequence<Config::M_Warp, Config::N_Warp, Config::K_Warp>,
|
||||
ck_tile::sequence<Config::M_Warp_Tile, Config::N_Warp_Tile, Config::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<Config::kPadM,
|
||||
Config::kPadN,
|
||||
Config::kPadK,
|
||||
false,
|
||||
Config::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
Config::TransposeC,
|
||||
Config::DoubleSmemBuffer,
|
||||
Config::IsPersistent>;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
Config::TransposeC>;
|
||||
|
||||
using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Config::M_Warp,
|
||||
Config::N_Warp,
|
||||
Config::M_Warp_Tile,
|
||||
Config::N_Warp_Tile,
|
||||
Config::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<Config::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
std::vector<int>& stride_As,
|
||||
std::vector<int>& stride_Bs,
|
||||
std::vector<int>& stride_Cs,
|
||||
std::vector<int>& stride_AQs,
|
||||
std::vector<int>& stride_BQs,
|
||||
const int group_count = 8)
|
||||
{
|
||||
ck_tile::index_t AQK, BQK;
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
|
||||
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
aq_tensors.reserve(group_count);
|
||||
bq_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
aq_dev_buf.reserve(group_count);
|
||||
bq_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
AQK = K / AQuantGroupSize::kK;
|
||||
BQK = K / BQuantGroupSize::kK;
|
||||
|
||||
if(K % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
if(K % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode");
|
||||
}
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{}));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{}));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
stride_AQs[i] =
|
||||
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout{}));
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout{}));
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{}))));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(BLayout{}))));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(BQLayout{}))));
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << " aq: " << aq_tensors[i].mDesc
|
||||
<< " bq: " << bq_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
aq_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
aq_tensors[i].get_element_space_size_in_bytes()));
|
||||
bq_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_aq,
|
||||
p_bq,
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
BQK,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
stride_Cs[i],
|
||||
stride_AQs[i],
|
||||
stride_BQs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
for(const auto& arg : gemm_descs)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, false, 1};
|
||||
ck_tile::hip_check_error(
|
||||
hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
invoke_grouped_gemm_persistent<GemmConfig>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto stream = ck_tile::stream_config{nullptr, false, 1};
|
||||
invoke_grouped_gemm_abquant<GemmConfig>(gemm_descs, stream, kargs_ptr);
|
||||
}
|
||||
|
||||
// Copy results back to host for validation
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_abquant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value);
|
||||
pass &=
|
||||
ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "gemm[" << i
|
||||
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
// Aliases for split test files
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmABQuant_1x1x128 = TestCkTileGroupedGemmABQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmABQuant_1x128x128 = TestCkTileGroupedGemmABQuant<Tuple>;
|
||||
Reference in New Issue
Block a user