mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] Migrate CK Tile examples to Tests to autorun on CI (#2421)
[CK_TILE] Add new ck tile unit test * Add new ck tile unit test smoke-gemm-universal * Add new ck tile unit test smoke-gemm-basic * Add new ck tile unit test topk_softmax * Add new ck tile unit test add_rmsnorm2d_rdquant_fwd
This commit is contained in:
69
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp
Normal file
69
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename InType,
|
||||
typename OutType,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig;
|
||||
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using InvRmsDataType = ck_tile::half_t;
|
||||
using UnquantYDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using InvRmsDataType = ck_tile::bf16_t;
|
||||
using UnquantYDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct rmsnorm2d_fwd_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_o; // output precision
|
||||
|
||||
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
|
||||
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
|
||||
// can set arbitrary(will skip check)
|
||||
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_rms;
|
||||
bool save_unquant;
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
Reference in New Issue
Block a user