mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
MX GEMM - New GEMM pipeline for MX data types (#2059)
* Allow selection of mfma_scale instructions * Read B tensor from LDS to VGPR in chunks of 16 in MFMA order * Add constexpr and synchronize return type for `get_exponent_value` * Pass scales by reference and add comments to `mfma_scale_f32_32x32x64` * Add support for microscaling instructions in `XdlopsGemm` * Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper * Remove software implementation of MX GEMM * Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction * Update README * Updated CHANGELOG * Remove unused static methods
This commit is contained in:
committed by
GitHub
parent
d55c9cb313
commit
7106976a72
@@ -0,0 +1,363 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
{
|
||||
using ComputeTypeA = ADataType;
|
||||
using ComputeTypeB = BDataType;
|
||||
using AccType = float; // for now only support V_MFMA_SCALE_F32
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, TransposeC, true>{};
|
||||
|
||||
static constexpr index_t AMmaKStride = KPack;
|
||||
static constexpr index_t BMmaKStride = KPack;
|
||||
|
||||
//> store rows/cols into thread registers in chunks of 16
|
||||
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
static constexpr index_t KThreadChunk = 16;
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
static constexpr index_t KPerInnerLoop = KPack;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
using HotLoopInstList =
|
||||
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
A_K1,
|
||||
B_K1,
|
||||
A_K1,
|
||||
B_K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
xdlops_gemm.KPerXdlops>;
|
||||
|
||||
static_assert(KPerThread % KPack == 0,
|
||||
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccType,
|
||||
MRepeat * NRepeat,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
|
||||
*
|
||||
* This constructor initializes the thread copy objects for matrices A and B.
|
||||
* It also performs several compile-time checks to ensure the correctness of the
|
||||
* matrix tile descriptors.
|
||||
*
|
||||
* @param a_origin The origin data index for matrix A.
|
||||
* @param b_origin The origin data index for matrix B.
|
||||
*
|
||||
* @note The constructor includes static assertions to ensure that:
|
||||
* - The matrix tile descriptors for A and B are known at compile-time.
|
||||
* - The number of threads in the thread block matches the product of MWaves, NWaves, and
|
||||
* WaveSize.
|
||||
* - The dimensions of the block are divisible by the product of the corresponding XDL and
|
||||
* repeat dimensions.
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
|
||||
}
|
||||
|
||||
// XDL output supporting C_xdl = A_xdl * B_xdl
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
// XDL output supporting C_xdl = A_xdl * B_xdl
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1,
|
||||
Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_block_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_G_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_g_m_n,
|
||||
make_tuple(make_pass_through_transform(G),
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
|
||||
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
|
||||
|
||||
protected:
|
||||
// M1, N1 as double buffer index
|
||||
// Read buffer + Compute buffer
|
||||
// A[M0, M1, M2, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
|
||||
make_tuple(
|
||||
Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
|
||||
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
|
||||
make_tuple(
|
||||
Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KThreadChunk>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KThreadChunk>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -7,6 +7,35 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Define matrix data types that have hardware support for MX GEMMs
|
||||
*/
|
||||
template <typename T>
|
||||
static constexpr bool is_scale_mfma_data_type()
|
||||
{
|
||||
return is_same_v<T, f8_ocp_t> || is_same_v<T, bf8_ocp_t> || is_same_v<T, f6_t> ||
|
||||
is_same_v<T, bf6_t> || is_same_v<T, f4_t>;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Define scale data types that have hardware support for MX GEMMs
|
||||
*/
|
||||
template <typename T>
|
||||
static constexpr bool is_scale_mfma_scale_type()
|
||||
{
|
||||
return is_same_v<T, e8m0_bexp_t>;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Combination of data types that have hardware support for MX GEMMs
|
||||
*/
|
||||
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
|
||||
static constexpr bool scale_mfma_hw_support()
|
||||
{
|
||||
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
|
||||
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
|
||||
}
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t ThreadBlockSize,
|
||||
@@ -34,6 +63,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t KPack>
|
||||
constexpr auto BlockGemmMXPipeline_Selector()
|
||||
{
|
||||
|
||||
// Hardware MX GEMM pipeline
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_v1_mx<BlkGemmPipeSche,
|
||||
@@ -43,8 +74,6 @@ constexpr auto BlockGemmMXPipeline_Selector()
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
@@ -62,7 +91,7 @@ constexpr auto BlockGemmMXPipeline_Selector()
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
|
||||
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -20,8 +20,6 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
@@ -46,8 +44,6 @@ template <index_t ThreadBlockSize,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
@@ -69,8 +65,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
@@ -85,46 +79,43 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
@@ -134,7 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
@@ -151,15 +141,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block size
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
@@ -172,45 +173,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], xdlops_gemm.KPerXdlops * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], xdlops_gemm.KPerXdlops * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmXdlops_pipeline_v1_mx.
|
||||
*
|
||||
* The primary purpose of this constructor is to modify default initialization of the base class
|
||||
* with the origin data index suitable for microscaling.
|
||||
*
|
||||
* @param a_origin The origin data index for matrix A.
|
||||
* @param b_origin The origin data index for matrix B.
|
||||
*
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmXdlops_pipeline_v1_mx(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: Base(a_origin, b_origin)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
@@ -258,9 +220,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
@@ -276,49 +238,31 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_assert(xdlops_gemm.mfma_instr.num_groups_per_blk *
|
||||
xdlops_gemm.mfma_instr.group_size ==
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
"Assume num_regs_per_blk == num_groups_per_blk * group_size");
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
|
||||
auto a_scale_thread_buf_group =
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_group.GetElementSpaceSize());
|
||||
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_group,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_group);
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i));
|
||||
a_scale_thread_buf(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_group[Number<i>{}];
|
||||
});
|
||||
// go to the next group
|
||||
a_scale_thread_buf(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
|
||||
}); // g
|
||||
|
||||
// restore row id and advance to the next scale
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
|
||||
xdlops_gemm.mfma_instr.num_groups_per_blk,
|
||||
1));
|
||||
}); // k0
|
||||
|
||||
// restore column id and advance to the next set of rows
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
}); // m0
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
@@ -326,15 +270,32 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, I0),
|
||||
b_scale_thread_buf);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, 0));
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_buf(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
@@ -345,8 +306,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
@@ -363,141 +322,166 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// k indexes mapping to threads for 32x32x64:
|
||||
// t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
|
||||
// t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
|
||||
// k = 0 k = 1
|
||||
|
||||
// k indexes mapping to threads for 16x16x128:
|
||||
// t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
|
||||
// t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
|
||||
// t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
|
||||
// t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
|
||||
// k = 0 k = 1
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<b_k_step>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
// read block data in chunks to assemble correct thread vectors
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto b_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_buf[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_buf[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
// m = 1:MPerXDL
|
||||
// n = 1:NPerXDL
|
||||
// k = 1:KPack
|
||||
// c(m,n) += a(m,k)*b(k,n)
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
|
||||
// one scale per k0
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}(
|
||||
[&](auto g) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}(
|
||||
[&](auto r) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(
|
||||
make_tuple(m0, k0, g, r));
|
||||
|
||||
constexpr auto reg_offset =
|
||||
g * xdlops_gemm.mfma_instr.group_size + r;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, reg_offset));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<reg_offset>{}] *
|
||||
type_convert<AccDataType>(
|
||||
b_scale_thread_buf[Number<b_scale_offset>{}]) *
|
||||
type_convert<AccDataType>(
|
||||
a_scale_thread_buf[Number<a_scale_offset>{}]);
|
||||
});
|
||||
});
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
|
||||
auto a_scale_thread_buf_group =
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_group.GetElementSpaceSize());
|
||||
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_group,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_group);
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
|
||||
a_scale_thread_buf(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_group[Number<r>{}];
|
||||
});
|
||||
// go to the next group
|
||||
a_scale_thread_buf(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
|
||||
}); // g
|
||||
|
||||
// restore row id and advance to the next scale
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
|
||||
xdlops_gemm.mfma_instr.num_groups_per_blk,
|
||||
1));
|
||||
}); // k0
|
||||
|
||||
// restore column id and advance to the next set of rows
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
}); // m0
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, I0),
|
||||
b_scale_thread_buf);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, 0));
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_buf(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
@@ -507,7 +491,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
@@ -517,94 +500,107 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
// read block data in chunks to assemble correct thread
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<b_k_step>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
// read block data in chunks to assemble correct thread
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto b_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
|
||||
// one scale per k0
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
constexpr auto reg_offset =
|
||||
g * xdlops_gemm.mfma_instr.group_size + r;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<reg_offset>{}] *
|
||||
type_convert<AccDataType>(
|
||||
b_scale_thread_buf[Number<b_scale_offset>{}]) *
|
||||
type_convert<AccDataType>(
|
||||
a_scale_thread_buf[Number<a_scale_offset>{}]);
|
||||
});
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_buf[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_buf[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved here
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{},
|
||||
Number<KRepeat>{},
|
||||
Number<xdlops_gemm.mfma_instr.num_groups_per_blk>{},
|
||||
Number<xdlops_gemm.mfma_instr.group_size>{}));
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<xdlops_gemm.mfma_instr.group_size>{}, Number<1>{}));
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved here
|
||||
static constexpr auto b_scale_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, Number<KRepeat>{}));
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
|
||||
Reference in New Issue
Block a user