[CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme
This commit is contained in:
Cong Ma
2025-11-07 12:13:45 -05:00
parent a967b4906b
commit bc26224ce1
11 changed files with 54 additions and 76 deletions

View File

@@ -10,11 +10,11 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(${EXE_NAME} EXCLUDE_FROM_ALL
gemm_quant.cpp
gemm_aquant_quantgrouped.cpp
gemm_bquant_quantgourped_prefill_bf8i4.cpp
gemm_bquant_quantgourped_prefill_fp8i4.cpp
gemm_bquant_quantgourped_prefill_bf8.cpp
gemm_bquant_quantgourped_prefill_fp8.cpp
gemm_bquant_quantgourped_preshuffleb_prefill.cpp
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
gemm_bquant_quantgrouped_prefill_bf8.cpp
gemm_bquant_quantgrouped_prefill_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
)
target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic`
## example
```
args:
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: C)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
-h Print help message (default:false)
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-a_layout A tensor data layout - Row or Column (default:R)
-b_layout B tensor data layout - Row or Column (default:C)
-bq_layout Bq tensor data layout - Row or Column (default:C)
-c_layout C tensor data layout - Row or Column (default:R)
-stride_a Tensor A stride (default:0)
-stride_q Tensor AQ stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-prec Data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4 (default:fp8)
-warmup Number of iterations before benchmark the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k SplitK value (default:1)
-device Device id that will be used to run the kernel (default:0)
-init 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache Flush cache before running the kernel (default:true)
-rotating_count Rotating count (default:1000)
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
-preshuffleb Enable preshuffle of tensor B (default:false)
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
```
User need to select correct mapping of config for each quant mode:

View File

@@ -1,10 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"

View File

@@ -1,10 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"

View File

@@ -1,10 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"

View File

@@ -1,10 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"

View File

@@ -1,10 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include "run_gemm_quant_example.inc"

View File

@@ -1,15 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <stdexcept>
#include <stdexcept>
#include <string>
#include <tuple>
@@ -22,14 +16,14 @@
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("h", "false", "print help message")
arg_parser.insert("h", "false", "Print help message")
.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("a_layout", "R", "A tensor data layout - Row or Column")
.insert("b_layout", "C", "B tensor data layout - Row or Column")
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
.insert("c_layout", "R", "C tensor data layout - Row or Column")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
@@ -37,17 +31,17 @@ auto create_args(int argc, char* argv[])
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec",
"fp8",
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
"Data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
.insert("warmup", "50", "Number of iterations before benchmark the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("device", "0", "device id that will be used to run the kernel, default 0")
.insert("split_k", "1", "SplitK value")
.insert("device", "0", "Device id that will be used to run the kernel")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, default to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1")
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle tensor B, default false")
.insert("flush_cache", "true", "Flush cache before running the kernel")
.insert("rotating_count", "1000", "Rotating count")
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
.insert("group_size",
"1x1x128",
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -17,6 +17,8 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
size_t combined_hash = 0;
for(const auto& str : inputs)
{
// Hash combine using golden ratio constant and bit shifts for good distribution and
// order-dependent mixing
combined_hash ^= hasher(str) + 0x9e3779b9 + (combined_hash << 6) + (combined_hash >> 2);
}
return combined_hash;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstring>
@@ -7,7 +7,6 @@
#include <ostream>
#include <random>
#include <stdexcept>
#include <stdexcept>
#include <string>
#include <tuple>