This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,28 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd.cpp)
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS})
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})
set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd")
add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} example_add_rmsnorm2d_rdquant_fwd.cpp)
target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,78 @@
# Add + RMSNorm2D + Rowwise Dynamic Quantization (RDQuant) with CK Tile
This example demonstrates a fused kernel for elementwise addition, 2D RMSNorm, and rowwise dynamic quantization using the CK Tile programming model. This pattern is common in LLMs for efficient normalization and quantized inference.
---
## Algorithm and Math
Given input $X$ and residual $R$:
1. **Elementwise Add**: $Z = X + R$
2. **RMSNorm**: $\text{rms}(Z) = \sqrt{\frac{1}{N} \sum_{i=1}^N Z_i^2 + \epsilon}$, $Y_i = \frac{Z_i}{\text{rms}(Z)} \cdot \gamma_i$
3. **Rowwise Dynamic Quantization**:
- For each row, $s = \max(|Y|) / 127$
- $Q_i = \text{round}(Y_i / s)$, $Q_i \in \text{int8}$
**Output**:
- Quantized tensor $Q$ (int8)
- Per-row scale $s$ (fp32)
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (row or block).
- **Tile Engine**: Loads tiles, performs add, RMSNorm, and quantization.
- **Pipeline**: Modular, can be extended for further fusion.
---
## Build & Run
```bash
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_add_rmsnorm2d_rdquant_fwd -j`nproc`
```
This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`
### Arguments
```bash
args:
-m m dimension (default:3328)
-n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5)
-save_x save rms(invrms) or not. set to 1 in training case (default:1)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec precision (default:fp16)
-quant precision (default:int8)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:add_rmsnorm2d_rdquant_fwd.json)
```
---
## Source Structure
- **Kernel**: [`add_rmsnorm2d_rdquant_fwd.hpp`](add_rmsnorm2d_rdquant_fwd.hpp) (tile-programming kernel template)
- **Executable**: [`add_rmsnorm2d_rdquant_fwd.cpp`](add_rmsnorm2d_rdquant_fwd.cpp), [`example_add_rmsnorm2d_rdquant_fwd.cpp`](example_add_rmsnorm2d_rdquant_fwd.cpp)
- **Build**: `CMakeLists.txt`, `instances/`, `script/`
---
## Related CK Tile Examples
- [10_rmsnorm2d](../10_rmsnorm2d/README.md): RMSNorm2D with tiles
- [12_smoothquant](../12_smoothquant/README.md): SmoothQuant with tiles
- [02_layernorm2d](../02_layernorm2d/README.md): LayerNorm2D with tiles
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -0,0 +1,330 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "add_rmsnorm2d_rdquant_fwd.hpp"
#include <cstring>
#include "ck_tile/utility/json_dump.hpp"
// different threshold for different dtype
template <typename InputDataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("quant", "int8", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "add_rmsnorm2d_rdquant_fwd.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InputDataType, typename QuantizedDataType, bool SaveX>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string input_data_type = arg_parser.get_str("prec");
std::string quantized_data_type = arg_parser.get_str("quant");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using TypeConfig = AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>;
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using GammaDataType = typename TypeConfig::GammaDataType;
using XDataType = typename TypeConfig::XDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = float;
using UnquantYDataType = ck_tile::null_type;
// host verify
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
ck_tile::HostTensor<BDataType> b_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<XDataType> x_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
gamma_buf.ToDevice(gamma_host.data());
std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m
<< ", n:" << n << ", stride:" << stride << std::flush;
add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX};
add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(),
b_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
x_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
epsilon,
m,
n,
stride};
float ave_time = add_rmsnorm2d_rdquant_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(BDataType) * m * n +
sizeof(GammaDataType) * n + sizeof(YScaleDataType) * m +
sizeof(QYDataType) * m * n;
if constexpr(SaveX)
num_byte += sizeof(XDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
using InvRmsDataType = InputDataType;
// Add
{
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
ck_tile::reference_binary_elementwise<ADataType, BDataType, XDataType, ComputeDataType>(
a_host, b_host, x_host_ref, op);
if constexpr(SaveX)
{
x_buf.FromDevice(x_host_dev.data());
auto [rtol, atol] = get_elimit<XDataType>();
if(stride == n)
{
pass = ck_tile::check_err(x_host_dev,
x_host_ref,
std::string("x Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> x_host_dev_row(x_host_dev.begin() + i_r * stride,
x_host_dev.begin() + i_r * stride +
n);
std::vector<QYDataType> x_host_ref_row(x_host_ref.begin() + i_r * stride,
x_host_ref.begin() + i_r * stride +
n);
pass &= ck_tile::check_err(x_host_dev_row,
x_host_ref_row,
std::string("x[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
}
ck_tile::HostTensor<YDataType> y_host({m, n});
// Rmsnorm2d
{
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n});
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
// simplicity
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType,
UnquantYDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<YDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_add_rmsnorm2d_rdquant_fwd_json(arg_parser.get_str("jsonfile"),
input_data_type,
quantized_data_type,
m,
n,
stride,
epsilon,
ave_time,
0,
gb_per_sec,
pass);
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string input_data_type = arg_parser.get_str("prec");
const std::string quantized_data_type = arg_parser.get_str("quant");
int save_x = arg_parser.get_int("save_x");
if(input_data_type == "fp16" && quantized_data_type == "int8" && save_x)
{
return run<ck_tile::half_t, ck_tile::int8_t, true>(arg_parser) ? 0 : -2;
}
else if(input_data_type == "fp16" && quantized_data_type == "int8" && !save_x)
{
return run<ck_tile::half_t, ck_tile::int8_t, false>(arg_parser) ? 0 : -2;
}
else if(input_data_type == "bf16" && quantized_data_type == "int8" && save_x)
{
return run<ck_tile::bf16_t, ck_tile::int8_t, true>(arg_parser) ? 0 : -2;
}
else if(input_data_type == "bf16" && quantized_data_type == "int8" && !save_x)
{
return run<ck_tile::bf16_t, ck_tile::int8_t, true>(arg_parser) ? 0 : -2;
}
else if(input_data_type == "fp16" && quantized_data_type == "fp8" && save_x)
{
return run<ck_tile::half_t, ck_tile::fp8_t, true>(arg_parser) ? 0 : -2;
}
else if(input_data_type == "fp16" && quantized_data_type == "fp8" && !save_x)
{
return run<ck_tile::half_t, ck_tile::fp8_t, false>(arg_parser) ? 0 : -2;
}
else if(input_data_type == "bf16" && quantized_data_type == "fp8" && save_x)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, true>(arg_parser) ? 0 : -2;
}
else if(input_data_type == "bf16" && quantized_data_type == "fp8" && !save_x)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, true>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -0,0 +1,113 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp"
#include <string>
template <typename InputDataType, typename QuantizedDataType>
struct AddRmsnormRdquantTypeConfig;
template <>
struct AddRmsnormRdquantTypeConfig<ck_tile::half_t, ck_tile::int8_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using XDataType = ck_tile::half_t;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
template <>
struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t, ck_tile::int8_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using XDataType = ck_tile::bf16_t;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
template <>
struct AddRmsnormRdquantTypeConfig<ck_tile::half_t, ck_tile::fp8_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using XDataType = ck_tile::half_t;
using YScaleDataType = float;
using QYDataType = ck_tile::fp8_t;
using ComputeDataType = float;
};
template <>
struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t, ck_tile::fp8_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using XDataType = ck_tile::bf16_t;
using YScaleDataType = float;
using QYDataType = ck_tile::fp8_t;
using ComputeDataType = float;
};
// runtime args
struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename InputDataType_,
typename QuantizedDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveX_,
bool kThreePass_>
struct add_rmsnorm2d_rdquant_fwd_traits_
{
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
static constexpr bool kThreePass = kThreePass_;
};
template <typename Traits_>
float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_rdquant_fwd_args a);
// This is the public API, will be generated by script
struct add_rmsnorm2d_rdquant_fwd_traits
{
std::string input_data_type;
std::string quantized_data_type;
bool save_x;
};
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits,
add_rmsnorm2d_rdquant_fwd_args,
const ck_tile::stream_config&);

View File

@@ -0,0 +1,280 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using ADataType = DataType;
using BDataType = DataType;
using GammaDataType = DataType;
using XDataType = DataType;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
ck_tile::HostTensor<BDataType> b_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<XDataType> x_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
a_buf.ToDevice(a_host.data());
b_buf.ToDevice(b_host.data());
gamma_buf.ToDevice(gamma_host.data());
constexpr bool kThreePass = true;
using BlockTile = ck_tile::sequence<4, 128>;
using Vector = ck_tile::sequence<1, 1>;
using ThreadPerBlock = ck_tile::sequence<4, 64>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<ADataType,
BDataType,
GammaDataType,
ComputeDataType,
XDataType,
YScaleDataType,
QYDataType,
Shape,
true, // kPadN
true, // kSaveX
kThreePass>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<Problem>;
using Pipeline = std::conditional_t<kThreePass, ThreePassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(),
b_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
x_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
epsilon,
m,
n,
stride};
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
using InvRmsDataType = DataType;
// Add
{
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
ck_tile::reference_binary_elementwise<ADataType, BDataType, XDataType, ComputeDataType>(
a_host, b_host, x_host_ref, op);
x_buf.FromDevice(x_host_dev.data());
auto [rtol, atol] = get_elimit<XDataType>();
if(stride == n)
{
pass = ck_tile::check_err(
x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> x_host_dev_row(x_host_dev.begin() + i_r * stride,
x_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> x_host_ref_row(x_host_ref.begin() + i_r * stride,
x_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(x_host_dev_row,
x_host_ref_row,
std::string("x[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
ck_tile::HostTensor<YDataType> y_host({m, n});
// Rmsnorm2d
{
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::HostTensor<ck_tile::null_type> unquant_y_host_ref({m, n});
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
// simplicity
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<YDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -0,0 +1,227 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <ck_tile/core.hpp>
#include "add_rmsnorm2d_rdquant_fwd.hpp"
template <typename InputDataType_,
typename QuantizedDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveX_,
bool kThreePass_>
using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<InputDataType_,
QuantizedDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveX_,
kThreePass_>;
template <typename input_data_type, typename quantized_data_type>
float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
add_rmsnorm2d_rdquant_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
// rm rn tm tn vn pd x 3p
if(a.n <= 64) {
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 4, 64, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 4, 64, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 6, 4, 64, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1,12, 4, 64, 1, true, true, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 2, 128, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 2, 128, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 2, 128, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 256, 1, true, true, false>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 4, 64, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 2, 128, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 256, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 6, 1, 256, 1, true, true, false>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 1, 256, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 256, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 256, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 256, 1, true, true, false>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 128, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 256, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 6, 1, 256, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 1024, 1, true, true, false>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 256, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 256, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 1024, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 1, true, true, false>>(s, a);
}
else if(a.n <= 8192) {
if(a.n<8192){
if(t.save_x){
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, true, true, false>>(s, a);
}
else{
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, true, false, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, true, false, false>>(s, a);
}
}
else{
if(t.save_x){
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, false, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, false, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, false, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, false, true, false>>(s, a);
}
else{
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, false, false, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, false, false, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, false, false, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, false, false, false>>(s, a);
}
}
}
else if(a.n > 8192) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, true, true, true>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, true, true, true>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, true, true, true>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, true, true, true>>(s, a);
}
return r;
// clang-format on
}
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
add_rmsnorm2d_rdquant_fwd_args a,
const ck_tile::stream_config& s)
{
if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
!t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
!t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::fp8_t>(t, a, s);
}
else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
!t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::fp8_t>(t, a, s);
}
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::fp8_t>(t, a, s);
}
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
!t.save_x)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::fp8_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
}

View File

@@ -0,0 +1,25 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
#if 0
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
#endif
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,14 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,14 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,14 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,25 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
#if 0
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
#endif
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,14 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,14 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,14 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,69 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <ck_tile/core.hpp>
#include "add_rmsnorm2d_rdquant_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = add_rmsnorm2d_rdquant_fwd_args;
template <typename InputDataType_,
typename QuantizedDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<InputDataType_,
QuantizedDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
template <typename Traits_>
float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
{
using InputDataType = typename Traits_::InputDataType;
using QuantizedDataType = typename Traits_::QuantizedDataType;
using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::ADataType,
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::BDataType,
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::GammaDataType,
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::ComputeDataType,
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::XDataType,
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::YScaleDataType,
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::QYDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveX,
Traits_::kThreePass>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<PipelineProblem>;
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kThreePass, ThreePassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -0,0 +1,40 @@
#!/bin/sh
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000

View File

@@ -0,0 +1,33 @@
#!/bin/sh
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13
$EXE -prec=$pr_i -m=17 -n=16
$EXE -prec=$pr_i -m=1 -n=100
$EXE -prec=$pr_i -m=4 -n=128
$EXE -prec=$pr_i -m=80 -n=127
$EXE -prec=$pr_i -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=7 -n=599
$EXE -prec=$pr_i -m=19 -n=512
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=11 -n=510
$EXE -prec=$pr_i -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=91 -n=636
$EXE -prec=$pr_i -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=31 -n=1024
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=8 -n=1501
$EXE -prec=$pr_i -m=3 -n=1826
$EXE -prec=$pr_i -m=5 -n=2040
$EXE -prec=$pr_i -m=7 -n=2734
$EXE -prec=$pr_i -m=1 -n=3182
$EXE -prec=$pr_i -m=9 -n=4096
$EXE -prec=$pr_i -m=3 -n=8192
$EXE -prec=$pr_i -m=1 -n=10547
$EXE -prec=$pr_i -m=3 -n=17134
done