mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
int4 act ready
This commit is contained in:
@@ -134,7 +134,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t ActOP = 0;
|
||||
static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu
|
||||
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -59,8 +59,8 @@ struct MulABScale
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
|
||||
{
|
||||
(void) d0;
|
||||
(void) d1;
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
#else
|
||||
@@ -71,8 +71,8 @@ struct MulABScale
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1) const
|
||||
{
|
||||
(void) d0;
|
||||
(void) d1;
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
#else
|
||||
@@ -81,7 +81,6 @@ struct MulABScale
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
using CDEElementOp = MulABScale;
|
||||
|
||||
#if 1
|
||||
@@ -125,7 +124,8 @@ using BElementOp = PassThrough;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
@@ -138,7 +138,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
2, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, true, ck::index_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, true, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -152,7 +152,7 @@ int main(int argc, char* argv[])
|
||||
// experts = 8
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
ck::index_t N = 14336 * 2;
|
||||
ck::index_t N = 14336;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
@@ -402,8 +402,7 @@ int main(int argc, char* argv[])
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
|
||||
ck::get_device_name() != "gfx950")
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
@@ -440,7 +439,8 @@ int main(int argc, char* argv[])
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
PassThrough,
|
||||
Act_OP>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -59,7 +59,7 @@ struct MulABScaleExpertWeight
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
@@ -151,7 +151,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, false, ck::index_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -409,8 +409,7 @@ int main(int argc, char* argv[])
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
|
||||
ck::get_device_name() != "gfx950")
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
|
||||
Reference in New Issue
Block a user