mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE]Moe update index (#1672)
* update MOCK_ID for moe-sorting
* add moe-smoothquant
* update a comment
* fix format
* hot fix
* update topk in overflow case
* update comments
* update bf16 cvt
---------
Co-authored-by: valarLip <340077269@qq.com>
[ROCm/composable_kernel commit: 36c7ce4e0e]
This commit is contained in:
25
example/ck_tile/14_moe_smoothquant/CMakeLists.txt
Normal file
25
example/ck_tile/14_moe_smoothquant/CMakeLists.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
message("adding ${TARGET_NAME}")
|
||||
# not using add_example_executable() to add target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
list(APPEND INSTANCE_SRCS ${source})
|
||||
endforeach()
|
||||
|
||||
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
|
||||
endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
|
||||
add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS})
|
||||
|
||||
15
example/ck_tile/14_moe_smoothquant/README.md
Normal file
15
example/ck_tile/14_moe_smoothquant/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# moe-smoothquant
|
||||
|
||||
This folder contains example for moe-smoothquant using ck_tile tile-programming implementation.
|
||||

|
||||
|
||||
Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_smoothquant -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_smoothquant`
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
#if 0
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true ,false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ,false>>(const S&, A);
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8,true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4,true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2,true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1,true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true, false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,12 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_smoothquant_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false>>(const S&, A);
|
||||
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "moe_smoothquant.hpp"
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = moe_smoothquant_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename data_type>
|
||||
float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/,
|
||||
moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd 2p
|
||||
if(a.hidden_size <= 64) {
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 128) {
|
||||
if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 256) {
|
||||
if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 512) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 4, 64, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 8, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 768) {
|
||||
if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 4, 64, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 6, 4, 64, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1,12, 4, 64, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 1024) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 2, 128, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 2, 128, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 1536) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 2, 128, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 256, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 6, 1, 256, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 2048) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 256, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 8, 1, 256, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 3072) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 256, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 6, 1, 256, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 3, 1, 1024, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size <= 4096) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, false>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, false>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, false>>(s, a);
|
||||
}
|
||||
else if(a.hidden_size > 4096) {
|
||||
if (a.hidden_size % 8 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true>>(s, a);
|
||||
else if (a.hidden_size % 4 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, true>>(s, a);
|
||||
else if (a.hidden_size % 2 == 0)
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, true>>(s, a);
|
||||
else
|
||||
r = moe_smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, true>>(s, a);
|
||||
}
|
||||
return r;
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
float moe_smoothquant(moe_smoothquant_traits t,
|
||||
moe_smoothquant_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::fp16_t>(t, a, s);
|
||||
}
|
||||
else if(t.data_type.compare("bf16") == 0)
|
||||
{
|
||||
return moe_smoothquant_dispatch<ck_tile::bf16_t>(t, a, s);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "moe_smoothquant.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = moe_smoothquant_args;
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = moe_smoothquant_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float moe_smoothquant_(const S& s, A a)
|
||||
{
|
||||
using DataType = typename Traits_::DataType;
|
||||
|
||||
using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
|
||||
typename MoeSmoothquantTypeConfig<DataType>::XDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::XScaleDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::ComputeDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::YScaleDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::QYDataType,
|
||||
typename Traits_::Shape,
|
||||
Traits_::kPadN,
|
||||
Traits_::kTwoPass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Kernel = ck_tile::MoeSmoothquant<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
BIN
example/ck_tile/14_moe_smoothquant/misc/moe-sm.png
Normal file
BIN
example/ck_tile/14_moe_smoothquant/misc/moe-sm.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 202 KiB |
264
example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp
Normal file
264
example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp
Normal file
@@ -0,0 +1,264 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "moe_smoothquant.hpp"
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-5;
|
||||
double atol = 1e-5;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-5;
|
||||
double atol = 1e-5;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::int8_t>()
|
||||
{
|
||||
// due to rounding, int8 quantization might have 1 abs error
|
||||
double rtol = 1;
|
||||
double atol = 1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "3328", "tokens dimension")
|
||||
.insert("h", "4096", "hidden_size dimension")
|
||||
.insert("e", "32", "experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= hidden_size);
|
||||
|
||||
using TypeConfig = MoeSmoothquantTypeConfig<DataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using XScaleDataType = typename TypeConfig::XScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({experts * hidden_size});
|
||||
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({topk * tokens}, {1});
|
||||
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({topk * tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({topk * tokens, hidden_size}, {stride, 1});
|
||||
|
||||
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
topk_ids_buf.ToDevice(topk_ids_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride
|
||||
<< ", experts:" << experts << ", topk:" << topk << std::flush;
|
||||
|
||||
moe_smoothquant_traits traits{data_type};
|
||||
|
||||
moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
|
||||
xscale_buf.GetDeviceBuffer(),
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
hidden_size,
|
||||
experts,
|
||||
topk,
|
||||
stride,
|
||||
stride};
|
||||
|
||||
float ave_time = moe_smoothquant(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte =
|
||||
sizeof(XDataType) * tokens * hidden_size + sizeof(XScaleDataType) * topk * hidden_size +
|
||||
sizeof(YScaleDataType) * topk * tokens + sizeof(QYDataType) * topk * tokens * hidden_size;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
using YDataType = ComputeDataType;
|
||||
ck_tile::HostTensor<ComputeDataType> y_host({topk * tokens, hidden_size}, {stride, 1});
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto i_token) {
|
||||
for(int i_topk = 0; i_topk < topk; i_topk++)
|
||||
{
|
||||
auto i_expert = topk_ids_host(i_token, i_topk);
|
||||
|
||||
for(int i_h = 0; i_h < hidden_size; ++i_h)
|
||||
{
|
||||
auto v_xscale = ck_tile::type_convert<ComputeDataType>(
|
||||
xscale_host(i_expert * hidden_size + i_h));
|
||||
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
|
||||
// y_host(i_token * topk + i_topk, i_h) = v_x * v_xscale;
|
||||
y_host(i_topk * tokens + i_token, i_h) = v_x * v_xscale;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, tokens)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// yscale
|
||||
{
|
||||
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({topk * tokens});
|
||||
|
||||
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
|
||||
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
|
||||
y_host, y_rowwise_amax_host, ReduceAmax{});
|
||||
|
||||
auto op = [](const auto& v0) {
|
||||
return v0 /
|
||||
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
|
||||
};
|
||||
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
|
||||
y_rowwise_amax_host, yscale_host_ref, op);
|
||||
|
||||
yscale_buf.FromDevice(yscale_host_dev.mData.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<YScaleDataType>();
|
||||
pass &= ck_tile::check_err(yscale_host_dev,
|
||||
yscale_host_ref,
|
||||
std::string("yscale Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
// rowwise quantization
|
||||
{
|
||||
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
|
||||
y_host, yscale_host_ref, qy_host_ref);
|
||||
|
||||
qy_buf.FromDevice(qy_host_dev.data());
|
||||
auto [rtol, atol] = get_elimit<QYDataType>();
|
||||
|
||||
if(stride == hidden_size)
|
||||
{
|
||||
pass = ck_tile::check_err(qy_host_dev,
|
||||
qy_host_ref,
|
||||
std::string("qy Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < topk * tokens; i_r++)
|
||||
{
|
||||
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
|
||||
qy_host_dev.begin() + i_r * stride +
|
||||
hidden_size);
|
||||
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
|
||||
qy_host_ref.begin() + i_r * stride +
|
||||
hidden_size);
|
||||
pass &= ck_tile::check_err(qy_host_dev_row,
|
||||
qy_host_ref_row,
|
||||
std::string("qy[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
114
example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
Normal file
114
example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/smoothquant.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
struct MoeSmoothquantTypeConfig;
|
||||
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kTwoPass_>
|
||||
struct moe_smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct moe_smoothquant_traits
|
||||
{
|
||||
std::string data_type;
|
||||
};
|
||||
|
||||
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);
|
||||
37
example/ck_tile/14_moe_smoothquant/script/perf_test.sh
Executable file
37
example/ck_tile/14_moe_smoothquant/script/perf_test.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
|
||||
EXE=build/bin/tile_example_moe_smoothquant
|
||||
|
||||
$EXE -t=1 -h=1 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=80 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=128 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=144 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=168 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=184 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=256 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=288 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=344 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=376 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=448 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=512 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=924 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=1024 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=1078 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=1996 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -t=700 -h=4080 -v=1 -prec=bf16 -repeat=1000
|
||||
|
||||
$EXE -t=700 -h=80 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=128 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=144 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=168 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=184 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=256 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=288 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=344 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=376 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=448 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=512 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=924 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=1024 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=1078 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=1996 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -t=700 -h=4080 -v=1 -prec=fp16 -repeat=1000
|
||||
30
example/ck_tile/14_moe_smoothquant/script/smoke_test.sh
Executable file
30
example/ck_tile/14_moe_smoothquant/script/smoke_test.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/bin/sh
|
||||
EXE=build/bin/tile_example_moe_smoothquant
|
||||
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
$EXE -prec=$pr_i -t=99 -h=13
|
||||
$EXE -prec=$pr_i -t=17 -h=16
|
||||
$EXE -prec=$pr_i -t=1 -h=100
|
||||
$EXE -prec=$pr_i -t=4 -h=128
|
||||
$EXE -prec=$pr_i -t=80 -h=127
|
||||
$EXE -prec=$pr_i -t=22 -h=255 -stride=256
|
||||
$EXE -prec=$pr_i -t=7 -h=599
|
||||
$EXE -prec=$pr_i -t=19 -h=512
|
||||
$EXE -prec=$pr_i -t=33 -h=313 -stride=1000
|
||||
$EXE -prec=$pr_i -t=11 -h=510
|
||||
$EXE -prec=$pr_i -t=171 -h=676 -stride=818
|
||||
$EXE -prec=$pr_i -t=91 -h=636
|
||||
$EXE -prec=$pr_i -t=12 -h=768 -stride=800
|
||||
$EXE -prec=$pr_i -t=100 -h=766 -stride=812
|
||||
$EXE -prec=$pr_i -t=31 -h=1024
|
||||
$EXE -prec=$pr_i -t=64 -h=1000 -stride=1004
|
||||
$EXE -prec=$pr_i -t=8 -h=1501
|
||||
$EXE -prec=$pr_i -t=3 -h=1826
|
||||
$EXE -prec=$pr_i -t=5 -h=2040
|
||||
$EXE -prec=$pr_i -t=7 -h=2734
|
||||
$EXE -prec=$pr_i -t=1 -h=3182
|
||||
$EXE -prec=$pr_i -t=9 -h=4096
|
||||
$EXE -prec=$pr_i -t=3 -h=8192
|
||||
$EXE -prec=$pr_i -t=1 -h=10547
|
||||
$EXE -prec=$pr_i -t=3 -h=17134
|
||||
done
|
||||
@@ -13,3 +13,4 @@ add_subdirectory(10_rmsnorm2d)
|
||||
add_subdirectory(11_add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(12_smoothquant)
|
||||
add_subdirectory(13_moe_sorting)
|
||||
add_subdirectory(14_moe_smoothquant)
|
||||
|
||||
@@ -64,6 +64,7 @@
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
|
||||
@@ -225,3 +226,7 @@
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
|
||||
#define CK_TILE_WORKAROUND_SWDEV_383542 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
|
||||
@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
standard_asm,
|
||||
rta_asm, // round to nearest away
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
|
||||
return uint16_t(u.int32);
|
||||
}
|
||||
|
||||
// TODO: do we need this on host?
|
||||
CK_TILE_HOST
|
||||
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
uint16_t float_to_bf16_rta_asm(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
struct
|
||||
{
|
||||
uint16_t lo;
|
||||
uint16_t hi;
|
||||
};
|
||||
} u = {f};
|
||||
|
||||
const uint32_t low_nan = 0x7fff;
|
||||
const uint32_t hi_nan = 0x7fff0000;
|
||||
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
|
||||
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
|
||||
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
|
||||
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
|
||||
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
|
||||
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
|
||||
|
||||
// Note: in above code snipet, we use hi 16 bit
|
||||
return u.hi;
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
|
||||
return float_to_bf16_rtn_asm(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
|
||||
return float_to_bf16_rta_asm(f);
|
||||
else
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
|
||||
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
const HostTensor<WeightType>& weights,
|
||||
@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
std::vector<std::vector<IndexType>> expert_tokens(experts,
|
||||
std::vector<IndexType>(unit_size, num_token));
|
||||
// allocate a temp buffer, and fill the value with [number_token|topk]
|
||||
std::vector<std::vector<IndexType>> expert_tokens(
|
||||
experts,
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
std::vector<IndexType>(unit_size, MOE_SORTING_MOCK_ID(num_token, topk)));
|
||||
#else
|
||||
std::vector<IndexType>(unit_size, num_token));
|
||||
#endif
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
expert_token_weights[e].resize(new_size);
|
||||
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
|
||||
{
|
||||
expert_tokens[e][i] = num_token;
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
expert_tokens[e][i] = MOE_SORTING_MOCK_ID(num_token, topk);
|
||||
#else
|
||||
expert_tokens[e][i] = num_token;
|
||||
#endif
|
||||
expert_token_weights[e][i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
expert_tokens[e][idx] = t;
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
expert_tokens[e][idx] = MOE_SORTING_MOCK_ID(t, k);
|
||||
#else
|
||||
expert_tokens[e][idx] = t;
|
||||
#endif
|
||||
expert_token_weights[e][idx] = w;
|
||||
expert_slice_idxs[e]++;
|
||||
}
|
||||
@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
unit_cnt *= unit_size;
|
||||
return;
|
||||
}
|
||||
|
||||
#undef MOE_SORTING_MOCK_ID
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -12,20 +12,77 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
|
||||
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
|
||||
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
|
||||
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
|
||||
//
|
||||
// 32bit 0........23 24.....31 bit
|
||||
// (data) -> (token_id | topk_id)
|
||||
// low 24 bit is for token id, top 8 bit is for topk id
|
||||
//
|
||||
// the input after smooth-quant is [topk, token, hidden_dim], originally it is [token, hidden_dim]
|
||||
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
struct MoeSortingHostArgs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
const void* p_topk_ids; // [token, topk]
|
||||
const void* p_weights; // [token, topk]
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
// we fused the setzero of output of fused-moe buffer
|
||||
// set this pointer to nullptr will skip this operation
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t unit_size;
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t moe_buf_bytes;
|
||||
index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -183,8 +240,14 @@ struct MoeSortingKernel
|
||||
index_t expert_id = topk_id[i];
|
||||
index_t rank_post_pad =
|
||||
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
|
||||
p_sorted_token_ids[rank_post_pad] = MOE_SORTING_MOCK_ID(curr_token_id, curr_topk_id);
|
||||
#else
|
||||
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
#endif
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
|
||||
}
|
||||
|
||||
@@ -195,8 +258,13 @@ struct MoeSortingKernel
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
|
||||
while(expert_offset < cumsum[tid + 1])
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[expert_offset] =
|
||||
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
#endif
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset++;
|
||||
}
|
||||
}
|
||||
@@ -229,4 +297,7 @@ struct MoeSortingKernel
|
||||
smem);
|
||||
}
|
||||
};
|
||||
|
||||
#undef MOE_SORTING_MOCK_ID
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp"
|
||||
#include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp"
|
||||
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp"
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// host side args
|
||||
struct MoeSmoothquantHostArgs
|
||||
{
|
||||
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
|
||||
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
|
||||
void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
|
||||
void* p_qy; // [topk * tokens, hidden_size], output
|
||||
|
||||
index_t tokens;
|
||||
index_t hidden_size;
|
||||
index_t experts;
|
||||
index_t topk;
|
||||
index_t x_stride; // input x row stride
|
||||
index_t y_stride; // output y stride(stride for topk)
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
template <typename Pipeline_>
|
||||
struct MoeSmoothquant
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
static_assert(Problem::BlockShape::Repeat_M == 1);
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
|
||||
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
|
||||
void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
|
||||
void* p_qy; // [topk, tokens, hidden_size], output
|
||||
|
||||
index_t tokens;
|
||||
index_t hidden_size;
|
||||
index_t experts;
|
||||
index_t topk;
|
||||
index_t x_stride; // input x row stride
|
||||
index_t y_stride; // output y stride(stride for topk)
|
||||
};
|
||||
using Hargs = MoeSmoothquantHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_xscale,
|
||||
hargs.p_topk_ids,
|
||||
hargs.p_yscale,
|
||||
hargs.p_qy,
|
||||
hargs.tokens,
|
||||
hargs.hidden_size,
|
||||
hargs.experts,
|
||||
hargs.topk,
|
||||
hargs.x_stride,
|
||||
hargs.y_stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kTwoPass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const index_t i_topk = blockIdx.x;
|
||||
const index_t i_token = blockIdx.y * Block_M;
|
||||
const index_t i_token_in_thrd =
|
||||
__builtin_amdgcn_readfirstlane(threadIdx.x / Problem::BlockShape::ThreadPerBlock_N);
|
||||
|
||||
const index_t i_expert = reinterpret_cast<const index_t*>(
|
||||
kargs.p_topk_ids)[(i_token + i_token_in_thrd) * kargs.topk + i_topk];
|
||||
|
||||
// [tokens ,hidden_size]
|
||||
const auto x_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.tokens, kargs.hidden_size),
|
||||
make_tuple(kargs.x_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
|
||||
}();
|
||||
|
||||
// [experts, hidden_size],
|
||||
const auto xscale_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_xscale) + i_expert * kargs.hidden_size,
|
||||
make_tuple(kargs.hidden_size),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
// [topk, tokens]
|
||||
auto yscale_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YScaleDataType*>(kargs.p_yscale) + i_topk * kargs.tokens,
|
||||
make_tuple(kargs.tokens),
|
||||
make_tuple(1),
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {i_token});
|
||||
}();
|
||||
|
||||
// [topk, tokens, hidden_size]
|
||||
auto qy_window = [&]() {
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<QYDataType*>(kargs.p_qy) + i_topk * kargs.tokens * kargs.y_stride,
|
||||
make_tuple(kargs.tokens, kargs.hidden_size),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
|
||||
}();
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user