mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
[CK TILE GEMM] Refactor block_scale_gemm examples
- Update Readme
This commit is contained in:
@@ -10,11 +10,11 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_executable(${EXE_NAME} EXCLUDE_FROM_ALL
|
||||
gemm_quant.cpp
|
||||
gemm_aquant_quantgrouped.cpp
|
||||
gemm_bquant_quantgourped_prefill_bf8i4.cpp
|
||||
gemm_bquant_quantgourped_prefill_fp8i4.cpp
|
||||
gemm_bquant_quantgourped_prefill_bf8.cpp
|
||||
gemm_bquant_quantgourped_prefill_fp8.cpp
|
||||
gemm_bquant_quantgourped_preshuffleb_prefill.cpp
|
||||
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_prefill_bf8.cpp
|
||||
gemm_bquant_quantgrouped_prefill_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
|
||||
)
|
||||
target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
|
||||
@@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-b batch size (default:1)
|
||||
-m m dimension (default:1024)
|
||||
-n n dimension (default:2048)
|
||||
-k k dimension (default:64)
|
||||
-a_layout Tensor A data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: C)
|
||||
-c_layout Tensor C data layout (default: R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
|
||||
-h Print help message (default:false)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-a_layout A tensor data layout - Row or Column (default:R)
|
||||
-b_layout B tensor data layout - Row or Column (default:C)
|
||||
-bq_layout Bq tensor data layout - Row or Column (default:C)
|
||||
-c_layout C tensor data layout - Row or Column (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_q Tensor AQ stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-prec Data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4 (default:fp8)
|
||||
-warmup Number of iterations before benchmark the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:1000)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k SplitK value (default:1)
|
||||
-device Device id that will be used to run the kernel (default:0)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache Flush cache before running the kernel (default:true)
|
||||
-rotating_count Rotating count (default:1000)
|
||||
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
|
||||
-preshuffleb Enable preshuffle of tensor B (default:false)
|
||||
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
|
||||
```
|
||||
|
||||
User need to select correct mapping of config for each quant mode:
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This example demonstrates 2D block scale quantization (N×K) for BQuant
|
||||
// using non-preshuffled configuration.
|
||||
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
|
||||
// This is currently done separately to avoid too verbose dispatching.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
@@ -22,14 +16,14 @@
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("h", "false", "print help message")
|
||||
arg_parser.insert("h", "false", "Print help message")
|
||||
.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row or Column")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row or Column")
|
||||
.insert("bq_layout", "C", "Bq tensor data layout - Row or Column")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row or Column")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_q", "0", "Tensor AQ stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
@@ -37,17 +31,17 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec",
|
||||
"fp8",
|
||||
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
|
||||
"Data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
|
||||
.insert("warmup", "50", "Number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("device", "0", "device id that will be used to run the kernel, default 0")
|
||||
.insert("split_k", "1", "SplitK value")
|
||||
.insert("device", "0", "Device id that will be used to run the kernel")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, default to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1")
|
||||
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle tensor B, default false")
|
||||
.insert("flush_cache", "true", "Flush cache before running the kernel")
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("group_size",
|
||||
"1x1x128",
|
||||
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -17,6 +17,8 @@ inline size_t hash_multiple_strings(const std::vector<std::string>& inputs)
|
||||
size_t combined_hash = 0;
|
||||
for(const auto& str : inputs)
|
||||
{
|
||||
// Hash combine using golden ratio constant and bit shifts for good distribution and
|
||||
// order-dependent mixing
|
||||
combined_hash ^= hasher(str) + 0x9e3779b9 + (combined_hash << 6) + (combined_hash >> 2);
|
||||
}
|
||||
return combined_hash;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <cstring>
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <ostream>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user