mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Merge remote-tracking branch 'origin/lwpck-3447' into mmflat
This commit is contained in:
@@ -467,7 +467,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<WarpSize * KPack>{};
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
}
|
||||
@@ -1474,7 +1474,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1567,7 +1567,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
@@ -2185,7 +2185,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2289,7 +2289,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
|
||||
@@ -1396,8 +1396,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1427,8 +1427,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1459,8 +1459,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1490,8 +1490,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1522,8 +1522,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1553,8 +1553,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1585,8 +1585,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
@@ -1616,8 +1616,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <map>
|
||||
|
||||
namespace ck {
|
||||
namespace internal {
|
||||
|
||||
@@ -33,7 +33,7 @@ __device__ void block_sync_lds_direct_load()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("\
|
||||
s_wait_vmcnt 0x0 \n \
|
||||
s_wait_loadcnt 0x0 \n \
|
||||
s_wait_dscnt 0x0 \n \
|
||||
s_barrier_signal -1 \n \
|
||||
s_barrier_wait -1 \
|
||||
|
||||
@@ -74,7 +74,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
|
||||
@@ -196,7 +196,7 @@ struct GemmKernel
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
@@ -95,7 +95,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
|
||||
@@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
@@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
@@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
@@ -1580,7 +1580,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
{
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
@@ -1650,7 +1650,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
|
||||
{
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1709,7 +1709,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1767,8 +1767,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec =
|
||||
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
|
||||
a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
|
||||
Reference in New Issue
Block a user