mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * addressing review comments * fixing CI issue * addressing reveiw comments * formatting * formatting * fixing aquant operator overlaoding * formatting --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -33,47 +33,50 @@ mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Compile the quant kernels
|
||||
make tile_example_gemm_quant_basic -j
|
||||
make tile_example_gemm_quant -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
|
||||
This will result in an executable `build/bin/tile_example_gemm_quant`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-h Print help message (default:false)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-a_layout A tensor data layout - Row or Column (default:R)
|
||||
-b_layout B tensor data layout - Row or Column (default:C)
|
||||
-bq_layout Bq tensor data layout - Row or Column (default:C)
|
||||
-c_layout C tensor data layout - Row or Column (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_q Tensor AQ stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
|
||||
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
|
||||
-warmup Number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:1000)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k SplitK value (default:1)
|
||||
-device Device id that will be used to run the kernel (default:0)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache Flush cache before running the kernel (default:true)
|
||||
-rotating_count Rotating count (default:1000)
|
||||
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
|
||||
-preshuffleb Enable preshuffle of tensor B (default:false)
|
||||
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
|
||||
-h Print help message (default:false)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-a_layout A tensor data layout - R for Row or C for Column (default:R)
|
||||
-b_layout B tensor data layout - R for Row or C for Column (default:C)
|
||||
-bq_layout Bq tensor data layout - R for Row or C for Column (default:C)
|
||||
-c_layout C tensor data layout - R for Row or C for Column (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_q Tensor AQ stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
|
||||
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
|
||||
-warmup Number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:1000)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k SplitK value (default:1)
|
||||
-device Device id that will be used to run the kernel (default:0)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache Flush cache before running the kernel (default:true)
|
||||
-rotating_count Rotating count (default:1000)
|
||||
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
|
||||
-preshuffleb Enable preshuffle of tensor B (default:false)
|
||||
-preshufflequant Enable preshuffle of quant tensor (default:false)
|
||||
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
|
||||
```
|
||||
|
||||
User need to select correct mapping of config for each quant mode:
|
||||
|
||||
| | quant_mode as runtime argument | Config in cpp file |
|
||||
|:--------|:-----:|-------|
|
||||
| For selecting AQuant | aquant | GemmConfigQuant |
|
||||
| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant |
|
||||
| For selecting BQuant | bquant | GemmConfigQuant |
|
||||
| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill
|
||||
| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant |
|
||||
| | quant_mode as runtime argument | Corresponding cpp file | GemmConfig at the top of cpp file |
|
||||
|:--------|:-----:|:-----:|-------|
|
||||
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
|
||||
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
|
||||
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
|
||||
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
|
||||
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
| For selecting RowCol quant | rowcolquant | gemm_quant_rowcol| GemmConfigRowColQuant |
|
||||
|
||||
|
||||
Reference in New Issue
Block a user