mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
MX GEMM - New GEMM pipeline for MX data types (#2059)
* Allow selection of mfma_scale instructions * Read B tensor from LDS to VGPR in chunks of 16 in MFMA order * Add constexpr and synchronize return type for `get_exponent_value` * Pass scales by reference and add comments to `mfma_scale_f32_32x32x64` * Add support for microscaling instructions in `XdlopsGemm` * Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper * Remove software implementation of MX GEMM * Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction * Update README * Updated CHANGELOG * Remove unused static methods
This commit is contained in:
committed by
GitHub
parent
d55c9cb313
commit
7106976a72
@@ -30,48 +30,69 @@ enum class MFMA_F8F6F4
|
||||
|
||||
};
|
||||
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
|
||||
template <int32_t BLOCK_M, int32_t BLOCK_N>
|
||||
struct mfma_type_selector;
|
||||
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT>
|
||||
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
|
||||
template <>
|
||||
struct mfma_type_selector<16, 16>
|
||||
{
|
||||
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT>
|
||||
__device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
||||
{
|
||||
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
|
||||
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
|
||||
}
|
||||
|
||||
__device__ void operator()(AFragT const& fragA,
|
||||
const int32_t scale_a,
|
||||
BFragT const& fragB,
|
||||
const int32_t scale_b,
|
||||
AccumFragT& fragAcc)
|
||||
{
|
||||
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
|
||||
op.template run<16, 16, AFragT, BFragT, AccumFragT>(
|
||||
fragA, scale_a, fragB, scale_b, fragAcc);
|
||||
op.template run<16, 16>(fragA, fragB, fragAcc);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT>
|
||||
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
|
||||
template <>
|
||||
struct mfma_type_selector<32, 32>
|
||||
{
|
||||
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT>
|
||||
__device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
||||
{
|
||||
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
|
||||
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
|
||||
op.template run<32, 32>(fragA, fragB, fragAcc);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void operator()(AFragT const& fragA,
|
||||
const int32_t scale_a,
|
||||
template <int32_t BLOCK_M, int32_t BLOCK_N>
|
||||
struct mfma_scale_type_selector;
|
||||
|
||||
template <>
|
||||
struct mfma_scale_type_selector<16, 16>
|
||||
{
|
||||
template <typename AFragT,
|
||||
typename AScaleFragT,
|
||||
typename BFragT,
|
||||
typename BScaleFragT,
|
||||
typename AccumFragT>
|
||||
__device__ static void run(AFragT const& fragA,
|
||||
AScaleFragT const& scale_a,
|
||||
BFragT const& fragB,
|
||||
const int32_t scale_b,
|
||||
BScaleFragT const& scale_b,
|
||||
AccumFragT& fragAcc)
|
||||
{
|
||||
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
|
||||
op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_scale_type_selector<32, 32>
|
||||
{
|
||||
template <typename AFragT,
|
||||
typename AScaleFragT,
|
||||
typename BFragT,
|
||||
typename BScaleFragT,
|
||||
typename AccumFragT>
|
||||
__device__ static void run(AFragT const& fragA,
|
||||
AScaleFragT const& scale_a,
|
||||
BFragT const& fragB,
|
||||
BScaleFragT const& scale_b,
|
||||
AccumFragT& fragAcc)
|
||||
{
|
||||
auto op = mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>{};
|
||||
op.template run<32, 32, AFragT, BFragT, AccumFragT>(
|
||||
fragA, scale_a, fragB, scale_b, fragAcc);
|
||||
op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -334,8 +355,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
|
||||
// BLOCK_K / BLOCK_X is a stride in xA matrix
|
||||
auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X);
|
||||
|
||||
// obtain 8-bit exponent
|
||||
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
|
||||
fragX = scale_ptr[startOffset];
|
||||
|
||||
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(input_ptr);
|
||||
}
|
||||
@@ -502,7 +522,7 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr,
|
||||
auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X);
|
||||
|
||||
// obtain 8-bit exponent
|
||||
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
|
||||
fragX = scale_ptr[startOffset];
|
||||
|
||||
return load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(input_ptr);
|
||||
}
|
||||
@@ -773,7 +793,8 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
|
||||
// Matrix multiply-accumulate using MFMA units
|
||||
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
||||
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
|
||||
using mfma = mfma_type_selector<BLOCK_M, BLOCK_N>;
|
||||
mfma::template run<>(fragA, fragB, fragAcc);
|
||||
|
||||
for(int i = 0; i < vectorSize(fragC); ++i)
|
||||
{
|
||||
@@ -805,29 +826,34 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using ScaleFragT = int32_t;
|
||||
using AScaleFragT = vector_type<ScaleType, 1>::type;
|
||||
using BScaleFragT = vector_type<ScaleType, 1>::type;
|
||||
|
||||
// Create frags
|
||||
auto fragA = AFragT{};
|
||||
auto fragB = BFragT{};
|
||||
auto fragC = CFragT{};
|
||||
auto fragAcc = AccumFragT{0};
|
||||
auto fragXa = ScaleFragT{0};
|
||||
auto fragXb = ScaleFragT{0};
|
||||
auto fragXa = AScaleFragT{};
|
||||
auto fragXb = BScaleFragT{};
|
||||
|
||||
// Load the inputs.
|
||||
// A = col major, BLOCK_M x BLOCK_K
|
||||
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, ScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
|
||||
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
|
||||
a, xa, fragXa);
|
||||
|
||||
// B = col major, BLOCK_K x BLOCK_N
|
||||
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, ScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
|
||||
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
|
||||
b, xb, fragXb);
|
||||
|
||||
// Scaled Matrix multiply-accumulate using MFMA units
|
||||
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
||||
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
|
||||
fragA, fragXa, fragB, fragXb, fragAcc);
|
||||
using mfma = mfma_scale_type_selector<BLOCK_M, BLOCK_N>;
|
||||
mfma::template run<>(fragA,
|
||||
fragXa.template AsType<ScaleType>(),
|
||||
fragB,
|
||||
fragXb.template AsType<ScaleType>(),
|
||||
fragAcc);
|
||||
|
||||
for(int i = 0; i < vectorSize(fragC); ++i)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user