mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* Add shortcut to RMSNorm * Modify test for adding shortcut for RMSNorm * Add fused parameter into tests * 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp * 1. Supports various stride and percisions. * Add support of Epilogue * Add fuse and epilogue support to rmsnorm ref * Modify rmsnorm example * Refactor tests/examples * Bug fix for newly added tests/examples * Bug fix for new tests 2 * Modify smoke test scripts remove dbg code * Supports non-smooth dyanmic quant * Update Rmsnorm2dFwd::GetName() * rename xscale and prec_sx to smoothscale and prec_sm Bug fix after rename Remove files * change example_rmsnorm2d_fwd.cpp * update performance calculator * Fix issue in two-pass when fuse add is enabled * Remove comment of beta --------- Co-authored-by: rocking <ChunYu.Lai@amd.com>
115 lines
3.8 KiB
C++
115 lines
3.8 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/smoothquant.hpp"
|
|
#include <string>
|
|
|
|
template <typename DataType>
|
|
struct SmoothquantTypeConfig;
|
|
|
|
template <>
|
|
struct SmoothquantTypeConfig<ck_tile::half_t>
|
|
{
|
|
using XDataType = ck_tile::half_t;
|
|
using SmoothScaleDataType = float;
|
|
using YScaleDataType = float;
|
|
using QYDataType = ck_tile::int8_t;
|
|
using ComputeDataType = float;
|
|
};
|
|
|
|
template <>
|
|
struct SmoothquantTypeConfig<ck_tile::bf16_t>
|
|
{
|
|
using XDataType = ck_tile::bf16_t;
|
|
using SmoothScaleDataType = float;
|
|
using YScaleDataType = float;
|
|
using QYDataType = ck_tile::int8_t;
|
|
using ComputeDataType = float;
|
|
};
|
|
|
|
// runtime args
|
|
struct smoothquant_args : public ck_tile::SmoothquantHostArgs
|
|
{
|
|
};
|
|
|
|
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
|
template <typename DataType_,
|
|
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
|
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
|
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
|
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
|
ck_tile::index_t Vector_N_, // vector size along N
|
|
bool kPadN_,
|
|
bool kTwoPass_>
|
|
struct smoothquant_traits_
|
|
{
|
|
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
|
|
|
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
|
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
|
static constexpr ck_tile::index_t total_warps =
|
|
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
|
|
|
// num of warps along m
|
|
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
|
if constexpr(is_warp_per_row)
|
|
{
|
|
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
|
return total_warps * (warpSize / ThreadPerBlock_N_);
|
|
}
|
|
else
|
|
{
|
|
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
|
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
|
}
|
|
}();
|
|
|
|
// num of warps along n
|
|
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
|
if constexpr(is_warp_per_row)
|
|
{
|
|
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
|
return 1;
|
|
}
|
|
else
|
|
{
|
|
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
|
return ThreadPerBlock_N_ / warpSize;
|
|
}
|
|
}();
|
|
|
|
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
|
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
|
|
|
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
|
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
|
|
|
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
|
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
|
|
|
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
|
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
|
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
|
using Vector = ck_tile::sequence<1, Vector_N_>;
|
|
|
|
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
|
|
|
static constexpr bool kPadN = kPadN_;
|
|
static constexpr bool kTwoPass = kTwoPass_;
|
|
};
|
|
|
|
template <typename Traits_>
|
|
float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);
|
|
|
|
// This is the public API, will be generated by script
|
|
struct smoothquant_traits
|
|
{
|
|
std::string data_type;
|
|
};
|
|
|
|
float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);
|