mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
support swiglu activaion and use rcpf to accelerate silu
This commit is contained in:
@@ -205,14 +205,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
: col * strideC + scatter_token_id;
|
||||
if constexpr(MoeGemmKind < 2)
|
||||
{
|
||||
AccDataType acc_gate = ActivationOp{}(acc);
|
||||
C[c_index] =
|
||||
ck_tile::type_convert<CDataType>(MoeGemmKind == 1 ? acc_gate * acc_up : acc_gate);
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(
|
||||
ActivationOp{}(acc, MoeGemmKind == 1 ? acc_up : 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
CDataType res =
|
||||
ck_tile::type_convert<CDataType>(ActivationOp{}(acc * expert_weight_ptr[row]));
|
||||
// moe gemm2 don't use activation.
|
||||
CDataType res = ck_tile::type_convert<CDataType>(acc * expert_weight_ptr[row]);
|
||||
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
ck_tile::fp16x2_t,
|
||||
ck_tile::bf16x2_t>;
|
||||
|
||||
@@ -863,7 +863,14 @@ struct Silu
|
||||
std::is_same_v<T, int32_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
y = x * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x));
|
||||
}
|
||||
else
|
||||
{
|
||||
y = x * (one / (one + ck_tile::exp(-x)));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -1218,7 +1225,7 @@ struct Swish
|
||||
|
||||
struct SoftRelu
|
||||
{
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1237,7 +1244,7 @@ struct SoftRelu
|
||||
struct Power
|
||||
{
|
||||
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma){};
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1259,7 +1266,7 @@ struct Power
|
||||
|
||||
struct ClippedRelu
|
||||
{
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1278,7 +1285,7 @@ struct ClippedRelu
|
||||
|
||||
struct LeakyRelu
|
||||
{
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1295,7 +1302,7 @@ struct LeakyRelu
|
||||
|
||||
struct Elu
|
||||
{
|
||||
Elu(float alpha = 1.f) : alpha_(alpha){};
|
||||
Elu(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
@@ -1312,7 +1319,7 @@ struct Elu
|
||||
|
||||
struct Logistic
|
||||
{
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha){};
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
|
||||
@@ -24,17 +24,21 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
|
||||
{
|
||||
if (TailNumber::Even == tail_num)
|
||||
if(TailNumber::Even == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if (TailNumber::Odd == tail_num)
|
||||
else if(TailNumber::Odd == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
|
||||
// TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -52,16 +56,17 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
static constexpr auto config =
|
||||
BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
@@ -105,64 +110,67 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload: MIterPerWarp * KIterPerWarp;
|
||||
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
/*
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
|
||||
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
|
||||
/*
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
|
||||
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
|
||||
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
|
||||
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
*/
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
*/
|
||||
|
||||
// #if (defined(USING_MFMA_16x16x32_F8) || \
|
||||
// defined(USING_MFMA_32x32x16_F8) || \
|
||||
// defined(USING_MFMA_16x16x16_F16) || \
|
||||
// defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5
|
||||
// static constexpr auto mfma_per_wg = 2;
|
||||
// static constexpr auto dsread_per_wg = 1;
|
||||
// #elif (defined(USING_MFMA_16x16x32_F16) || \
|
||||
// defined(USING_MFMA_32x32x16_F16) || \
|
||||
// defined(USING_MFMA_16x16x128_F4) || \
|
||||
// defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1
|
||||
// static constexpr auto mfma_per_wg = 1;
|
||||
// static constexpr auto dsread_per_wg = 1;
|
||||
// #elif (defined(USING_MFMA_16x16x128_F8) || \
|
||||
// defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2
|
||||
// static constexpr auto mfma_per_wg = 1;
|
||||
// static constexpr auto dsread_per_wg = 2;
|
||||
// #endif
|
||||
#ifdef __gfx942__
|
||||
static constexpr index_t mfma_per_wg = 2;
|
||||
#else
|
||||
static constexpr index_t mfma_per_wg = 1;
|
||||
#endif
|
||||
static constexpr index_t dsread_per_wg = WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
|
||||
// #if (defined(USING_MFMA_16x16x32_F8) || \
|
||||
// defined(USING_MFMA_32x32x16_F8) || \
|
||||
// defined(USING_MFMA_16x16x16_F16) || \
|
||||
// defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5
|
||||
// static constexpr auto mfma_per_wg = 2;
|
||||
// static constexpr auto dsread_per_wg = 1;
|
||||
// #elif (defined(USING_MFMA_16x16x32_F16) || \
|
||||
// defined(USING_MFMA_32x32x16_F16) || \
|
||||
// defined(USING_MFMA_16x16x128_F4) || \
|
||||
// defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1
|
||||
// static constexpr auto mfma_per_wg = 1;
|
||||
// static constexpr auto dsread_per_wg = 1;
|
||||
// #elif (defined(USING_MFMA_16x16x128_F8) || \
|
||||
// defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2
|
||||
// static constexpr auto mfma_per_wg = 1;
|
||||
// static constexpr auto dsread_per_wg = 2;
|
||||
// #endif
|
||||
#ifdef __gfx942__
|
||||
static constexpr index_t mfma_per_wg = 2;
|
||||
#else
|
||||
static constexpr index_t mfma_per_wg = 1;
|
||||
#endif
|
||||
static constexpr index_t dsread_per_wg =
|
||||
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
|
||||
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -185,25 +193,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
return PipelinePolicy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
// Init inst order
|
||||
index_t max_data_inst =
|
||||
dsread_perM > load_perM
|
||||
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
|
||||
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
|
||||
index_t max_data_inst = dsread_perM > load_perM
|
||||
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
|
||||
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
|
||||
index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
|
||||
index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
|
||||
|
||||
index_t inst_order[NIterPerWarp * 10];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for(int idx = 0; idx < NIterPerWarp * 10; idx++)
|
||||
{
|
||||
inst_order[idx] = 0;
|
||||
}
|
||||
|
||||
index_t index = 0;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for(int j = 0; j < max_data_inst; j++)
|
||||
{
|
||||
if(dswrite_perM > j)
|
||||
@@ -223,8 +231,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule IGLP
|
||||
#pragma unroll
|
||||
// Schedule IGLP
|
||||
#pragma unroll
|
||||
for(int j = 0; j < mfma_perM_perK; j++)
|
||||
{
|
||||
index_t inst_idx = 0;
|
||||
@@ -236,10 +244,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
inst_idx = mfma_perM_perK - 1;
|
||||
else
|
||||
inst_idx = mfma_perM_perK - j;
|
||||
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
#pragma unroll
|
||||
|
||||
#pragma unroll
|
||||
for(int r = 0; r < round_data_inst; r++)
|
||||
{
|
||||
if(r % 2 == 0)
|
||||
@@ -284,84 +292,84 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// -1 M6N1: 58 1 - - -
|
||||
// -1 M6N2: 59 - - 7 -
|
||||
// -1 M6N3: 60 2 - - -
|
||||
// -1 M7N0: 61 - - - -
|
||||
// -1 M7N1: 62 3 - - -
|
||||
// -1 M7N2: 63 - - 8 -
|
||||
// -1 M7N3: 64 4 - - -
|
||||
// 0 M0N0K0: 1 - - - 1
|
||||
// 0 M0N1: 2 5 - - -
|
||||
// 0 M0N2: 3 - - - 2
|
||||
// 0 M0N3: 4 6 - - -
|
||||
// 0 M1N0: 5 - - - 3
|
||||
// 0 M1N1: 6 7 - - -
|
||||
// 0 M1N2: 7 - - - 4
|
||||
// 0 M1N3: 8 8 - - -
|
||||
// 0 M2N0: 9 - - - 5
|
||||
// 0 M2N1: 10 9 - - -
|
||||
// 0 M2N2: 11 - - - 6
|
||||
// 0 M2N3: 12 10 - - -
|
||||
// 0 M3N0: 13 - 1 - 7
|
||||
// 0 M3N1: 14 11 - - -
|
||||
// 0 M3N2: 15 - - - 8
|
||||
// -1 M7N0: 61 - - - -
|
||||
// -1 M7N1: 62 3 - - -
|
||||
// -1 M7N2: 63 - - 8 -
|
||||
// -1 M7N3: 64 4 - - -
|
||||
// 0 M0N0K0: 1 - - - 1
|
||||
// 0 M0N1: 2 5 - - -
|
||||
// 0 M0N2: 3 - - - 2
|
||||
// 0 M0N3: 4 6 - - -
|
||||
// 0 M1N0: 5 - - - 3
|
||||
// 0 M1N1: 6 7 - - -
|
||||
// 0 M1N2: 7 - - - 4
|
||||
// 0 M1N3: 8 8 - - -
|
||||
// 0 M2N0: 9 - - - 5
|
||||
// 0 M2N1: 10 9 - - -
|
||||
// 0 M2N2: 11 - - - 6
|
||||
// 0 M2N3: 12 10 - - -
|
||||
// 0 M3N0: 13 - 1 - 7
|
||||
// 0 M3N1: 14 11 - - -
|
||||
// 0 M3N2: 15 - - - 8
|
||||
// 0 M3N3: 16 12 - - -
|
||||
// 0 M4N0: 17 - 2 - -
|
||||
// 0 M4N1: 18 13 - - -
|
||||
// 0 M4N2: 19 - - 1 -
|
||||
// 0 M4N0: 17 - 2 - -
|
||||
// 0 M4N1: 18 13 - - -
|
||||
// 0 M4N2: 19 - - 1 -
|
||||
// 0 M4N3: 20 14 - - -
|
||||
// 0 M5N0: 21 - 3 - -
|
||||
// 0 M5N1: 22 15 - - -
|
||||
// 0 M5N2: 23 - - 2 -
|
||||
// 0 M5N0: 21 - 3 - -
|
||||
// 0 M5N1: 22 15 - - -
|
||||
// 0 M5N2: 23 - - 2 -
|
||||
// 0 M5N3: 24 16 - - -
|
||||
// 0 M6N0: 25 - 4 - -
|
||||
// 0 M6N1: 26 17 - - -
|
||||
// 0 M6N2: 27 - - 3 -
|
||||
// 0 M6N0: 25 - 4 - -
|
||||
// 0 M6N1: 26 17 - - -
|
||||
// 0 M6N2: 27 - - 3 -
|
||||
// 0 M6N3: 28 18 - - -
|
||||
// 0 M7N0: 29 - - - -
|
||||
// 0 M7N1: 30 19 - - -
|
||||
// 0 M7N2: 31 - - 4 -
|
||||
// 0 M7N0: 29 - - - -
|
||||
// 0 M7N1: 30 19 - - -
|
||||
// 0 M7N2: 31 - - 4 -
|
||||
// 0 M7N3: 32 20 - - -
|
||||
// 0 M0N0K1: 33 - - - 9
|
||||
// 0 M0N1: 34 21 - - -
|
||||
// 0 M0N2: 35 - - - 10
|
||||
// 0 M0N3: 36 22 - - -
|
||||
// 0 M1N0: 37 - - - 11
|
||||
// 0 M1N1: 38 23 - - -
|
||||
// 0 M1N2: 39 - - - 12
|
||||
// 0 M0N0K1: 33 - - - 9
|
||||
// 0 M0N1: 34 21 - - -
|
||||
// 0 M0N2: 35 - - - 10
|
||||
// 0 M0N3: 36 22 - - -
|
||||
// 0 M1N0: 37 - - - 11
|
||||
// 0 M1N1: 38 23 - - -
|
||||
// 0 M1N2: 39 - - - 12
|
||||
// 0 M1N3: 40 24 - - -
|
||||
// 0 M2N0: 41 - - - 13
|
||||
// 0 M2N1: 42 25 - - -
|
||||
// 0 M2N2: 43 - - - 14
|
||||
// 0 M2N3: 44 26 - - -
|
||||
// 0 M3N0: 45 - 5 - 15
|
||||
// 0 M3N1: 46 27 - - -
|
||||
// 0 M3N2: 47 - - - 16
|
||||
// 0 M2N0: 41 - - - 13
|
||||
// 0 M2N1: 42 25 - - -
|
||||
// 0 M2N2: 43 - - - 14
|
||||
// 0 M2N3: 44 26 - - -
|
||||
// 0 M3N0: 45 - 5 - 15
|
||||
// 0 M3N1: 46 27 - - -
|
||||
// 0 M3N2: 47 - - - 16
|
||||
// 0 M3N3: 48 28 - - -
|
||||
// 0 M4N0: 49 - 6 - -
|
||||
// 0 M4N1: 50 29 - - -
|
||||
// 0 M4N2: 51 - - 5 -
|
||||
// 0 M4N0: 49 - 6 - -
|
||||
// 0 M4N1: 50 29 - - -
|
||||
// 0 M4N2: 51 - - 5 -
|
||||
// 0 M4N3: 52 30 - - -
|
||||
// 0 M5N0: 53 - 7 - -
|
||||
// 0 M5N1: 54 31 - - -
|
||||
// 0 M5N2: 55 - - 6 -
|
||||
// 0 M5N0: 53 - 7 - -
|
||||
// 0 M5N1: 54 31 - - -
|
||||
// 0 M5N2: 55 - - 6 -
|
||||
// 0 M5N3: 56 32 - - -
|
||||
// 0 M6N0: 57 - 8 - -
|
||||
// 0 M6N1: 58 1 - - -
|
||||
// 0 M6N2: 59 - - 7 -
|
||||
// 0 M6N0: 57 - 8 - -
|
||||
// 0 M6N1: 58 1 - - -
|
||||
// 0 M6N2: 59 - - 7 -
|
||||
// 0 M6N3: 60 2 - - -
|
||||
// 0 M7N0: 61 - - - -
|
||||
// 0 M7N1: 62 3 - - -
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N0: 61 - - - -
|
||||
// 0 M7N1: 62 3 - - -
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N3: 64 4 - - -
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
dsread_perM = dsread_per_wg;
|
||||
@@ -380,10 +388,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM =
|
||||
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
dswrite_perM = (dswrite_num_perK -
|
||||
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
// Add ds write when ds write data > needed
|
||||
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
|
||||
@@ -397,13 +405,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
load_perM =
|
||||
((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
|
||||
: 0) +
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep : 0);
|
||||
: 0) +
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
|
||||
: 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_perM =
|
||||
(Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep : 0;
|
||||
load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
@@ -413,18 +423,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
dsread_perM = dsread_per_wg;
|
||||
@@ -443,10 +453,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM =
|
||||
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
dswrite_perM = (dswrite_num_perK -
|
||||
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
// Add ds write when ds write data > needed
|
||||
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
|
||||
@@ -459,7 +469,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
if(mIter < HalfMIter)
|
||||
{
|
||||
load_perM =
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep : 0);
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
|
||||
: 0);
|
||||
}
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
@@ -468,20 +479,20 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
{
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
if ((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
dsread_perM = dsread_per_wg;
|
||||
if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
@@ -507,7 +518,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
"wrong!");
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
@@ -517,7 +528,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
@@ -525,8 +536,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
@@ -543,22 +556,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp = make_tile_window(
|
||||
a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
auto a_warp_window_pong_tmp = make_tile_window(
|
||||
a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
@@ -569,7 +582,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
@@ -619,7 +632,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -632,7 +644,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
@@ -640,19 +652,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// Prefill A0
|
||||
// if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
// {
|
||||
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
// shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
// store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
// }
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -668,12 +667,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10... from lds
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), m_preload> a_warp_tensor;
|
||||
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -687,7 +689,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
@@ -701,7 +703,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -709,14 +711,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
@@ -724,14 +728,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
@@ -745,10 +751,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
@@ -757,12 +764,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
@@ -782,10 +789,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
@@ -793,14 +802,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
@@ -814,7 +825,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -830,7 +842,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
@@ -847,14 +859,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
@@ -862,14 +876,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
@@ -880,11 +896,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
@@ -892,27 +909,31 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
//barrier
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
@@ -930,14 +951,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
@@ -945,14 +968,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
@@ -974,7 +999,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
|
||||
@@ -77,11 +77,59 @@ enum class MoeFlatmmKind
|
||||
kFFN_gemm2,
|
||||
};
|
||||
|
||||
namespace moe {
|
||||
|
||||
struct MoeSilu
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(T gate, T linear = 1) const
|
||||
{
|
||||
ck_tile::element_wise::Silu{}(gate, gate);
|
||||
return gate * linear;
|
||||
};
|
||||
};
|
||||
|
||||
struct Swiglu
|
||||
{
|
||||
float alpha = 1.702f; // default value used in gpt-oss
|
||||
float limit = 7.0f; // default value used in gpt-oss
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu() = default;
|
||||
CK_TILE_HOST_DEVICE
|
||||
Swiglu(float alpha_, float limit_) : alpha(alpha_), limit(limit_) {}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
|
||||
{
|
||||
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, int32_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
constexpr T one = type_convert<T>(1);
|
||||
|
||||
gate = gate < limit ? gate : limit;
|
||||
linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
|
||||
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
return gate * __builtin_amdgcn_rcpf(one + ck_tile::exp(alpha * -gate)) * (linear + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gate * (one / (one + ck_tile::exp(alpha * -gate))) * (linear + 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace moe
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename FlatmmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
MoeFlatmmKind kind,
|
||||
typename FusedActivation = element_wise::Silu>
|
||||
typename FusedActivation = moe::MoeSilu>
|
||||
struct MoeFlatmmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -900,11 +948,9 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
gate_tensor.get_thread_buffer().at(idx));
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
gate_tensor.get_thread_buffer().at(idx) *
|
||||
up_tensor.get_thread_buffer().at(idx);
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
up_tensor.get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -937,8 +983,8 @@ struct MoeFlatmmKernel
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx),
|
||||
lds_tile[0].get_thread_buffer().at(idx));
|
||||
lds_tile[0].get_thread_buffer().at(idx) =
|
||||
ActivationOp{}(lds_tile[0].get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1022,11 +1068,9 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
gate_tensor.get_thread_buffer().at(idx));
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) =
|
||||
gate_tensor.get_thread_buffer().at(idx) *
|
||||
up_tensor.get_thread_buffer().at(idx);
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
up_tensor.get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -1068,8 +1112,8 @@ struct MoeFlatmmKernel
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(lds_tile[write_stage].get_thread_buffer().at(idx),
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx));
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx) = ActivationOp{}(
|
||||
lds_tile[write_stage].get_thread_buffer().at(idx));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user