mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[Ck_tile] smoothquant (#1617)
* fix compile error * fix typo of padding * Add smoothquant op * Add smoothquant instance library * refine type * add test script * Re-generate smoothquant.hpp * Always use 'current year' in copyright * use Generic2dBlockShape instead * Add vector = 8 instance back * Find exe path automatically * Simplify the api condition * Remove debugging code * update year * Add blank line between function declaration * explicitly cast return value to dim3 * refine return value * Fix default warmup and repeat value * Add comment * refactor sommthquant cmake * Add README * Fix typo --------- Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -11,11 +11,11 @@ namespace ck_tile {
|
||||
// host side args
|
||||
struct Rmsnorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_gamma;
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
|
||||
void* p_y;
|
||||
void* p_invRms;
|
||||
void* p_y; // [m, n], output, fp16/bf16
|
||||
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
@@ -83,7 +83,7 @@ struct Rmsnorm2dFwd
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
return (hargs.m + Block_M - 1) / Block_M;
|
||||
return dim3(integer_divide_ceil(hargs.m, Block_M));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
@@ -149,7 +149,7 @@ struct Rmsnorm2dFwd
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadM>{});
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
/*
|
||||
// clang-format off
|
||||
|
||||
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
|
||||
|
||||
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
|
||||
+<----------------------< Repeat_N(2)>--------------------->+
|
||||
| |
|
||||
+<-- <WarpPerBlock_N(2)> -->+
|
||||
Warp_N
|
||||
+--------------+--------------+--------------+--------------+----+----------------+
|
||||
Warp_M | wrap_0 | wrap_1 | | ^ ^
|
||||
+--------------+--------------+ | <WarpPerBlock_M(2)> |
|
||||
| wrap_2 | wrap_3 | | v
|
||||
+--------------+--------------+--------------+--------------+----+ Block_M
|
||||
| | |
|
||||
+ + |
|
||||
| | | v
|
||||
+--------------+--------------+--------------+--------------+ +
|
||||
|
||||
each Warp-tile (e.g 16 thrd per row)
|
||||
|
||||
Vector_N (contiguous pixels each thrd holds along N, or vector size)
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
|
||||
+-----------+-----------+-----------+-----------+-----------+
|
||||
// clang-format on
|
||||
*/
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
|
||||
index_t BlockSize_ =
|
||||
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
|
||||
struct Rmsnorm2dShape
|
||||
{
|
||||
// block size
|
||||
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
|
||||
|
||||
// num warps along seq<M, N>, within each block
|
||||
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
|
||||
|
||||
// warp size
|
||||
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
|
||||
|
||||
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
|
||||
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
|
||||
// repeat of each thread along seq<M, N>
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
// vector size along seq<M, N>
|
||||
static constexpr index_t Vector_M = Vector_::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector_::at(number<1>{});
|
||||
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user