mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Merge branch 'develop' into test_copy_fix
This commit is contained in:
@@ -325,12 +325,50 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v2 support now");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1112,7 +1112,7 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
#if 1
|
||||
#if 0
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
|
||||
|
||||
@@ -467,7 +467,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
}
|
||||
@@ -1474,7 +1474,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1567,7 +1567,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
@@ -2185,7 +2185,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2289,7 +2289,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
|
||||
@@ -1396,8 +1396,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1427,8 +1427,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1459,8 +1459,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1490,8 +1490,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1522,8 +1522,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1553,8 +1553,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1585,8 +1585,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1616,8 +1616,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <map>
|
||||
|
||||
namespace ck {
|
||||
namespace internal {
|
||||
|
||||
@@ -33,7 +33,7 @@ __device__ void block_sync_lds_direct_load()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("\
|
||||
s_wait_vmcnt 0x0 \n \
|
||||
s_wait_loadcnt 0x0 \n \
|
||||
s_wait_dscnt 0x0 \n \
|
||||
s_barrier_signal -1 \n \
|
||||
s_barrier_wait -1 \
|
||||
|
||||
@@ -74,7 +74,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
|
||||
@@ -196,7 +196,7 @@ struct GemmKernel
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
@@ -95,7 +95,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
|
||||
@@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
@@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
@@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
@@ -1580,7 +1580,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
{
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
@@ -1650,7 +1650,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
|
||||
{
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1709,7 +1709,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1767,8 +1767,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec =
|
||||
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
|
||||
a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
|
||||
@@ -272,4 +272,137 @@ struct BlockReduce2dCrossWarpSync
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2dTreeCrossWarpSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
|
||||
template <typename YDistributedTensor_>
|
||||
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
|
||||
{
|
||||
constexpr index_t num_reduce_warps = [&]() {
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_warp = 0;
|
||||
|
||||
index_t len_ = 1;
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
len_ *= r_length;
|
||||
}
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
return num_reduce_warps;
|
||||
}
|
||||
|
||||
// return in byte
|
||||
template <typename YDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// we need to store all data from every wave into smem
|
||||
// e.g. 2x2 reduce along N
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | ___> | w01 |
|
||||
// | w2 | w3 | | w23 |
|
||||
//
|
||||
// -> store data from every wave into LDS
|
||||
//
|
||||
//
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
|
||||
if constexpr(num_reduce_warps == 1)
|
||||
return;
|
||||
|
||||
// Each warp's lane 0 writes its partial results to shared memory
|
||||
const index_t smem_offset = warp_id;
|
||||
if(lane_id == 0)
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
// Store the i-th element of this warp's thread_buffer into SMEM
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v = 0;
|
||||
if(lane_id < num_reduce_warps)
|
||||
{
|
||||
v = smem_ptr[lane_id + i * num_warps];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// pull data from remote lane
|
||||
const auto o =
|
||||
__shfl_xor(v, number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// reduce
|
||||
v = reduce_func(v, o);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i) = v;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
|
||||
|
||||
@@ -58,13 +58,14 @@ struct Rmsnorm2dFwd
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
@@ -150,6 +151,8 @@ struct Rmsnorm2dFwd
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveInvRms) n += "_rms";
|
||||
if (kTwoPass) n += "_2p";
|
||||
if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL) n += "_nsm";
|
||||
else if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE) n += "_t5ml";
|
||||
return n; }();
|
||||
|
||||
auto prec_str = [&] () {
|
||||
|
||||
@@ -69,6 +69,15 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dTreeCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This T5Pass implements the RMSNorm2d forward pipeline as a variant
|
||||
* based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like
|
||||
* method.
|
||||
*
|
||||
* The T5 model, developed by Google, is a transformer-based architecture designed to perform
|
||||
* a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS
|
||||
* normalization is handled, particularly where intermediate values are cast to BF16. This aims to
|
||||
* achieve a similar value distribution to that produced by the VLLM hip implementation, thereby
|
||||
* enhancing model accuracy.
|
||||
*
|
||||
* Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is
|
||||
* not guaranteed to eliminate all differences or ensure uniform outcomes across every use case.
|
||||
*
|
||||
* This implementation is a variant based on the original one-pass and two-pass approaches,
|
||||
* allowing for both fused and non-fused add operations.
|
||||
*/
|
||||
|
||||
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
|
||||
struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_op"; // block per row
|
||||
else
|
||||
return "wpr_op"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename GammaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename UnquantYWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
YWindow& y_window_,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window_,
|
||||
UnquantYWindow& unquant_y_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
const auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_tree_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dTreeCrossWarpSync<Problem>();
|
||||
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
|
||||
// load gamma (TODO: support no gamma?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
[[maybe_unused]] auto pre_out =
|
||||
make_static_distributed_tensor<YResidualDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
|
||||
// To make norm input align with residual output
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
pre_out(idx) = float_to_bf16<bf16_rounding_mode::standard>(acc(idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
pre_out(idx) = type_convert<YResidualDataType>(acc(idx));
|
||||
}
|
||||
acc(idx) = type_convert<ComputeDataType>(pre_out(idx));
|
||||
}
|
||||
});
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, pre_out);
|
||||
}
|
||||
}
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
|
||||
set_tile(square_sum, 0);
|
||||
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
|
||||
{
|
||||
sweep_tile(
|
||||
acc,
|
||||
[&](auto idx_0, auto idx_1) {
|
||||
square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1];
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
square_sum = block_reduce2d(acc,
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func);
|
||||
}
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
|
||||
|
||||
if constexpr(kSaveInvRms)
|
||||
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
|
||||
|
||||
// rmsnorm computation
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 =
|
||||
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
|
||||
type_convert<ComputeDataType>(tmp0) * gamma_);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
|
||||
rmsn(idx) = rmsn_;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
|
||||
rmsn(idx) = rmsn_;
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(
|
||||
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, rmsn);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -117,10 +117,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
|
||||
},
|
||||
square_sum);
|
||||
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
|
||||
|
||||
if constexpr(kSaveInvRms)
|
||||
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
|
||||
|
||||
@@ -37,20 +37,37 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_Q
|
||||
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
|
||||
// clang-format on
|
||||
|
||||
enum class Rmsnorm2dSensitiveEnum
|
||||
{
|
||||
NO_SPECIFIC_MODEL = 0,
|
||||
// T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based
|
||||
// architecture designed for a variety of NLP tasks. This option mimics T5's approach to
|
||||
// RMSNorm, aiming to ensure similar value distributions and enhance accuracy.
|
||||
T5_MODEL_LIKE = 1,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Rmsnorm2dSensitiveEnum> struct Rmsnorm2dSensitiveEnumName;
|
||||
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL> { static constexpr const char * name = "nsm"; };
|
||||
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE> { static constexpr const char * name = "t5ml"; };
|
||||
// clang-format on
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
Rmsnorm2dFusedAddEnum kFusedAdd_,
|
||||
Rmsnorm2dFusedQuantEnum kFusedQuant_>
|
||||
Rmsnorm2dFusedQuantEnum kFusedQuant_,
|
||||
Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_>
|
||||
struct Rmsnorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user