mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] Row/Col quant gemm (#2729)
* Add cshuffle epilogue test
* add the poc implementation to the epilogue and tests
* refactor cshuffle epilogue
* WIP: adding tensor/tile usage to scale_tile
* fix usage of tile_elementwise_inout
* add gemm_quant_kernel for generalizing gemm quant kernel
* Add problem specific to different quants, add QuantType to Traits
* Add quant_type to quant_kernel template parameters
* Create aq/bq_block_windows and views depending on QuantType
* Use tile windows as inputs in cshuffle epilogue
* Fix some issues in epilogue
* initial new example code for new general gemm quant kernel test
* Fix issues in kernel
* Add verification check for rowcol Quantmode
* use AccDataType instead of AQ in pipeline
* fix aquant preshuffle
* fix formatting
* some cleanup
* remove gemm_aquant_basic.cpp
* remove gemm_aquant_kernel.hpp
* fix tests for the renamed quant kernel
* fix formatting
* clean example files
* fix some merge conflicts
* fix preshufflequant rename issue
* fix some templates after merging with develop
* fix test preshuffle parameter
* fix formatting
* Unify bquant kernel to the common quant kernel
* remove bquant kernel also from common header
* fix formatting
* clean up commented code
* fix formatting config hpp
* fix merge mistake
* Non-const for movable windows
* fix formatting
* Fix grammar in README
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
* Remove #include<bit> and clean up example
* fix strides
* Add some descriptions for move_windows
---------
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
[ROCm/composable_kernel commit: c6010f2953]
This commit is contained in:
@@ -6,8 +6,12 @@ endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp)
|
||||
target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
|
||||
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp)
|
||||
target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp)
|
||||
target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
# GEMM Matrix Multiplication
|
||||
# Quant GEMM Matrix Multiplication
|
||||
|
||||
This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation.
|
||||
This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation.
|
||||
|
||||
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
|
||||
- Row and Column-wise scaled: scaling implemented in Epilogue
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The aquant pipeline method on the gemm calculation
|
||||
make tile_example_gemm_aquant_basic -j
|
||||
# Compile the quant kernels
|
||||
make tile_example_gemm_quant_basic -j
|
||||
make tile_example_gemm_bquant_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_aquant_basic`
|
||||
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
|
||||
|
||||
## example
|
||||
```
|
||||
@@ -22,15 +25,16 @@ args:
|
||||
-n n dimension (default:2048)
|
||||
-k k dimension (default:64)
|
||||
-a_layout Tensor A data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: C)
|
||||
-c_layout Tensor C data layout (default: R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
|
||||
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-quant_mode Which quant method to use (aquant, rowcol)
|
||||
```
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr bool transposed_warp_gemm = true;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transposed_warp_gemm,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
#include "run_gemm_aquant_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for A.");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
|
||||
@@ -21,7 +21,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
@@ -50,13 +50,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -109,8 +110,10 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
float gemm_calc_bquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
@@ -50,13 +50,14 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -108,8 +109,10 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::BQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
376
example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp
Normal file
376
example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp
Normal file
@@ -0,0 +1,376 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = GemmConfig::kPadM;
|
||||
constexpr bool kPadN = GemmConfig::kPadN;
|
||||
constexpr bool kPadK = GemmConfig::kPadK;
|
||||
|
||||
constexpr int kBlockPerCu = GemmConfig::kBlockPerCu;
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
// B datatype is safe to use as compute type as it should be at least fp8
|
||||
using ComputeDataType = typename TypeConfig::BDataType;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr bool transposed_warp_gemm = false;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using CodegenPipelineProblem = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::QDataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
using CodegenGemmPipeline =
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<CodegenPipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize, QuantMode>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for A.");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void print_help(const char* program_name)
|
||||
{
|
||||
std::cout
|
||||
<< "Usage: " << program_name << " [OPTIONS]\n"
|
||||
<< "\n Parameters:\n"
|
||||
<< " -quant_mode=MODE aquant (A quantization) or rowcol (row/column quantization)\n"
|
||||
<< " -v=LEVEL 0=No validation, 1=CPU validation, 2=GPU validation\n"
|
||||
<< " -prec=TYPE Data types: fp8, bf8, i4fp8, i4bf8, i4f32fp8, i4f32bf8\n"
|
||||
<< " -m, -n, -k=SIZE Matrix dimensions (M×K) @ (K×N) = (M×N)\n"
|
||||
<< "\nExample:\n"
|
||||
<< " " << program_name << " -quant_mode=rowcol -v=1 -prec=fp8\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
|
||||
for(int i = 1; i < argc; ++i)
|
||||
{
|
||||
if(strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0)
|
||||
{
|
||||
print_help(argv[0]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode! Use 'aquant' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported quantization mode! Use 'aquant' or 'rowcol'");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
ck_tile::fp8_t>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
ck_tile::bf8_t>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
||||
|
||||
if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>,
|
||||
TypeConfig,
|
||||
128,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported quantization mode for this datatype! Use 'aquant'.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
|
||||
@@ -83,14 +83,13 @@ struct GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
struct GemmConfigQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
@@ -105,29 +104,6 @@ struct GemmConfigDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -148,9 +124,7 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -237,19 +211,20 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_q", "0", "Tensor AQ stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent")
|
||||
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
|
||||
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr")
|
||||
.insert("quant_mode", "aquant", "Choose aquant (default) or rowcol");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <bit>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
|
||||
@@ -59,7 +58,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::AQuantGemmHostArgs args;
|
||||
ck_tile::QuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
@@ -68,7 +67,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK = AQK;
|
||||
args.QK_A = AQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <bit>
|
||||
#include <random>
|
||||
|
||||
template <typename Layout>
|
||||
@@ -60,7 +59,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::BQuantGemmHostArgs args;
|
||||
ck_tile::QuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.bq_ptr = bq_bqk_n_dev_buf.GetDeviceBuffer();
|
||||
@@ -69,7 +68,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK = BQK;
|
||||
args.QK_B = BQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
404
example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Normal file
404
example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Normal file
@@ -0,0 +1,404 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
|
||||
{
|
||||
if(t.get_lengths().size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Host tensor is not rank 2 tensor.");
|
||||
}
|
||||
int m_ = t.get_lengths()[0];
|
||||
int aqk_ = t.get_lengths()[1];
|
||||
if(aqk_ % block_aq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& aq_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::DeviceMem* bq_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t AQK,
|
||||
ck_tile::index_t BQK,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_AQ,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_BQ,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::QuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.aq_ptr = aq_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.bq_ptr = (bq_dev_buf != nullptr) ? bq_dev_buf->GetDeviceBuffer() : nullptr;
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK_A = AQK;
|
||||
args.QK_B = BQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.stride_AQ = stride_AQ;
|
||||
args.stride_BQ = stride_BQ;
|
||||
|
||||
float ave_time = gemm_calc_quant<GemmConfig,
|
||||
TypeConfig,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
QuantMode>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(typename TypeConfig::ADataType) * M * K +
|
||||
sizeof(typename TypeConfig::QDataType) * M * AQK +
|
||||
sizeof(typename TypeConfig::BDataType) * N * K +
|
||||
sizeof(typename TypeConfig::CDataType) * M * N;
|
||||
if(bq_dev_buf != nullptr)
|
||||
{
|
||||
num_byte += sizeof(typename TypeConfig::QDataType) * BQK;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
|
||||
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
std::cout << " StrideBQ =" << stride_BQ;
|
||||
}
|
||||
std::cout << " A_Type = " << DataTypeTraits<typename TypeConfig::ADataType>::name
|
||||
<< " AQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name
|
||||
<< " B_Type = " << DataTypeTraits<typename TypeConfig::BDataType>::name;
|
||||
if constexpr(!std::is_same_v<typename TypeConfig::QDataType, void>)
|
||||
{
|
||||
std::cout << " BQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name;
|
||||
}
|
||||
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = "
|
||||
<< (QuantMode == ck_tile::QuantType::AQuantGrouped ? "AQuantGrouped" : "RowColQuant")
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using AQDataType = typename TypeConfig::QDataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using BQDataType = typename TypeConfig::QDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using CDataType = typename TypeConfig::CDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if(K % QuantGroupSize != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be aligned with QuantGroupSize for AQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t AQK, BQK;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1]
|
||||
BQK = N; // Column quantization: tensor shape [1, N]
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported QuantMode");
|
||||
}
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t stride_BQ = 0;
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
// Conditional stride calculation based on QuantMode
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
|
||||
stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout));
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
// Create AQ tensor with appropriate shape
|
||||
ck_tile::HostTensor<AQDataType> aq_tensor = [&]() {
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
return ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
}
|
||||
}();
|
||||
// Create BQ tensor only for RowColQuant mode
|
||||
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
std::cout << "Monotonic initialization is not supported." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_tensor);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
aq_tensor.SetZero();
|
||||
b_k_n.SetZero();
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_tensor_ptr->SetZero();
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem aq_dev_buf(aq_tensor.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_dev_buf_ptr =
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
|
||||
}
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant && QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
shuffle_aq(aq_tensor, GemmConfig::K_Tile / QuantGroupSize);
|
||||
aq_dev_buf.ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
aq_dev_buf.ToDevice(aq_tensor.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
|
||||
}
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
TypeConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
QuantMode>(a_m_k_dev_buf,
|
||||
aq_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
(QuantMode == ck_tile::QuantType::RowColQuant) ? bq_dev_buf_ptr.get()
|
||||
: nullptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
BQK,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
stride_BQ,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(a_m_k, aq_tensor, b_k_n, c_m_n_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k, aq_tensor, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
|
||||
}
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -115,6 +115,92 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<AQDataType>& aq_m_1,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const HostTensor<BQDataType>& bq_1_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
// Init accumulator
|
||||
AccDataType v_acc = 0;
|
||||
// Get row scale for A and column scale for B
|
||||
float a_scale = aq_m_1(m, 0);
|
||||
float b_scale = bq_1_n(0, n);
|
||||
|
||||
// Compute the dot product
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
// Process A data
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
|
||||
// Process B data
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
v_acc = v_acc * a_scale * b_scale;
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -210,8 +213,12 @@ struct CShuffleEpilogue
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
@@ -257,11 +264,120 @@ struct CShuffleEpilogue
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
template <auto iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE void
|
||||
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
|
||||
{
|
||||
// Load tiles
|
||||
const auto scale_m_tile = load_tile(scale_m_window);
|
||||
const auto scale_n_tile = load_tile(scale_n_window);
|
||||
|
||||
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(
|
||||
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
|
||||
|
||||
// Move scale windows
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
}
|
||||
|
||||
template <auto iAccess, typename OAccTile, typename LdsTile>
|
||||
CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
|
||||
{
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
}
|
||||
|
||||
template <typename LdsTile, typename InLdsWindow>
|
||||
CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
|
||||
{
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
}
|
||||
|
||||
template <typename DramWindows, typename COutTensor>
|
||||
CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
|
||||
{
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
}
|
||||
|
||||
template <typename OutDramWindow, typename COutTensor>
|
||||
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
|
||||
const COutTensor& c_out_tensor)
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Move both the output and D tensors windows for the next access.
|
||||
*/
|
||||
template <auto iAccess, typename OutDramWindow, typename DDramWindows>
|
||||
CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
|
||||
{
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
// move the output dram window
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
// move windows for each of the D matrices (inputs for element-wise)
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
|
||||
struct EmptyScale
|
||||
{
|
||||
};
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
void* p_smem,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
@@ -282,9 +398,6 @@ struct CShuffleEpilogue
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
@@ -306,60 +419,46 @@ struct CShuffleEpilogue
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
|
||||
}
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
cast_lds_tile(lds_tile, in_lds_window);
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
apply_d_tensors(d_dram_windows, c_out_tensor);
|
||||
store_to_dram(out_dram_window, c_out_tensor);
|
||||
move_windows<iAccess>(out_dram_window, d_dram_windows);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
|
||||
@@ -1,679 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BQuantGemmProblem
|
||||
{
|
||||
CK_TILE_HOST BQuantGemmProblem() = default;
|
||||
CK_TILE_HOST BQuantGemmProblem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_BQ_)
|
||||
: M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
QK(QK_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_C(stride_C_),
|
||||
stride_BQ(stride_BQ_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t QK;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t stride_BQ;
|
||||
};
|
||||
|
||||
struct BQuantGemmHostArgs : public BQuantGemmProblem
|
||||
{
|
||||
CK_TILE_HOST BQuantGemmHostArgs() = default;
|
||||
CK_TILE_HOST BQuantGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
const void* bq_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_BQ_)
|
||||
: BQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_BQ_),
|
||||
a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
bq_ptr(bq_ptr_),
|
||||
c_ptr(c_ptr_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* bq_ptr;
|
||||
void* c_ptr;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
struct BQuantGemmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* bq_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t QK;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t stride_BQ;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct BQuantGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename GemmPipeline::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr BQuantGemmKernelArgs
|
||||
MakeKernelArgs(const BQuantGemmHostArgs& hostArgs)
|
||||
{
|
||||
return BQuantGemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.bq_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.QK,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_BQ,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const BQuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
|
||||
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const BQuantGemmKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
if(kargs.QK % GemmPipeline::GetVectorSizeBQ() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
const BQuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.N, kargs.QK),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, bq_tensor_view, b_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_pad_view = [&]() {
|
||||
const auto& bq_tensor_view = views.at(I1);
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return pad_tensor_view(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
// TODO: Add support for padding.
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, bq_pad_view, b_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& bq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_block_window = [&]() {
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_n, 0});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, bq_block_window, b_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const BQuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(BQuantGemmKernelArgs kargs) const
|
||||
{
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
RunGemm(a_ptr, b_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -12,117 +12,218 @@
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct AQuantGemmProblem
|
||||
namespace detail {
|
||||
// Helper templates for safe type extraction
|
||||
template <typename T, typename Default>
|
||||
struct get_aq_layout_or
|
||||
{
|
||||
CK_TILE_HOST AQuantGemmProblem() = default;
|
||||
CK_TILE_HOST AQuantGemmProblem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_)
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::AQLayout; }
|
||||
struct get_aq_layout_or<T, Default>
|
||||
{
|
||||
using type = typename T::AQLayout;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
struct get_bq_layout_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::BQLayout; }
|
||||
struct get_bq_layout_or<T, Default>
|
||||
{
|
||||
using type = typename T::BQLayout;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
struct get_aq_data_type_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::AQDataType; }
|
||||
struct get_aq_data_type_or<T, Default>
|
||||
{
|
||||
using type = typename T::AQDataType;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
struct get_bq_data_type_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::BQDataType; }
|
||||
struct get_bq_data_type_or<T, Default>
|
||||
{
|
||||
using type = typename T::BQDataType;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
struct get_preshuffle_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::PreshuffleQuant; }
|
||||
struct get_preshuffle_or<T, Default>
|
||||
{
|
||||
using type = typename T::PreshuffleQuant;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct QuantGemmProblem
|
||||
{
|
||||
CK_TILE_HOST QuantGemmProblem() = default;
|
||||
CK_TILE_HOST QuantGemmProblem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_A_,
|
||||
index_t QK_B_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_,
|
||||
index_t stride_BQ_)
|
||||
: M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
QK(QK_),
|
||||
QK_A(QK_A_),
|
||||
QK_B(QK_B_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_C(stride_C_),
|
||||
stride_AQ(stride_AQ_)
|
||||
stride_AQ(stride_AQ_),
|
||||
stride_BQ(stride_BQ_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t QK;
|
||||
index_t QK_A;
|
||||
index_t QK_B;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t stride_AQ;
|
||||
index_t stride_BQ;
|
||||
};
|
||||
|
||||
struct AQuantGemmHostArgs : public AQuantGemmProblem
|
||||
struct QuantGemmHostArgs : public QuantGemmProblem
|
||||
{
|
||||
CK_TILE_HOST AQuantGemmHostArgs() = default;
|
||||
CK_TILE_HOST AQuantGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
const void* aq_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_)
|
||||
: AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_),
|
||||
CK_TILE_HOST QuantGemmHostArgs() = default;
|
||||
CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
const void* aq_ptr_,
|
||||
const void* bq_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_A_,
|
||||
index_t QK_B_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_,
|
||||
index_t stride_BQ_)
|
||||
: QuantGemmProblem(
|
||||
M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
|
||||
a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
aq_ptr(aq_ptr_),
|
||||
bq_ptr(bq_ptr_),
|
||||
c_ptr(c_ptr_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* aq_ptr;
|
||||
void* c_ptr;
|
||||
index_t k_batch;
|
||||
const void* a_ptr = nullptr;
|
||||
const void* b_ptr = nullptr;
|
||||
const void* aq_ptr = nullptr;
|
||||
const void* bq_ptr = nullptr;
|
||||
void* c_ptr = nullptr;
|
||||
index_t k_batch = 0;
|
||||
};
|
||||
|
||||
struct AQuantGemmKernelArgs
|
||||
struct QuantGemmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* aq_ptr;
|
||||
const void* bq_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t QK;
|
||||
index_t QK_A;
|
||||
index_t QK_B;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t stride_AQ;
|
||||
index_t stride_BQ;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct AQuantGemmKernel
|
||||
template <typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
QuantType QuantType_>
|
||||
struct QuantGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using AQLayout = remove_cvref_t<
|
||||
typename detail::get_aq_layout_or<GemmPipeline, typename GemmPipeline::ALayout>::type>;
|
||||
using BQLayout = remove_cvref_t<
|
||||
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant;
|
||||
static constexpr bool PreshuffleQuant = remove_cvref_t<
|
||||
typename detail::get_preshuffle_or<GemmPipeline, std::false_type>::type>::value;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
using AQDataType =
|
||||
remove_cvref_t<typename detail::get_aq_data_type_or<GemmPipeline, AccDataType>::type>;
|
||||
using BQDataType =
|
||||
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
|
||||
|
||||
static constexpr auto I0 = number<0>(); // A Tensor
|
||||
static constexpr auto I1 = number<1>(); // AQ Tensor
|
||||
static constexpr auto I2 = number<2>(); // B Tensor
|
||||
static constexpr auto I3 = number<3>(); // BQ Tensor
|
||||
static constexpr auto I4 = number<4>(); // C Tensor
|
||||
|
||||
static constexpr auto kQuantType = QuantType_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -133,22 +234,25 @@ struct AQuantGemmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
|
||||
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)
|
||||
CK_TILE_HOST static constexpr QuantGemmKernelArgs
|
||||
MakeKernelArgs(const QuantGemmHostArgs& hostArgs)
|
||||
{
|
||||
return AQuantGemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.aq_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.QK,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_AQ,
|
||||
hostArgs.k_batch};
|
||||
return QuantGemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.aq_ptr,
|
||||
hostArgs.bq_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.QK_A,
|
||||
hostArgs.QK_B,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_AQ,
|
||||
hostArgs.stride_BQ,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -158,7 +262,7 @@ struct AQuantGemmKernel
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs,
|
||||
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
|
||||
@@ -198,7 +302,7 @@ struct AQuantGemmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs)
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
@@ -209,14 +313,31 @@ struct AQuantGemmKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0)
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: no kernel currently uses BQuant like this:
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -350,8 +471,9 @@ struct AQuantGemmKernel
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
const AQuantGemmKernelArgs& kargs,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
@@ -380,71 +502,85 @@ struct AQuantGemmKernel
|
||||
return ck_tile::integer_least_multiple(length, alignment) - length;
|
||||
};
|
||||
|
||||
const auto& make_preshuffled_aq_tensor_view = [&]() {
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_y = kargs.QK / GemmPipeline::KPerBlockAQ;
|
||||
|
||||
const auto aq_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
|
||||
make_tuple(aq_x, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
|
||||
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_pad0_desc = transform_tensor_descriptor(
|
||||
aq_desc,
|
||||
make_tuple(make_pass_through_transform(aq_y),
|
||||
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size =
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
aq_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(aq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto aq_pad1_desc = transform_tensor_descriptor(
|
||||
aq_unmerge_pad0_desc,
|
||||
make_tuple(make_pass_through_transform(aq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(
|
||||
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto pad_wave_size =
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
};
|
||||
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
return make_preshuffled_aq_tensor_view();
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
|
||||
|
||||
const auto aq_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
|
||||
make_tuple(aq_x, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
|
||||
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
|
||||
const auto aq_pad0_desc = transform_tensor_descriptor(
|
||||
aq_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
|
||||
const auto wave_tile_size =
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
|
||||
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
aq_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
const auto aq_pad1_desc = transform_tensor_descriptor(
|
||||
aq_unmerge_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(aq_y),
|
||||
make_pass_through_transform(wave_tile_count_x),
|
||||
make_right_pad_transform(
|
||||
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto pad_wave_size =
|
||||
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
aq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
|
||||
make_pass_through_transform(pad_wave_size)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
}
|
||||
else
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK),
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, 0), // broadcasting over n
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type for this
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
@@ -510,6 +646,32 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(0, 1), // broadcasting over m
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.N, kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type for this
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -532,7 +694,8 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
|
||||
return make_tuple(
|
||||
a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -556,6 +719,7 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// no padding
|
||||
const auto& aq_pad_view = [&]() { return views.at(I1); }();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
@@ -576,9 +740,12 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// no padding
|
||||
const auto& bq_pad_view = [&]() { return views.at(I3); }();
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
const auto& c_tensor_view = views.at(I4);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
@@ -595,7 +762,7 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -605,7 +772,8 @@ struct AQuantGemmKernel
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
const auto& bq_pad_view = views.at(I3);
|
||||
const auto& c_pad_view = views.at(I4);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -625,14 +793,13 @@ struct AQuantGemmKernel
|
||||
}();
|
||||
|
||||
const auto& aq_block_window = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block =
|
||||
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block =
|
||||
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
@@ -642,13 +809,27 @@ struct AQuantGemmKernel
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(aq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type?
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
@@ -668,12 +849,36 @@ struct AQuantGemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& bq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type here
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window);
|
||||
return make_tuple(
|
||||
a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -695,16 +900,17 @@ struct AQuantGemmKernel
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const AQuantGemmKernelArgs& kargs,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
@@ -713,22 +919,51 @@ struct AQuantGemmKernel
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
auto& c_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const
|
||||
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
|
||||
{
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
@@ -740,13 +975,15 @@ struct AQuantGemmKernel
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -14,6 +14,7 @@ namespace ck_tile {
|
||||
template <typename ADataType_,
|
||||
typename AQDataType_,
|
||||
typename BDataType_,
|
||||
typename BQDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
@@ -23,12 +24,12 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>
|
||||
struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using Base = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
@@ -44,6 +45,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
using typename Base::CDataType;
|
||||
using typename Base::ComputeDataType;
|
||||
using AQDataType = remove_cvref_t<AQDataType_>;
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
|
||||
@@ -63,6 +65,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
using Base::VectorLoadSize;
|
||||
|
||||
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
|
||||
|
||||
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
@@ -75,7 +78,7 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_aquant_problem",
|
||||
return concat('_', "gemm_quant_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
@@ -94,6 +97,13 @@ struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
return kPadK ? 1 : GetAlignmentAQ();
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
|
||||
{
|
||||
return VectorLoadSize / sizeof(BQDataType);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -108,18 +118,19 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase<ADataType_,
|
||||
AQDataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
AQDataType_,
|
||||
BDataType_,
|
||||
void, // no BQDataType for AQuant
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
@@ -132,96 +143,42 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
struct GemmBQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using Base = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using typename Base::ADataType;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::ComputeDataType;
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CLayout;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
using Base::kBlockSize;
|
||||
|
||||
using Base::kPadK;
|
||||
using Base::kPadM;
|
||||
using Base::kPadN;
|
||||
|
||||
using Base::DoubleSmemBuffer;
|
||||
using Base::VectorLoadSize;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
|
||||
|
||||
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
|
||||
static_assert(Scheduler == GemmPipelineScheduler::Intrawave);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_bquant_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
"QuantGroupSize",
|
||||
kQuantGroupSize);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
|
||||
{
|
||||
return VectorLoadSize / sizeof(BQDataType);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
|
||||
};
|
||||
using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
void, // no AQDataType for BQuant
|
||||
BDataType_,
|
||||
BQDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
false, // no TransposeC
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename BQDataType_,
|
||||
typename CDataType_,
|
||||
typename AccDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool TransposeC_ = false,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
using GemmBQuantPipelineProblem = GemmBQuantPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
BQDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
|
||||
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
AccDataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_>;
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,9 +4,17 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct QuantType : std::uint16_t
|
||||
{
|
||||
AQuantGrouped = 0,
|
||||
BQuantGrouped = 1,
|
||||
RowColQuant = 2
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
@@ -14,19 +22,24 @@ template <bool kPadM_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
typename AQLayout_ = ALayout_>
|
||||
struct TileGemmAQuantTraits
|
||||
QuantType QuantType_,
|
||||
typename AQLayout_ = ALayout_,
|
||||
typename BQLayout_ = BLayout_>
|
||||
struct TileGemmQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr QuantType kQuantType = QuantType_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using AQLayout = AQLayout_;
|
||||
using BQLayout = BQLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
@@ -35,31 +48,4 @@ struct TileGemmAQuantTraits
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool PreshuffleQuant_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
typename BQLayout_ = BLayout_>
|
||||
struct TileGemmBQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
using BQLayout = BQLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -23,4 +23,5 @@ add_subdirectory(add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(gemm_block_scale)
|
||||
add_subdirectory(utility)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(epilogue)
|
||||
add_subdirectory(atomic_add_op)
|
||||
|
||||
1
test/ck_tile/epilogue/CMakeLists.txt
Normal file
1
test/ck_tile/epilogue/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_gtest_executable(test_ck_tile_cshuffle_epilogue test_cshuffle_epilogue.cpp)
|
||||
84
test/ck_tile/epilogue/test_cshuffle_epilogue.cpp
Normal file
84
test/ck_tile/epilogue/test_cshuffle_epilogue.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_cshuffle_epilogue_util.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
class CShuffleEpilogueTest : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
void SetUp() override {}
|
||||
};
|
||||
|
||||
TEST_F(CShuffleEpilogueTest, BasicHalfTest)
|
||||
{
|
||||
// Basic test configuration with half_t data types
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
constexpr index_t kMPerBlock = 256;
|
||||
constexpr index_t kNPerBlock = 256;
|
||||
constexpr index_t MWave = 2;
|
||||
constexpr index_t NWave = 2;
|
||||
constexpr index_t MPerXdl = 32;
|
||||
constexpr index_t NPerXdl = 32;
|
||||
constexpr index_t KPerXdl = 8;
|
||||
|
||||
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
ODataType,
|
||||
kMPerBlock,
|
||||
kNPerBlock,
|
||||
MWave,
|
||||
NWave,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>();
|
||||
EXPECT_TRUE(result) << "Basic CShuffleEpilogue test failed";
|
||||
}
|
||||
|
||||
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
|
||||
{
|
||||
// Basic test configuration with half_t data types
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
constexpr index_t kMPerBlock = 256;
|
||||
constexpr index_t kNPerBlock = 256;
|
||||
constexpr index_t MWave = 2;
|
||||
constexpr index_t NWave = 2;
|
||||
constexpr index_t MPerXdl = 32;
|
||||
constexpr index_t NPerXdl = 32;
|
||||
constexpr index_t KPerXdl = 8;
|
||||
|
||||
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
ODataType,
|
||||
kMPerBlock,
|
||||
kNPerBlock,
|
||||
MWave,
|
||||
NWave,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(true);
|
||||
EXPECT_TRUE(result) << "Scale CShuffleEpilogue test failed";
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
191
test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp
Normal file
191
test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp
Normal file
@@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Simple test kernel to invoke the CShuffleEpilogue
|
||||
template <typename Problem, index_t M, index_t N, bool UseScale>
|
||||
__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data,
|
||||
float* m_scale,
|
||||
float* n_scale)
|
||||
{
|
||||
using Epilogue = CShuffleEpilogue<Problem>;
|
||||
|
||||
static_assert(Problem::kMPerBlock <= M && Problem::kNPerBlock <= N,
|
||||
"Block size must fit in tensor dimensions");
|
||||
|
||||
// Allocate shared memory for epilogue
|
||||
__shared__ char smem[Epilogue::GetSmemSize()];
|
||||
|
||||
// Create accumulator tile
|
||||
constexpr auto lds_distribution_encode =
|
||||
make_static_tile_distribution(Epilogue::MakeLdsDistributionEncode());
|
||||
auto acc_tile =
|
||||
make_static_distributed_tensor<typename Epilogue::AccDataType>(lds_distribution_encode);
|
||||
|
||||
// Fill acc_tile with a simple pattern
|
||||
auto& acc_buffer = acc_tile.get_thread_buffer();
|
||||
acc_buffer[0] = 2.0F;
|
||||
|
||||
// Create output tensor view
|
||||
auto output_tensor_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(output_data,
|
||||
make_tuple(M, N),
|
||||
make_tuple(N, 1),
|
||||
number<Epilogue::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
|
||||
// Create output tile window
|
||||
auto output_tile_window =
|
||||
make_tile_window(output_tensor_view,
|
||||
make_tuple(number<Problem::kMPerBlock>{}, number<Problem::kNPerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
// Create empty D tensors tuple (we're ignoring ds_dram_windows for this test)
|
||||
auto empty_ds = make_tuple();
|
||||
|
||||
// Call the epilogue
|
||||
if constexpr(UseScale)
|
||||
{
|
||||
const auto m_scale_window = make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
m_scale, make_tuple(M, N), make_tuple(1, 0), number<1>{}, number<1>{}),
|
||||
make_tuple(number<Problem::kMPerBlock>{}, number<Problem::kNPerBlock>{}),
|
||||
{0, 0});
|
||||
const auto n_scale_window = make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
n_scale, make_tuple(M, N), make_tuple(0, 1), number<1>{}, number<1>{}),
|
||||
make_tuple(number<Problem::kMPerBlock>{}, number<Problem::kNPerBlock>{}),
|
||||
{0, 0});
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, m_scale_window, n_scale_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem);
|
||||
}
|
||||
}
|
||||
|
||||
// Test configuration helper
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename ODataType,
|
||||
index_t kM,
|
||||
index_t kN,
|
||||
index_t MWave,
|
||||
index_t NWave,
|
||||
index_t MPerXdl,
|
||||
index_t NPerXdl,
|
||||
index_t KPerXdl>
|
||||
using SimpleCShuffleEpilogueProblem =
|
||||
CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // Empty Ds datatype tuple
|
||||
AccDataType,
|
||||
ODataType,
|
||||
ck_tile::tuple<>, // Empty Ds layout
|
||||
tensor_layout::gemm::RowMajor, // ELayout
|
||||
ck_tile::element_wise::PassThrough, // CDElementwise
|
||||
kM,
|
||||
kN,
|
||||
MWave,
|
||||
NWave,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
false, // isCTransposed,
|
||||
memory_operation_enum::set>;
|
||||
|
||||
template <typename Problem, index_t M, index_t N>
|
||||
bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
{
|
||||
using ODataType = typename Problem::ODataType;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
std::cout << "Running CShuffleEpilogue test with M=" << M << ", N=" << N
|
||||
<< ", MPerBlock=" << kMPerBlock << ", NPerBlock=" << kNPerBlock
|
||||
<< ", BlockSize=" << kBlockSize << std::endl;
|
||||
|
||||
// Allocate host memory
|
||||
const size_t output_size = M * N;
|
||||
|
||||
std::vector<ODataType> host_output(output_size, static_cast<ODataType>(0));
|
||||
|
||||
// Allocate device memory
|
||||
ODataType* device_output;
|
||||
|
||||
HIP_CHECK_ERROR(hipMalloc(&device_output, output_size * sizeof(ODataType)));
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
device_output, host_output.data(), output_size * sizeof(ODataType), hipMemcpyHostToDevice));
|
||||
|
||||
// Launch kernel
|
||||
dim3 gridSize(1, 1, 1);
|
||||
dim3 blockSize(kBlockSize, 1, 1);
|
||||
|
||||
if(use_scale)
|
||||
{
|
||||
float* m_scale;
|
||||
float* n_scale;
|
||||
std::vector<float> h_m_scale(M, 1.0F);
|
||||
std::vector<float> h_n_scale(N, 1.0F);
|
||||
h_n_scale[1] = 2.0F; // multiply one col only with 2
|
||||
HIP_CHECK_ERROR(hipMalloc(&m_scale, M * sizeof(float)));
|
||||
HIP_CHECK_ERROR(hipMalloc(&n_scale, N * sizeof(float)));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice));
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, true>
|
||||
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, false>
|
||||
<<<gridSize, blockSize>>>(device_output, nullptr, nullptr);
|
||||
}
|
||||
|
||||
// Check for kernel launch errors
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
// Copy results back
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost));
|
||||
|
||||
// Basic verification - just check that output has a 2, and 4 if using scaling
|
||||
bool has_2 =
|
||||
type_convert<float>(host_output[0]) > 1.9F && type_convert<float>(host_output[0]) < 2.1F;
|
||||
bool scale_has_4 = true;
|
||||
if(use_scale)
|
||||
{
|
||||
scale_has_4 = type_convert<float>(host_output[1]) > 3.9F &&
|
||||
type_convert<float>(host_output[1]) < 4.1F;
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
HIP_CHECK_ERROR(hipFree(device_output));
|
||||
|
||||
return has_2 && scale_has_4;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -240,4 +240,4 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
@@ -55,13 +55,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
false, // preshuffle
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -114,8 +115,10 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -185,7 +188,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::AQuantGemmHostArgs args;
|
||||
ck_tile::QuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
@@ -194,7 +197,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK = AQK;
|
||||
args.QK_A = AQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
Reference in New Issue
Block a user