mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Moe gemm activation (#2026)
* fix useless code and remove usless oob * clang format * fix coredump in e2e test * fix2 * fix clang format * fix output oob * impl int64 but result not correct * int64 index ok now * input output all ok * fix uint32 * revert v1 test * use uint32 * mork to support 13w tokens * moe sorting fix moebuf * fix merge * update moe api fix aiter build * fix buid * fuse silu * silu ok * acale ok * add silu * change code * gemm2 ok * gufusion compatible ok, fix warnings * gu fusion for m32 m64 ok * support bf16 cshuffle * i4 gemm2 ok * i4 gemm2 ok and i4 gemm1 build * 16x16 run ok * change flops; change cshuffle dtype * fuse gelu silu act in moe gemm1 * fp8 with act ready * int4 act ready * remove useless changes * remove useless code change * fix clang format * add the arch limit of int4 moe gemm * fuse moe activation * fix fp8 16x16 * fix no quant case * fix bugs * fix fp8 gufusion bug * remove useless comments * refine activation code & complete moe example * fix int8 bugs * merge tkw1 --------- Co-authored-by: coderfeli <coderfeli@163.com> Co-authored-by: feli <felix.li@amd.com> Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: root <root@hjbog-srdc-51.amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
|
||||
if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132)
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1)
|
||||
target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -25,7 +25,6 @@ template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
// using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
|
||||
@@ -36,7 +35,7 @@ using A0DataType = F8;
|
||||
using B0DataType = F8;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
@@ -61,27 +60,25 @@ struct MulABScale
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1) const
|
||||
{
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
// for gate, a_scale, b_scale, fuse silu,
|
||||
struct MulABScaleSilu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float>(EDataType& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1) const
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
|
||||
{
|
||||
// act
|
||||
float x0 = 0;
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
|
||||
e = ck::type_convert<EDataType>(x0);
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, EDataType, EDataType>(
|
||||
EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -95,11 +92,19 @@ struct MulABScaleExpertWeight
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside.
|
||||
// tofix:felix
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
@@ -107,16 +112,14 @@ struct MulABScaleExpertWeight
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
|
||||
// using DsLayout = DsLayoutGate;
|
||||
// using DsDataType = DsDataTypeGate;
|
||||
// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false
|
||||
|
||||
// using CDEElementOp = MulABScaleSiluMulGate;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
{
|
||||
@@ -155,22 +158,21 @@ using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MXDLPerWave = 2;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t NPerBlock = 64;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = true;
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
@@ -188,8 +190,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
|
||||
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -201,15 +203,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 8;
|
||||
ck::index_t valid_tile_num = 8;
|
||||
ck::index_t tokens = 128;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t tokens = 64;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
// ck::index_t tokens = batch * topk;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
@@ -255,28 +255,22 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 1, 1};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
// const ck::index_t experts = 8;
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
@@ -290,13 +284,12 @@ int main(int argc, char* argv[])
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
// expert_ids.savetxt("expert_ids.txt", "int");
|
||||
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D1DataType> d1_e_n(
|
||||
HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
@@ -304,6 +297,7 @@ int main(int argc, char* argv[])
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
@@ -312,25 +306,25 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0, 1});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
@@ -349,9 +343,7 @@ int main(int argc, char* argv[])
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
// a0_t_k.savetxt("a.txt");
|
||||
// d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
// d1_e_n.savetxt("d1_e_n.txt", "int");
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
@@ -369,7 +361,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
|
||||
preShuffleBuffer(
|
||||
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
@@ -408,9 +401,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
|
||||
sizeof(B0DataType) * K * N * experts +
|
||||
sizeof(B0DataType) * K * N * 2 * experts +
|
||||
sizeof(EDataType) * valid_tile_num * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -437,6 +430,7 @@ int main(int argc, char* argv[])
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ActOP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
@@ -446,7 +440,9 @@ int main(int argc, char* argv[])
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
d0_t_n,
|
||||
b0_e_n_k,
|
||||
d1_e_n,
|
||||
c_t_k_n,
|
||||
d2_e_n,
|
||||
PassThrough{},
|
||||
@@ -472,15 +468,14 @@ int main(int argc, char* argv[])
|
||||
c_t_k_n(t, topk_id, n),
|
||||
d0_t_n(t, n),
|
||||
d1_e_n(e, n),
|
||||
1.f);
|
||||
d2_e_n(e, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
// e_t_n_device_result.savetxt("out.txt");
|
||||
// e_t_n_host_result.savetxt("ref.txt");
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ using A0DataType = F8;
|
||||
using B0DataType = I4;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
@@ -47,7 +47,8 @@ using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, ELayout>;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// for gate, a_scale, b_scale
|
||||
struct MulABScale
|
||||
@@ -56,42 +57,32 @@ struct MulABScale
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// for gate, a_scale, b_scale, fuse silu,
|
||||
struct MulABScaleSilu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float>(EDataType& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1) const
|
||||
{
|
||||
// act
|
||||
float x0 = 0;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16);
|
||||
#else
|
||||
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
|
||||
#endif
|
||||
e = ck::type_convert<EDataType>(x0);
|
||||
}
|
||||
};
|
||||
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
@@ -102,13 +93,19 @@ struct MulABScaleExpertWeight
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
#endif
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
@@ -116,15 +113,18 @@ struct MulABScaleExpertWeight
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2 * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
#endif
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
|
||||
|
||||
// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true
|
||||
|
||||
#if 1
|
||||
void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl)
|
||||
@@ -165,54 +165,24 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr ck::index_t MXDLPerWave = 1;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock,
|
||||
AK1, BK1,
|
||||
MNPerXDL, MNPerXDL,
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
// clang-format on
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, MPerBlock, 128, 128,
|
||||
256, MPerBlock, 64, 128,
|
||||
16, 32,
|
||||
32, 32,
|
||||
4, 1,
|
||||
16, 16,
|
||||
8, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
|
||||
2, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -220,13 +190,10 @@ int main(int argc, char* argv[])
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// tokens = 1
|
||||
// topk = 1
|
||||
// experts = 8
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096 * 2;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t N = 14336;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
@@ -266,20 +233,20 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 0, 0};
|
||||
max_token_id.mData = {valid_size};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
@@ -294,11 +261,12 @@ int main(int argc, char* argv[])
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
@@ -306,6 +274,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
|
||||
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
@@ -314,11 +283,11 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
@@ -497,9 +466,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
|
||||
sizeof(B0DataType) / 2 * K * N * experts +
|
||||
sizeof(B0DataType) / 2 * K * N * 2 * experts +
|
||||
sizeof(EDataType) * valid_tile_num * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -526,6 +495,7 @@ int main(int argc, char* argv[])
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Act_OP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
@@ -535,7 +505,9 @@ int main(int argc, char* argv[])
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
d0_t_n,
|
||||
b0_e_n_k,
|
||||
d1_e_n,
|
||||
c_t_k_n,
|
||||
d2_e_n,
|
||||
PassThrough{},
|
||||
@@ -561,13 +533,13 @@ int main(int argc, char* argv[])
|
||||
c_t_k_n(t, topk_id, n),
|
||||
d0_t_n(t, n),
|
||||
d1_e_n(e, n),
|
||||
1.f);
|
||||
d2_e_n(e, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
// using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
|
||||
@@ -36,7 +35,7 @@ using A0DataType = F8;
|
||||
using B0DataType = F8;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
@@ -48,7 +47,6 @@ using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
@@ -62,11 +60,19 @@ struct MulABScaleExpertWeight
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside.
|
||||
// tofix:felix
|
||||
(void)d0;
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2);
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
@@ -119,14 +125,12 @@ using CDEElementOp = MulABScaleExpertWeight;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 2;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 4;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
|
||||
// static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
|
||||
// static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
|
||||
static constexpr ck::index_t CShuffleNLane = 32;
|
||||
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
@@ -135,7 +139,7 @@ static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
@@ -164,8 +168,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
|
||||
4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
@@ -177,16 +181,13 @@ int main(int argc, char* argv[])
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// tokens = 1
|
||||
// topk = 1
|
||||
// experts = 8
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 6;
|
||||
ck::index_t valid_tile_num = 6;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 128;
|
||||
@@ -212,6 +213,18 @@ int main(int argc, char* argv[])
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
sorted_tile_num = std::stoi(argv[7]);
|
||||
valid_tile_num = std::stoi(argv[8]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
@@ -229,15 +242,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
// const ck::index_t experts = 8;
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
// max_token_id.mData[0] = valid_size;
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
|
||||
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
@@ -249,7 +260,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
@@ -263,8 +274,7 @@ int main(int argc, char* argv[])
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
@@ -315,12 +325,7 @@ int main(int argc, char* argv[])
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
// a0_t_k_k.savetxt("a.txt");
|
||||
// expert_ids.savetxt("expert_ids.txt", "int");
|
||||
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
// d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
// d1_e_n.savetxt("d1_e_n.txt", "int");
|
||||
// d2_e_n.savetxt("d2_e_n.txt", "int");
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
@@ -398,7 +403,7 @@ int main(int argc, char* argv[])
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<CShuffleDataType> c_t_n({tokens, N});
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
|
||||
@@ -406,7 +411,7 @@ int main(int argc, char* argv[])
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
float,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -439,8 +444,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
// e_t_n_device_result.savetxt("out.txt");
|
||||
// e_t_n_host_result.savetxt("ref.txt");
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
|
||||
@@ -62,11 +62,13 @@ struct MulABScaleExpertWeight
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2 * 16);
|
||||
e = ck::type_convert<EDataType>(c * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d2);
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
#endif
|
||||
}
|
||||
// for reference cpu
|
||||
@@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 1;
|
||||
static constexpr ck::index_t MXDLPerWave = 8;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t CShuffleNLane = 32;
|
||||
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
|
||||
@@ -149,8 +151,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -159,9 +161,6 @@ int main(int argc, char* argv[])
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// tokens = 1
|
||||
// topk = 1
|
||||
// experts = 8
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
|
||||
@@ -0,0 +1,621 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
index_t num_loop) const
|
||||
|
||||
{
|
||||
ignore = b_block_buf;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}>
|
||||
b_thread_dequant_bufs_up;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// // Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// // Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I0));
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs_up(I0));
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up
|
||||
[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(local_read_buf));
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs_up(local_read_buf));
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs(I1));
|
||||
|
||||
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1),
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_dequant_bufs_up(I1));
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_dequant_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
decltype(b_block_desc_n0_n1_k0_k1),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
KPack>;
|
||||
|
||||
const PassThrough b_element_op{};
|
||||
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,573 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::c_thread_desc_;
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b =
|
||||
HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2;
|
||||
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr(MPerBlock >= 128 && NPerBlock >= 64)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
// if constexpr(i.value < num_buffer_load_inst_a) {
|
||||
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
// }
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_buf;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// // Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// // Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp"
|
||||
@@ -33,57 +35,112 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmBPreshufflePipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(std::is_same<ADataType, BDataType>::value)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
|
||||
@@ -46,7 +46,8 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 =
|
||||
BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType, TransposeC>{};
|
||||
@@ -333,7 +334,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
|
||||
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
|
||||
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
@@ -58,7 +59,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
|
||||
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
@@ -190,6 +191,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
IndexType,
|
||||
GatherDim,
|
||||
NumThreadScratch>;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -42,6 +42,7 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
typename IndexType,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
@@ -133,13 +134,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,7 +149,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
@@ -169,10 +169,9 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunRead(src_descs, src_bufs);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
@@ -230,6 +229,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
IndexType,
|
||||
ScatterDim,
|
||||
OutputScatter,
|
||||
ScatterWeightIdx,
|
||||
|
||||
@@ -65,9 +65,12 @@ template <typename ALayout,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
bool PerTokenQuant = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
@@ -132,7 +135,12 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
PerTokenQuant,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
@@ -247,10 +255,10 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
4 * (1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b =
|
||||
NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
|
||||
constexpr auto estimated_reg_c =
|
||||
MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
|
||||
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
|
||||
4 * (2) * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
|
||||
BlockSize / 4 * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
@@ -270,8 +278,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -281,8 +287,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -297,8 +301,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -308,8 +310,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
@@ -329,8 +329,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
|
||||
|
||||
@@ -26,12 +26,17 @@ namespace ck {
|
||||
// two lds chunks.
|
||||
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
|
||||
// buffer when we declare __shared__ inside blkgemmpipe
|
||||
|
||||
enum Activation
|
||||
{
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1
|
||||
};
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsInputGemm = false,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -45,22 +50,19 @@ __global__ void
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNum>(karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -70,8 +72,6 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsInputGemm = false,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -86,23 +86,20 @@ __global__ void
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
TailNum>(karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -154,7 +151,12 @@ template <typename ALayout,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOperation = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
bool PerTokenQuant = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ADataType,
|
||||
@@ -227,6 +229,7 @@ struct GridwiseMoeGemm
|
||||
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
|
||||
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
|
||||
const index_t gridy = NSwizzle ? 1 : mblock;
|
||||
|
||||
return std::make_tuple(gridx, gridy, 1);
|
||||
}
|
||||
|
||||
@@ -305,7 +308,7 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
|
||||
IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
@@ -497,8 +500,8 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
|
||||
template <typename ELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
__host__ __device__ static auto MakeCGridDescriptor_M_N(
|
||||
IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
@@ -909,7 +912,8 @@ struct GridwiseMoeGemm
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack>())>;
|
||||
KPack,
|
||||
IsInputGemm>())>;
|
||||
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -1141,9 +1145,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const index_t* p_sorted_token_ids,
|
||||
const index_t* p_sorted_expert_ids,
|
||||
const index_t* p_max_token_id,
|
||||
@@ -1203,6 +1205,7 @@ struct GridwiseMoeGemm
|
||||
return {blockIdx.x, blockIdx.y};
|
||||
}
|
||||
}();
|
||||
|
||||
const index_t block_n_id = block_mn.first;
|
||||
const index_t block_m_id = block_mn.second;
|
||||
const index_t token0 =
|
||||
@@ -1218,7 +1221,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets;
|
||||
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1226,9 +1229,10 @@ struct GridwiseMoeGemm
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -1239,7 +1243,6 @@ struct GridwiseMoeGemm
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
@@ -1269,6 +1272,7 @@ struct GridwiseMoeGemm
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1311,24 +1315,74 @@ struct GridwiseMoeGemm
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
float,
|
||||
c_thread_buf.num_of_v_,
|
||||
c_thread_buf.s_per_v,
|
||||
true>
|
||||
c_thread_buf_fp32;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
b_grid_buf,
|
||||
b_grid_buf_up,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
c_thread_buf_up,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -1356,6 +1410,185 @@ struct GridwiseMoeGemm
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
// mul scales
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
const float* p_scale_b = p_ds_grid[I1];
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
|
||||
{
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
|
||||
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
|
||||
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_scale_b += expert_id;
|
||||
}
|
||||
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights;
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
scale_token_ids =
|
||||
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
|
||||
p_sorted_token_ids + m_pos);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
float scale_a = [&]() {
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return p_sorted_weights_0[0];
|
||||
}
|
||||
}();
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
scale_a * scale_b * c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
|
||||
topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
|
||||
c_thread_buf_fp32[cidx];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
@@ -1453,17 +1686,8 @@ struct GridwiseMoeGemm
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
if(i.value == 1)
|
||||
{
|
||||
ptr_ +=
|
||||
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
|
||||
}
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
@@ -1526,7 +1750,8 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
@@ -1538,7 +1763,6 @@ struct GridwiseMoeGemm
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
@@ -1568,35 +1792,21 @@ struct GridwiseMoeGemm
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads =
|
||||
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets;
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
if constexpr(sizeof(ADataType) < 2)
|
||||
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
|
||||
else
|
||||
weight = p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1604,7 +1814,7 @@ struct GridwiseMoeGemm
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_thread_buf_fp32,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
@@ -1617,8 +1827,7 @@ struct GridwiseMoeGemm
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
@@ -1643,9 +1852,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
|
||||
const index_t* p_sorted_expert_ids,
|
||||
const index_t* p_max_token_id,
|
||||
@@ -1721,7 +1928,7 @@ struct GridwiseMoeGemm
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
|
||||
token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats>
|
||||
StaticallyIndexedArray<IndexType, AMRepeats>
|
||||
gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
@@ -1730,7 +1937,7 @@ struct GridwiseMoeGemm
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
|
||||
@@ -1773,6 +1980,7 @@ struct GridwiseMoeGemm
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
2>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1967,11 +2175,12 @@ struct GridwiseMoeGemm
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
if(i.value == 1)
|
||||
{
|
||||
ptr_ +=
|
||||
expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1);
|
||||
}
|
||||
// if(i.value == 1)
|
||||
// {
|
||||
// ptr_ +=
|
||||
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N :
|
||||
// 1);
|
||||
// }
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
@@ -2036,7 +2245,8 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
@@ -2078,12 +2288,9 @@ struct GridwiseMoeGemm
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads =
|
||||
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos =
|
||||
@@ -2091,23 +2298,11 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
if constexpr(sizeof(ADataType) < 2)
|
||||
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
|
||||
else
|
||||
weight = p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -2128,8 +2323,7 @@ struct GridwiseMoeGemm
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights);
|
||||
scatter_offsets);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
|
||||
@@ -41,6 +41,7 @@ template <typename SliceLengths,
|
||||
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
@@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
|
||||
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
@@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
auto gather_offset =
|
||||
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, true);
|
||||
|
||||
@@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
DstCoord dst_coord_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -43,6 +43,7 @@ template <typename SrcDatas,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename IndexType,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
@@ -153,7 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
@@ -172,31 +172,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
else if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
}
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
});
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
@@ -412,7 +389,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
OOBCheck(thread_scratch_id);
|
||||
@@ -420,8 +397,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto scatter_offset = 0;
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
IndexType scatter_offset = 0;
|
||||
if constexpr(OutputScatter)
|
||||
{
|
||||
constexpr auto iScatter =
|
||||
@@ -431,8 +408,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
@@ -488,10 +467,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunRead(src_descs, src_bufs);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ template <AddressSpaceEnum BufferAddressSpace,
|
||||
typename T,
|
||||
typename ElementSpaceSize,
|
||||
bool InvalidElementUseNumericalZeroValue,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence,
|
||||
typename IndexType = index_t>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
@@ -59,16 +60,16 @@ struct DynamicBuffer
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
||||
__host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; }
|
||||
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
__host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value ||
|
||||
!is_native_type<X>(),
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||
__host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
@@ -79,7 +80,7 @@ struct DynamicBuffer
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_LOAD
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -140,7 +141,7 @@ struct DynamicBuffer
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::Set)
|
||||
{
|
||||
@@ -191,8 +192,8 @@ struct DynamicBuffer
|
||||
|
||||
template <typename DstBuffer, index_t NumElemsPerThread>
|
||||
__host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
|
||||
index_t src_offset,
|
||||
index_t dst_offset,
|
||||
IndexType src_offset,
|
||||
IndexType dst_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
// Copy data from global to LDS memory using direct loads.
|
||||
@@ -214,7 +215,7 @@ struct DynamicBuffer
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value ||
|
||||
!is_native_type<X>(),
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
@@ -224,8 +225,8 @@ struct DynamicBuffer
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_STORE
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
#if CK_USE_AMD_BUFFER_LOAD
|
||||
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -342,11 +343,12 @@ struct DynamicBuffer
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if 0
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
// if(i >= 2169041600)
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#endif
|
||||
}
|
||||
@@ -357,7 +359,7 @@ struct DynamicBuffer
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x)
|
||||
{
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
|
||||
@@ -378,12 +380,14 @@ struct DynamicBuffer
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
|
||||
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
|
||||
sizeof(IndexType) <= sizeof(int32_t) &&
|
||||
(is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -408,12 +412,12 @@ struct DynamicBuffer
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr IndexType scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr IndexType scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -421,8 +425,9 @@ struct DynamicBuffer
|
||||
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -455,6 +460,17 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el
|
||||
p, element_space_size};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum BufferAddressSpace,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence,
|
||||
typename T,
|
||||
typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p,
|
||||
ElementSpaceSize element_space_size)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true, coherence, long_index_t>{
|
||||
p, element_space_size};
|
||||
}
|
||||
|
||||
template <
|
||||
AddressSpaceEnum BufferAddressSpace,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence,
|
||||
|
||||
@@ -23,6 +23,13 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
|
||||
return generate_tuple_for(f, make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
|
||||
{
|
||||
|
||||
@@ -23,12 +23,14 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
bool MulRoutedWeight = false,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
index_t ActivationType_ = 0,
|
||||
bool MulRoutedWeight = true,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceMoeGemm : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
static constexpr auto ActivationType = ActivationType_;
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ck::index_t>& sorted_token_ids,
|
||||
@@ -36,7 +38,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ck::index_t>& max_token_id,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<float>& a_scale_t,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
const Tensor<float>& b_scale_e_n,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
const Tensor<D2DataType>& d2,
|
||||
AElementwiseOperation a_element_op,
|
||||
@@ -47,7 +51,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
max_token_id_{max_token_id},
|
||||
sorted_tile_size_{sorted_tile_size},
|
||||
a_t_k_{a_t_k},
|
||||
a_scale_t_{a_scale_t},
|
||||
b_e_n_k_{b_e_n_k},
|
||||
b_scale_e_n_{b_scale_e_n},
|
||||
c_t_k_n_{c_t_k_n},
|
||||
d2_{d2},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -61,7 +67,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ck::index_t>& max_token_id_;
|
||||
index_t sorted_tile_size_;
|
||||
const Tensor<ADataType>& a_t_k_;
|
||||
const Tensor<float>& a_scale_t_;
|
||||
const Tensor<BDataType>& b_e_n_k_;
|
||||
const Tensor<float>& b_scale_e_n_;
|
||||
Tensor<CDataType>& c_t_k_n_;
|
||||
const Tensor<D2DataType>& d2_;
|
||||
|
||||
@@ -77,11 +85,17 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
static_assert(ActivationType < 2, "Not supported activation type");
|
||||
const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2];
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
|
||||
AccDataType v_acc_up{0};
|
||||
ComputeTypeB v_b_up{0};
|
||||
AccDataType v_acc{0};
|
||||
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
|
||||
const int t = arg.sorted_token_ids_(m) & 0xffffff;
|
||||
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
|
||||
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
|
||||
@@ -102,7 +116,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_a = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_a = i4 - 8;
|
||||
v_a = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -112,42 +126,79 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
|
||||
uint8_t i4 = 0;
|
||||
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
|
||||
uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data;
|
||||
uint8_t i4 = 0;
|
||||
uint8_t i4_up = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
{
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
i4_up = (i4x2_up >> 0) & 0xf;
|
||||
}
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
{
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
i4_up = (i4x2_up >> 4) & 0xf;
|
||||
}
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
v_b_up = i4_to_f32_gfx9(i4_up);
|
||||
#else
|
||||
v_b = i4 - 8;
|
||||
v_b = i4 - 8;
|
||||
v_b_up = i4_up - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
|
||||
arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n));
|
||||
}
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
v_acc_up += ck::type_convert<AccDataType>(v_a) *
|
||||
ck::type_convert<AccDataType>(v_b_up);
|
||||
}
|
||||
CDataType v_c{0};
|
||||
|
||||
CDataType v_c_up{0};
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
v_acc *= v_topk_w;
|
||||
v_acc_up *= v_topk_w;
|
||||
}
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
arg.c_element_op_(v_c_up, v_acc_up);
|
||||
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c;
|
||||
if constexpr(ActivationType == 1)
|
||||
{
|
||||
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
v_c_up *= 16;
|
||||
v_c *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(v_c, v_c);
|
||||
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
|
||||
}
|
||||
else if constexpr(ActivationType == 0)
|
||||
{
|
||||
v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t);
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
v_c_up *= 16;
|
||||
v_c *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(v_c, v_c);
|
||||
v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t);
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const ck::index_t max_token_id = arg.max_token_id_(0);
|
||||
make_ParallelTensorFunctor(
|
||||
f_mk_kn_mn, max_token_id, arg.c_t_k_n_.mDesc.GetLengths()[2])(
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
@@ -173,7 +224,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ck::index_t>& max_token_id,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<float>& a_scale_n,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
const Tensor<float>& b_scale_e_n,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
const Tensor<D2DataType>& d2,
|
||||
AElementwiseOperation a_element_op,
|
||||
@@ -185,7 +238,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
max_token_id,
|
||||
sorted_tile_size,
|
||||
a_t_k,
|
||||
a_scale_n,
|
||||
b_e_n_k,
|
||||
b_scale_e_n,
|
||||
c_t_k_n,
|
||||
d2,
|
||||
a_element_op,
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
bool MulRoutedWeight = false,
|
||||
bool MulRoutedWeight = true,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
|
||||
Reference in New Issue
Block a user