Files
composable_kernel/example/ck_tile/12_smoothquant/smoothquant.hpp
rocking fbd654545a [Ck_tile] smoothquant (#1617)
* fix compile error

* fix typo of padding

* Add smoothquant op

* Add smoothquant instance library

* refine type

* add test script

* Re-generate smoothquant.hpp

* Always use 'current year' in copyright

* use Generic2dBlockShape instead

* Add vector = 8 instance back

* Find exe path automatically

* Simplify the api condition

* Remove debugging code

* update year

* Add blank line between function declaration

* explicitly cast return value to dim3

* refine return value

* Fix default warmup and repeat value

* Add comment

* refactor sommthquant cmake

* Add README

* Fix typo

---------

Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
2024-11-01 13:51:56 +08:00

115 lines
3.7 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
template <typename DataType>
struct SmoothquantTypeConfig;
template <>
struct SmoothquantTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t;
using XScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t>
{
using XDataType = ck_tile::bf16_t;
using XScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
// runtime args
struct smoothquant_args : public ck_tile::SmoothquantHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
struct smoothquant_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);
// This is the public API, will be generated by script
struct smoothquant_traits
{
std::string data_type;
};
float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);