Weight Preshuffle Block Scale gemm support (#2877)

* initial commit

* remove extra files

* fixing errors

* updated ReadMe file for mapping of diff quants with diff configs

* addressing review comments

* addressing review comments

* Resolved merge conflicts

* [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled

The get_preshuffle_or was not working as expected, which led to incorrect behavior
in the quantization preshuffle process. This change replaces it with the more reliable
is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied.

* initial commit

* debugging

* working fp8 for init constant

* fp8 working with all inits

* updated block level code with comments

* changing the loop iter

* debugging

* debugging

* debugging

* code fix

* code clean up

* clang formatted

* Add comment

* code cleanup

* clang formatted

* merge conflicts fixes

* applying the latest int4 changes to the piepline

* fixing test code for updated traits

* Adding gtest

* review comments addressed

* addressing review comments

* remove c++20 code

* added flush cache changes

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: root <root@banff-cyxtera-s73-2.ctr.dcgpu>
This commit is contained in:
Khushbu Agarwal
2025-09-29 12:46:37 -07:00
committed by GitHub
parent 2e9428eb63
commit 81458a6681
17 changed files with 1129 additions and 53 deletions

View File

@@ -47,5 +47,6 @@ User need to select correct mapping of config for each quant mode:
| For selecting AQuant | aquant | GemmConfigQuant |
| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant |
| For selecting BQuant | bquant | GemmConfigQuant |
| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill
| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant |

View File

@@ -23,7 +23,6 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
// B datatype is safe to use as compute type as it should be at least fp8
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
@@ -41,10 +40,14 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode>;
QuantMode,
ALayout, // for AQLayout
BLayout, // for BQLayout
GemmConfig::DoubleSmemBuffer>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
@@ -53,7 +56,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -110,9 +116,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
@@ -160,9 +169,49 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = 0;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<typename TypeConfig::ADataType,
typename TypeConfig::BDataType>
rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.c_ptr,
0,
args.M * args.N * sizeof(typename TypeConfig::CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -180,6 +229,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant) &&
GemmConfig::PreshuffleB)
{
throw std::runtime_error(
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
}
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
@@ -391,4 +448,7 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_decode>(argc, argv);
}

View File

@@ -91,6 +91,7 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr bool PreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
};
@@ -145,6 +146,26 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
static constexpr bool PreshuffleQuant = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
@@ -222,7 +243,6 @@ auto create_args(int argc, char* argv[])
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
@@ -240,7 +260,7 @@ auto create_args(int argc, char* argv[])
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1", "rotating count, defaults to 1")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
bool result = arg_parser.parse(argc, argv);

View File

@@ -24,6 +24,22 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
@@ -121,6 +137,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
@@ -393,17 +410,27 @@ int run_gemm_example_with_layouts(int argc,
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PreshuffleB)
{
b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
}
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
if constexpr(GemmConfig::PreshuffleB)
{
b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
}
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
@@ -509,7 +536,7 @@ int run_gemm_example_with_layouts(int argc,
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{