int4 act ready

This commit is contained in:
root
2025-03-28 07:45:39 +00:00
parent 529a1732cd
commit de65682298
3 changed files with 17 additions and 18 deletions

View File

@@ -134,7 +134,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t ActOP = 0;
static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -59,8 +59,8 @@ struct MulABScale
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
{
(void) d0;
(void) d1;
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * 16);
#else
@@ -71,8 +71,8 @@ struct MulABScale
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1) const
{
(void) d0;
(void) d1;
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * 16);
#else
@@ -81,7 +81,6 @@ struct MulABScale
}
};
using CDEElementOp = MulABScale;
#if 1
@@ -125,7 +124,8 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
@@ -138,7 +138,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
2, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, true, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, true, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -152,7 +152,7 @@ int main(int argc, char* argv[])
// experts = 8
// per expert:
// GEMM shape
ck::index_t N = 14336 * 2;
ck::index_t N = 14336;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
@@ -402,8 +402,7 @@ int main(int argc, char* argv[])
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
ck::get_device_name() != "gfx950")
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
@@ -440,7 +439,8 @@ int main(int argc, char* argv[])
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
PassThrough,
Act_OP>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -59,7 +59,7 @@ struct MulABScaleExpertWeight
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
@@ -151,7 +151,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, false, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -409,8 +409,7 @@ int main(int argc, char* argv[])
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
ck::get_device_name() != "gfx950")
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "