mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fuse moe ffn1 and activation op
1. fused moe related activation op. 2. inserted act_op parameters to DeviceMoeGemm. - 0: gelu - 1: silu - 2: swiglu 3. fixed moe gemm example.
This commit is contained in:
@@ -12,6 +12,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
|
||||
if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132)
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1)
|
||||
target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -36,7 +36,7 @@ using A0DataType = F8;
|
||||
using B0DataType = F8;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
@@ -59,27 +59,25 @@ struct MulABScale
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1) const
|
||||
{
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
// for gate, a_scale, b_scale, fuse silu,
|
||||
struct MulABScaleSilu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float>(EDataType& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1) const
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
|
||||
{
|
||||
// act
|
||||
float x0 = 0;
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
|
||||
e = ck::type_convert<EDataType>(x0);
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, EDataType, EDataType>(
|
||||
EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -87,8 +85,6 @@ struct MulABScaleSilu
|
||||
// using DsDataType = DsDataTypeGate;
|
||||
using CDEElementOp = MulABScale;
|
||||
|
||||
// using CDEElementOp = MulABScaleSiluMulGate;
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16 / sizeof(B0DataType);
|
||||
@@ -125,9 +121,9 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MXDLPerWave = 2;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t MXDLPerWave = 1;
|
||||
static constexpr ck::index_t NXDLPerWave = 1;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
@@ -138,6 +134,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu, 1: silu, 2: swiglu
|
||||
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
@@ -157,8 +154,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, uint32_t, A0DataType>;
|
||||
1, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, true, int32_t, A0DataType>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -172,9 +169,9 @@ int main(int argc, char* argv[])
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t sorted_tile_num = 8;
|
||||
ck::index_t valid_tile_num = 8;
|
||||
ck::index_t tokens = 64;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
// ck::index_t tokens = batch * topk;
|
||||
@@ -224,7 +221,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 0};
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 1};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
@@ -236,7 +233,8 @@ int main(int argc, char* argv[])
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// max_token_id.mData = {valid_size};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
@@ -261,10 +259,13 @@ int main(int argc, char* argv[])
|
||||
// expert_ids.savetxt("expert_ids.txt", "int");
|
||||
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
// Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N * 2}, {StrideDs[1] ? StrideDs[1] *
|
||||
// N * 2: 1, StrideDs[1]}));
|
||||
Tensor<D1DataType> d1_e_n(
|
||||
HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
@@ -278,10 +279,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
@@ -290,10 +291,10 @@ int main(int argc, char* argv[])
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
@@ -311,7 +312,7 @@ int main(int argc, char* argv[])
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
// a0_t_k.savetxt("a.txt");
|
||||
// d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
d0_t_n.savetxt("d0_t_n.txt", "float");
|
||||
// d1_e_n.savetxt("d1_e_n.txt", "int");
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
@@ -329,7 +330,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
|
||||
preShuffleBuffer(
|
||||
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
@@ -367,9 +369,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
|
||||
sizeof(B0DataType) * K * N * experts +
|
||||
sizeof(B0DataType) * K * N * 2 * experts +
|
||||
sizeof(EDataType) * valid_tile_num * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -394,7 +396,8 @@ int main(int argc, char* argv[])
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
PassThrough,
|
||||
ActOP>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
@@ -403,7 +406,9 @@ int main(int argc, char* argv[])
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
d0_t_n,
|
||||
b0_e_n_k,
|
||||
d1_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
@@ -432,10 +437,10 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
// e_t_n_device_result.savetxt("out.txt");
|
||||
// e_t_n_host_result.savetxt("ref.txt");
|
||||
e_t_n_device_result.savetxt("out.txt");
|
||||
e_t_n_host_result.savetxt("ref.txt");
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -36,7 +36,7 @@ using A0DataType = F8;
|
||||
using B0DataType = I4;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
@@ -55,42 +55,32 @@ struct MulABScale
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// for gate, a_scale, b_scale, fuse silu,
|
||||
struct MulABScaleSilu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float>(EDataType& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1) const
|
||||
{
|
||||
// act
|
||||
float x0 = 0;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16);
|
||||
#else
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
|
||||
#endif
|
||||
e = ck::type_convert<EDataType>(x0);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScale;
|
||||
|
||||
#if 1
|
||||
@@ -132,53 +122,24 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr ck::index_t MXDLPerWave = 1;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock,
|
||||
AK1, BK1,
|
||||
MNPerXDL, MNPerXDL,
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, MPerBlock, 128, 128,
|
||||
256, MPerBlock, 64, 128,
|
||||
16, 32,
|
||||
32, 32,
|
||||
4, 1,
|
||||
16, 16,
|
||||
8, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>;
|
||||
2, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, true, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -191,14 +152,14 @@ int main(int argc, char* argv[])
|
||||
// experts = 8
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
ck::index_t N = 14336 * 2;
|
||||
ck::index_t N = 14336;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 64;
|
||||
ck::index_t tokens = 644;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
@@ -260,11 +221,12 @@ int main(int argc, char* argv[])
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
@@ -279,10 +241,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
@@ -440,7 +402,8 @@ int main(int argc, char* argv[])
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
if(!device_op.IsSupportedArgument(argument) ||
|
||||
!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
@@ -450,9 +413,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
|
||||
sizeof(B0DataType) / 2 * K * N * experts +
|
||||
sizeof(B0DataType) / 2 * K * N * 2 * experts +
|
||||
sizeof(EDataType) * valid_tile_num * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -477,7 +440,8 @@ int main(int argc, char* argv[])
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
PassThrough,
|
||||
Act_OP>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
@@ -486,7 +450,9 @@ int main(int argc, char* argv[])
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
d0_t_n,
|
||||
b0_e_n_k,
|
||||
d1_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -36,7 +36,7 @@ using A0DataType = F8;
|
||||
using B0DataType = F8;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
@@ -62,11 +62,19 @@ struct MulABScaleExpertWeight
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside.
|
||||
// tofix:felix
|
||||
(void)d0;
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2);
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
@@ -119,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 2;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 4;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
|
||||
// static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
|
||||
@@ -163,8 +171,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, uint32_t, A0DataType>;
|
||||
4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
@@ -184,8 +192,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 6;
|
||||
ck::index_t valid_tile_num = 6;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 128;
|
||||
@@ -245,13 +253,13 @@ int main(int argc, char* argv[])
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
// max_token_id.mData[0] = valid_size;
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
|
||||
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
|
||||
// max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = 0;//eids[i];
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
@@ -409,22 +417,21 @@ int main(int argc, char* argv[])
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<CShuffleDataType> c_t_n({tokens, N});
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
|
||||
B0DataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
|
||||
B0DataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
D2DataType,
|
||||
float,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -62,11 +62,13 @@ struct MulABScaleExpertWeight
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2 * 16);
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
}
|
||||
// for reference cpu
|
||||
@@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 1;
|
||||
static constexpr ck::index_t MXDLPerWave = 8;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t CShuffleNLane = 32;
|
||||
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
|
||||
@@ -148,8 +150,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>;
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -407,7 +409,8 @@ int main(int argc, char* argv[])
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
if(!device_op.IsSupportedArgument(argument) ||
|
||||
!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
|
||||
@@ -0,0 +1,620 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
index_t num_loop) const
|
||||
|
||||
{
|
||||
ignore = b_block_buf;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}>
|
||||
b_thread_dequant_bufs_up;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// // Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// // Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I0));
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs_up(I0));
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up
|
||||
[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(local_read_buf));
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs_up(local_read_buf));
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I1));
|
||||
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs_up(I1));
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
KPack>;
|
||||
|
||||
const PassThrough b_element_op{};
|
||||
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,573 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::c_thread_desc_;
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b =
|
||||
HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2;
|
||||
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr(MPerBlock >= 128 && NPerBlock >= 64)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
// if constexpr(i.value < num_buffer_load_inst_a) {
|
||||
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
// }
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_buf;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// // Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// // Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp"
|
||||
@@ -33,57 +35,112 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmBPreshufflePipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(std::is_same<ADataType, BDataType>::value)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -46,7 +46,9 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
|
||||
// static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 =
|
||||
BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType, TransposeC>{};
|
||||
@@ -181,6 +183,23 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmXdlops_pipeline_base.
|
||||
*
|
||||
* This constructor initializes the thread copy objects for matrices A and B.
|
||||
* It also performs several compile-time checks to ensure the correctness of the
|
||||
* matrix tile descriptors.
|
||||
*
|
||||
* @param a_origin The origin data index for matrix A.
|
||||
* @param b_origin The origin data index for matrix B.
|
||||
*
|
||||
* @note The constructor includes static assertions to ensure that:
|
||||
* - The matrix tile descriptors for A and B are known at compile-time.
|
||||
* - The number of threads in the thread block matches the product of MWaves, NWaves, and
|
||||
* WaveSize.
|
||||
* - The dimensions of the block are divisible by the product of the corresponding XDL and
|
||||
* repeat dimensions.
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
@@ -316,7 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
|
||||
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
|
||||
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -134,13 +134,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,10 +169,9 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunRead(src_descs, src_bufs);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -65,8 +65,10 @@ template <typename ALayout,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool PerTokenQuant = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
@@ -132,7 +134,10 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
PerTokenQuant,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
@@ -234,10 +239,10 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
4 * (1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b =
|
||||
NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
|
||||
constexpr auto estimated_reg_c =
|
||||
MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
|
||||
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
|
||||
4 * (2) * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
|
||||
BlockSize / 4 * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
@@ -257,7 +262,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -267,7 +271,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -282,7 +285,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -292,7 +294,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -312,7 +313,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
|
||||
|
||||
@@ -26,11 +26,18 @@ namespace ck {
|
||||
// two lds chunks.
|
||||
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
|
||||
// buffer when we declare __shared__ inside blkgemmpipe
|
||||
|
||||
enum Activation
|
||||
{
|
||||
gelu = 0,
|
||||
silu = 1,
|
||||
swiglu = 2
|
||||
};
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsInputGemm = false,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -44,7 +51,7 @@ __global__ void
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
@@ -66,7 +73,6 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsInputGemm = false,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -81,21 +87,20 @@ __global__ void
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::
|
||||
template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -147,7 +152,10 @@ template <typename ALayout,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOperation = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool PerTokenQuant = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
@@ -221,6 +229,7 @@ struct GridwiseMoeGemm
|
||||
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
|
||||
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
|
||||
const index_t gridy = NSwizzle ? 1 : mblock;
|
||||
|
||||
return std::make_tuple(gridx, gridy, 1);
|
||||
}
|
||||
|
||||
@@ -491,8 +500,8 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
|
||||
template <typename ELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
|
||||
__host__ __device__ static auto MakeCGridDescriptor_M_N(
|
||||
IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
@@ -903,7 +912,8 @@ struct GridwiseMoeGemm
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack>())>;
|
||||
KPack,
|
||||
IsInputGemm>())>;
|
||||
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -1135,7 +1145,6 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
bool IsInputGemm = true,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const index_t* p_sorted_token_ids,
|
||||
const index_t* p_sorted_expert_ids,
|
||||
@@ -1221,7 +1230,8 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -1232,7 +1242,6 @@ struct GridwiseMoeGemm
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
@@ -1305,24 +1314,66 @@ struct GridwiseMoeGemm
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
b_grid_buf,
|
||||
b_grid_buf_up,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
c_thread_buf_up,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -1350,6 +1401,94 @@ struct GridwiseMoeGemm
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
// mul scales
|
||||
const float* p_scale_b = p_ds_grid[I1];
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
|
||||
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
|
||||
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_scale_b += expert_id;
|
||||
}
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
const float scale_b = p_scale_b[n0 * NWave * PerTokenQuant];
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
scale_token_ids =
|
||||
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
|
||||
p_sorted_token_ids + m_pos);
|
||||
}
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
float scale_a = [&]() {
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return p_sorted_weights_0[0];
|
||||
}
|
||||
}();
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu)
|
||||
{
|
||||
tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx),
|
||||
c_thread_buf(cidx));
|
||||
}
|
||||
else if(ActivationOperation == Activation::swiglu)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NPerXdl + problem.N) * PerTokenQuant];
|
||||
auto gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
auto up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
gate = gate * math::rcp(1.0 + math::exp(-gate));
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf(cidx) = scale_a * scale_b *
|
||||
topk_weights.AsType<float>()[m4] *
|
||||
c_thread_buf[cidx];
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
@@ -1447,17 +1586,8 @@ struct GridwiseMoeGemm
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
if(i.value == 1)
|
||||
{
|
||||
ptr_ +=
|
||||
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
|
||||
}
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
@@ -1520,7 +1650,7 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
@@ -1531,17 +1661,8 @@ struct GridwiseMoeGemm
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
// using BufferType = std::conditional_t<
|
||||
// std::is_same_v<IndexType, long_index_t>,
|
||||
// decltype(make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())),
|
||||
// decltype(make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()))
|
||||
// >;
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// BufferType c_grid_buf = std::is_same_v<IndexType, long_index_t> ?
|
||||
// make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) :
|
||||
// make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
@@ -1571,36 +1692,21 @@ struct GridwiseMoeGemm
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads =
|
||||
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<IndexType, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock]
|
||||
// & 0xff000000) >> 24;
|
||||
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < static_cast<IndexType>(problem.NumTokens)
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1621,8 +1727,7 @@ struct GridwiseMoeGemm
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -1647,7 +1752,6 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
bool IsInputGemm = true,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
|
||||
const index_t* p_sorted_expert_ids,
|
||||
@@ -1971,11 +2075,12 @@ struct GridwiseMoeGemm
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
if(i.value == 1)
|
||||
{
|
||||
ptr_ +=
|
||||
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
|
||||
}
|
||||
// if(i.value == 1)
|
||||
// {
|
||||
// ptr_ +=
|
||||
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N :
|
||||
// 1);
|
||||
// }
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
@@ -2040,7 +2145,7 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
@@ -2083,15 +2188,9 @@ struct GridwiseMoeGemm
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads =
|
||||
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<IndexType, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock]
|
||||
// & 0xff000000) >> 24;
|
||||
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
@@ -2099,21 +2198,11 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < static_cast<IndexType>(problem.NumTokens)
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
// float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -2134,8 +2223,7 @@ struct GridwiseMoeGemm
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -154,7 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
@@ -173,31 +172,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
else if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
}
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
});
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
@@ -421,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
IndexType scatter_offset = 0;
|
||||
if constexpr(OutputScatter)
|
||||
{
|
||||
@@ -432,7 +408,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
@@ -491,10 +467,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunRead(src_descs, src_bufs);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,11 +22,13 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
index_t ActivationType_ = 0,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceMoeGemm : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
static constexpr auto ActivationType = ActivationType_;
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ck::index_t>& sorted_token_ids,
|
||||
@@ -34,7 +36,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ck::index_t>& max_token_id,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<float>& a_scale_t,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
const Tensor<float>& b_scale_e_n,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -44,7 +48,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
max_token_id_{max_token_id},
|
||||
sorted_tile_size_{sorted_tile_size},
|
||||
a_t_k_{a_t_k},
|
||||
a_scale_t_{a_scale_t},
|
||||
b_e_n_k_{b_e_n_k},
|
||||
b_scale_e_n_{b_scale_e_n},
|
||||
c_t_k_n_{c_t_k_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
@@ -57,7 +63,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ck::index_t>& max_token_id_;
|
||||
index_t sorted_tile_size_;
|
||||
const Tensor<ADataType>& a_t_k_;
|
||||
const Tensor<float>& a_scale_t_;
|
||||
const Tensor<BDataType>& b_e_n_k_;
|
||||
const Tensor<float>& b_scale_e_n_;
|
||||
Tensor<CDataType>& c_t_k_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -72,11 +80,20 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
if constexpr(ActivationType > 2)
|
||||
{
|
||||
static_assert(false, "Not supported activation type");
|
||||
}
|
||||
const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2];
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
|
||||
AccDataType v_acc_up{0};
|
||||
ComputeTypeB v_b_up{0};
|
||||
AccDataType v_acc{0};
|
||||
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
|
||||
const int t = arg.sorted_token_ids_(m) & 0xffffff;
|
||||
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
|
||||
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
|
||||
@@ -96,7 +113,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_a = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_a = i4 - 8;
|
||||
v_a = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -106,37 +123,75 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
|
||||
uint8_t i4 = 0;
|
||||
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
|
||||
uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data;
|
||||
uint8_t i4 = 0;
|
||||
uint8_t i4_up = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
{
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
i4_up = (i4x2_up >> 0) & 0xf;
|
||||
}
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
{
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
i4_up = (i4x2_up >> 4) & 0xf;
|
||||
}
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
v_b_up = i4_to_f32_gfx9(i4_up);
|
||||
#else
|
||||
v_b = i4 - 8;
|
||||
v_b = i4 - 8;
|
||||
v_b_up = i4_up - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
|
||||
if constexpr(ActivationType == 2)
|
||||
{
|
||||
arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n));
|
||||
}
|
||||
}
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
|
||||
if constexpr(ActivationType == 2)
|
||||
{
|
||||
v_acc_up += ck::type_convert<AccDataType>(v_a) *
|
||||
ck::type_convert<AccDataType>(v_b_up);
|
||||
}
|
||||
}
|
||||
CDataType v_c{0};
|
||||
CDataType v_c_up{0};
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c;
|
||||
if constexpr(ActivationType == 2)
|
||||
{
|
||||
arg.c_element_op_(v_c_up, v_acc_up);
|
||||
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
|
||||
v_c = v_c * (1.0 / (1.0 + math::exp(-v_c)));
|
||||
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(ActivationType == 1)
|
||||
{
|
||||
tensor_operation::element_wise::Silu{}(v_c, v_c);
|
||||
}
|
||||
else if constexpr(ActivationType == 0)
|
||||
{
|
||||
tensor_operation::element_wise::Gelu{}(v_c, v_c);
|
||||
}
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const ck::index_t max_token_id = arg.max_token_id_(0);
|
||||
make_ParallelTensorFunctor(
|
||||
f_mk_kn_mn, max_token_id, arg.c_t_k_n_.mDesc.GetLengths()[2])(
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
@@ -162,7 +217,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ck::index_t>& max_token_id,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<float>& a_scale_n,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
const Tensor<float>& b_scale_e_n,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -173,7 +230,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
max_token_id,
|
||||
sorted_tile_size,
|
||||
a_t_k,
|
||||
a_scale_n,
|
||||
b_e_n_k,
|
||||
b_scale_e_n,
|
||||
c_t_k_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
|
||||
Reference in New Issue
Block a user