Merge commit '990d645578b4a195f5c5b8479eeef47d828faa98' into develop

This commit is contained in:
assistant-librarian[bot]
2025-05-20 23:06:37 +00:00
parent 488476841a
commit 271978ec7c
2 changed files with 36 additions and 9 deletions

View File

@@ -67,13 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using TypeConfig = AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>;
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using XDataType = typename TypeConfig::XDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = float;
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using XDataType = typename TypeConfig::XDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = float;
using UnquantYDataType = ck_tile::null_type;
// host verify
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
@@ -184,6 +185,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Rmsnorm2d
{
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n});
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
// simplicity
@@ -191,8 +193,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
InvRmsDataType,
UnquantYDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
}
// yscale

View File

@@ -1,5 +1,29 @@
[Back to the main page](../README.md)
# Composable Kernel profiler
## Profiler GEMM UNIVERSAL kernels
```bash
# arg1: tensor operation (gemm_universal: Universal GEMM)
# arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16->f8; 7: f8->bf16, comp f8; 8: f16@i4; 9: bf16@i4
# arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
# 1: A[m, k] * B[n, k] = C[m, n];
# 2: A[k, m] * B[k, n] = C[m, n];
# 3: A[k, m] * B[n, k] = C[m, n])
# arg4: verification (0: no; 1: yes)
# arg5: initialization (0: no init; 1: integer value; 2: decimal value)
# arg6: print tensor value (0: no; 1: yes)
# arg7: time kernel (0=no, 1=yes)
# arg8 to 13: M, N, K, StrideA, StrideB, StrideC
# arg14: split k into mulitiple batch
# optional:
# arg15: number of warm-up cycles (default 1)
# arg16: number of iterations (default 10)
# arg17: memory for rotating buffer (default 0, size in MB)
################ op datatype layout verify init print time M N K StrideA StrideB StrideC SplitK WarmupCycles Iterations MemoryBuffer
./bin/ckProfiler gemm_universal 1 0 1 1 0 1 4096 4096 4096 4096 4096 4096 1 1 10 0
```
## Profile GEMM kernels
```bash
#arg1: tensor operation (gemm=GEMM)