supporting prefill shapes for preshuffle block scale gemm (#2975)

* debugging

* debugging for prefill shapes

* comment unused code

* fix for prefill shapes

* clearing up the code

* add int4 to universal gemm example

* clang formatted

* adding test for prefill shapes in block scale gemm

* lil improv on the block pipeline

* Address Review Comment

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-10-10 15:36:24 -07:00
committed by GitHub
parent 9d060d3e3c
commit 3c39d279ab
10 changed files with 137 additions and 89 deletions

View File

@@ -4,8 +4,18 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
- Row and Column-wise scaled: scaling implemented in Epilogue
- Tensor-wise scaled: scaling implemented in Epilogue
- Row and Column-wise scaled: All of the rowwise elements in A Matrix and columwise elements in B Matrix will share the same quantization element and the elementwisde operation will complete in epilogue.
- Tensor-wise scaled: Share the same scalar scale across the whole tensor of A or B
---
## Features
- **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
- **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading
- **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
- **Validation**: CPU/GPU validation and error tolerance options.
## build
```

View File

@@ -47,6 +47,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
QuantMode,
ALayout, // for AQLayout
BLayout, // for BQLayout
false,
GemmConfig::DoubleSmemBuffer>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
@@ -450,4 +451,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv); }

View File

@@ -166,6 +166,26 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
@@ -261,7 +281,7 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);