Merge branch 'develop' into test_copy_fix

This commit is contained in:
Max Podkorytov
2025-07-16 11:23:57 -07:00
committed by GitHub
34 changed files with 849 additions and 262 deletions

View File

@@ -325,12 +325,50 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else
{
throw std::runtime_error("todo: only v1 & v2 support now");
}
}
#endif

View File

@@ -1112,7 +1112,7 @@ struct GridwiseMoeGemm
}
// check gridwise gemm pipeline
#if 1
#if 0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)

View File

@@ -467,7 +467,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
return make_naive_tensor_descriptor_packed(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
}
@@ -1474,7 +1474,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1567,7 +1567,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
@@ -2185,7 +2185,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
get_warp_local_1d_id() % NWave,
0,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -2289,7 +2289,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
get_warp_local_1d_id() % NWave,
0,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(

View File

@@ -1396,8 +1396,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1427,8 +1427,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
@@ -1459,8 +1459,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1490,8 +1490,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
@@ -1522,8 +1522,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1553,8 +1553,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
@@ -1585,8 +1585,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
@@ -1616,8 +1616,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,

View File

@@ -8,6 +8,7 @@
#include <cstring>
#include <string>
#include <string_view>
#include <map>
namespace ck {
namespace internal {

View File

@@ -33,7 +33,7 @@ __device__ void block_sync_lds_direct_load()
{
#ifdef __gfx12__
asm volatile("\
s_wait_vmcnt 0x0 \n \
s_wait_loadcnt 0x0 \n \
s_wait_dscnt 0x0 \n \
s_barrier_signal -1 \n \
s_barrier_wait -1 \

View File

@@ -74,7 +74,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));

View File

@@ -196,7 +196,7 @@ struct GemmKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
// clang-format on
}

View File

@@ -57,7 +57,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
@@ -95,7 +95,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static constexpr auto
CK_TILE_HOST static auto
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
{
index_t grid_size = 0;

View File

@@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
@@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
@@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
@@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
#elif defined(__gfx908__) || defined(__gfx90a__)
CVecType c_vec{0.f};
static_for<0, 8, 1>{}([&](auto k) {
@@ -1580,7 +1580,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
#if defined(__gfx94__) or defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
@@ -1650,7 +1650,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
{
#if defined(__gfx94__) or defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
@@ -1709,7 +1709,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
{
#if defined(__gfx95__)
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
@@ -1767,8 +1767,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
else
{
#if defined(__gfx95__)
c_vec =
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;

View File

@@ -272,4 +272,137 @@ struct BlockReduce2dCrossWarpSync
}
};
template <typename Problem_, typename Policy_ = void>
struct BlockReduce2dTreeCrossWarpSync
{
using Problem = remove_cvref_t<Problem_>;
using BlockShape = typename Problem::BlockShape;
template <typename YDistributedTensor_>
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
{
constexpr index_t num_reduce_warps = [&]() {
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_warp = 0;
index_t len_ = 1;
static_for<0, NDimR, 1>{}([&](auto idim_r) {
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
len_ *= r_length;
}
});
return len_;
}();
return num_reduce_warps;
}
// return in byte
template <typename YDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
using DataType = typename YDistributedTensor_::DataType;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
if constexpr(num_reduce_warps == 1)
return;
// Each warp's lane 0 writes its partial results to shared memory
const index_t smem_offset = warp_id;
if(lane_id == 0)
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
// Store the i-th element of this warp's thread_buffer into SMEM
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
});
}
block_sync_lds();
// We let each warp holds a duplication to do reduction.
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v = 0;
if(lane_id < num_reduce_warps)
{
v = smem_ptr[lane_id + i * num_warps];
}
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// pull data from remote lane
const auto o =
__shfl_xor(v, number<lid_over_rid_derivative << istage.value>{}.value);
// reduce
v = reduce_func(v, o);
});
}
});
y_tensor.get_thread_buffer()(i) = v;
});
}
};
} // namespace ck_tile

View File

@@ -5,6 +5,7 @@
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"

View File

@@ -58,13 +58,14 @@ struct Rmsnorm2dFwd
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
@@ -150,6 +151,8 @@ struct Rmsnorm2dFwd
if (kPadN) n += "_pn";
if (kSaveInvRms) n += "_rms";
if (kTwoPass) n += "_2p";
if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL) n += "_nsm";
else if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE) n += "_t5ml";
return n; }();
auto prec_str = [&] () {

View File

@@ -69,6 +69,15 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dTreeCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{

View File

@@ -0,0 +1,228 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
/**
* @brief This T5Pass implements the RMSNorm2d forward pipeline as a variant
* based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like
* method.
*
* The T5 model, developed by Google, is a transformer-based architecture designed to perform
* a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS
* normalization is handled, particularly where intermediate values are cast to BF16. This aims to
* achieve a similar value distribution to that produced by the VLLM hip implementation, thereby
* enhancing model accuracy.
*
* Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is
* not guaranteed to eliminate all differences or ensure uniform outcomes across every use case.
*
* This implementation is a variant based on the original one-pass and two-pass approaches,
* allowing for both fused and non-fused add operations.
*/
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_op"; // block per row
else
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename UnquantYWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
UnquantYWindow& unquant_y_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_tree_cross_warp_sync =
Policy::template GetBlockReduce2dTreeCrossWarpSync<Problem>();
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
[[maybe_unused]] auto pre_out =
make_static_distributed_tensor<YResidualDataType>(x.get_tile_distribution());
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
// To make norm input align with residual output
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
{
pre_out(idx) = float_to_bf16<bf16_rounding_mode::standard>(acc(idx));
}
else
{
pre_out(idx) = type_convert<YResidualDataType>(acc(idx));
}
acc(idx) = type_convert<ComputeDataType>(pre_out(idx));
}
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, pre_out);
}
}
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
set_tile(square_sum, 0);
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
{
sweep_tile(
acc,
[&](auto idx_0, auto idx_1) {
square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1];
},
sequence<1, 2>{});
}
else
{
square_sum = block_reduce2d(acc,
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
{
const auto tmp0 =
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma_);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
rmsn(idx) = rmsn_;
}
else
{
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
rmsn(idx) = rmsn_;
}
});
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
}
else
{
Epilogue{}(y_window_, rmsn);
}
}
};
} // namespace ck_tile

View File

@@ -117,10 +117,7 @@ struct Rmsnorm2dFwdPipelineOnePass
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
},
square_sum);
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));

View File

@@ -37,20 +37,37 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_Q
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
enum class Rmsnorm2dSensitiveEnum
{
NO_SPECIFIC_MODEL = 0,
// T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based
// architecture designed for a variety of NLP tasks. This option mimics T5's approach to
// RMSNorm, aiming to ensure similar value distributions and enhance accuracy.
T5_MODEL_LIKE = 1,
};
// clang-format off
template<Rmsnorm2dSensitiveEnum> struct Rmsnorm2dSensitiveEnumName;
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL> { static constexpr const char * name = "nsm"; };
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE> { static constexpr const char * name = "t5ml"; };
// clang-format on
template <bool kPadN_,
bool kSaveInvRms_,
bool kSaveUnquant_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_>
Rmsnorm2dFusedQuantEnum kFusedQuant_,
Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_>
struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveUnquant = kSaveUnquant_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveUnquant = kSaveUnquant_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
};
} // namespace ck_tile