MX GEMM - New GEMM pipeline for MX data types (#2059)

* Allow selection of mfma_scale instructions

* Read B tensor from LDS to VGPR in chunks of 16 in MFMA order

* Add constexpr and synchronize return type for `get_exponent_value`

* Pass scales by reference and add comments to `mfma_scale_f32_32x32x64`

* Add support for microscaling instructions in `XdlopsGemm`

* Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper

* Remove software implementation of MX GEMM

* Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction

* Update README

* Updated CHANGELOG

* Remove unused static methods

[ROCm/composable_kernel commit: 7106976a72]
This commit is contained in:
Andriy Roshchenko
2025-04-15 17:17:07 -06:00
committed by GitHub
parent 1a8132e9f9
commit 5e2bd20672
19 changed files with 1007 additions and 608 deletions

View File

@@ -13,6 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support for Stream-K version of mixed fp8/bf16 GEMM
* Added GEMM pipeline for microscaling (MX) data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
### Optimized

View File

@@ -1,10 +1,5 @@
add_custom_target(example_gemm_mx)
add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale)
add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale)
add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale)

View File

@@ -10,16 +10,16 @@ Custom verification parameters:
# arg4: verbosity (0=no info, 1=verbose info)
# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC
# arg11: KBatch
./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1
./bin/example_gemm_mx_fp8 1 1 0 1
```
Custom tensor shapes:
```bash
./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1
./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1
```
Default invocation:
```bash
# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0
./bin/example_gemm_mx_fp8_fp8_scale
# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0
./bin/example_gemm_mx_fp8
```

View File

@@ -95,7 +95,7 @@ bool parse_cmd_args(int argc,
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
<< "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl
<< "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl;
return false;
}
@@ -103,7 +103,8 @@ bool parse_cmd_args(int argc,
return true;
}
template <typename ADataType,
template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename CDataType,
@@ -115,65 +116,9 @@ template <typename ADataType,
typename CElementOp,
typename AccDataType,
typename CShuffleDataType,
ck::index_t MXVectorSize>
ck::index_t ScaleBlockSize>
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
{
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
static constexpr ck::index_t ScaleBlockSize = MXVectorSize;
static constexpr ck::index_t KPerBlock = 64;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XDataType, // AScaleDataType
BDataType, // BDataType
XDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
MXVectorSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
auto M = problem_size.M;
auto N = problem_size.N;
@@ -230,8 +175,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{}));
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
@@ -428,8 +373,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
if(config.time_kernel)
{
std::size_t flop = std::size_t(2) * M * N * K +
std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
@@ -445,7 +392,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
return res_verified;
}
template <typename ADataType,
template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename CDataType,
@@ -464,7 +412,8 @@ bool run_mx_gemm_example(int argc, char* argv[])
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) &&
run_mx_gemm<ADataType,
run_mx_gemm<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
CDataType,

View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::e8m0_bexp_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XDataType, // AScaleDataType
BDataType, // BDataType
XDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -1,42 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::e8m0_bexp_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t mx_vector_size = 32; // scaling block size
int main(int argc, char* argv[])
{
return run_mx_gemm_example<ADataType,
BDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
mx_vector_size>(argc, argv)
? 0
: -1;
}

View File

@@ -1,42 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t mx_vector_size = 32; // scaling block size
int main(int argc, char* argv[])
{
return run_mx_gemm_example<ADataType,
BDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
mx_vector_size>(argc, argv)
? 0
: -1;
}

View File

@@ -1,42 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::f8_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t mx_vector_size = 32; // scaling block size
int main(int argc, char* argv[])
{
return run_mx_gemm_example<ADataType,
BDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
mx_vector_size>(argc, argv)
? 0
: -1;
}

View File

@@ -0,0 +1,363 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmXdlops_mx_pipeline_base
{
using ComputeTypeA = ADataType;
using ComputeTypeB = BDataType;
using AccType = float; // for now only support V_MFMA_SCALE_F32
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
static constexpr index_t WaveSize = 64;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, TransposeC, true>{};
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
//> store rows/cols into thread registers in chunks of 16
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
static constexpr index_t KThreadChunk = 16;
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList =
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops>;
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccType,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
*
* This constructor initializes the thread copy objects for matrices A and B.
* It also performs several compile-time checks to ensure the correctness of the
* matrix tile descriptors.
*
* @param a_origin The origin data index for matrix A.
* @param b_origin The origin data index for matrix B.
*
* @note The constructor includes static assertions to ensure that:
* - The matrix tile descriptors for A and B are known at compile-time.
* - The number of threads in the thread block matches the product of MWaves, NWaves, and
* WaveSize.
* - The dimensions of the block are divisible by the product of the corresponding XDL and
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck

View File

@@ -7,6 +7,35 @@
namespace ck {
/**
* @brief Define matrix data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_data_type()
{
return is_same_v<T, f8_ocp_t> || is_same_v<T, bf8_ocp_t> || is_same_v<T, f6_t> ||
is_same_v<T, bf6_t> || is_same_v<T, f4_t>;
}
/**
* @brief Define scale data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_scale_type()
{
return is_same_v<T, e8m0_bexp_t>;
}
/**
* @brief Combination of data types that have hardware support for MX GEMMs
*/
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
static constexpr bool scale_mfma_hw_support()
{
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
}
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
@@ -34,6 +63,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t KPack>
constexpr auto BlockGemmMXPipeline_Selector()
{
// Hardware MX GEMM pipeline
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_v1_mx<BlkGemmPipeSche,
@@ -43,8 +74,6 @@ constexpr auto BlockGemmMXPipeline_Selector()
AScaleDataType,
BDataType,
BScaleDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
@@ -62,7 +91,7 @@ constexpr auto BlockGemmMXPipeline_Selector()
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
}
}

View File

@@ -3,7 +3,7 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
namespace ck {
@@ -20,8 +20,6 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
@@ -46,8 +44,6 @@ template <index_t ThreadBlockSize,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
@@ -69,8 +65,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
AScaleDataType,
BDataType,
BScaleDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
@@ -85,46 +79,43 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
@@ -134,7 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
using Base::xdlops_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
@@ -151,15 +141,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::KThreadChunk;
using Tuple4 = typename Base::Tuple4;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block size
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
@@ -172,45 +173,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
return TailNumber::Full;
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], xdlops_gemm.KPerXdlops * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], xdlops_gemm.KPerXdlops * xdlops_b_idx[I0]);
}
/**
* @brief Constructor for BlockwiseGemmXdlops_pipeline_v1_mx.
*
* The primary purpose of this constructor is to modify default initialization of the base class
* with the origin data index suitable for microscaling.
*
* @param a_origin The origin data index for matrix A.
* @param b_origin The origin data index for matrix B.
*
*/
__host__ __device__
BlockwiseGemmXdlops_pipeline_v1_mx(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: Base(a_origin, b_origin)
{
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
@@ -258,9 +220,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
const BScaleGridBuffer& b_scale_grid_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
@@ -276,49 +238,31 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_assert(xdlops_gemm.mfma_instr.num_groups_per_blk *
xdlops_gemm.mfma_instr.group_size ==
xdlops_gemm.GetRegSizePerXdlops(),
"Assume num_regs_per_blk == num_groups_per_blk * group_size");
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
auto a_scale_thread_buf_group =
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_group.GetElementSpaceSize());
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_group,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_group);
a_scale_thread_buf_copy);
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i));
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_group[Number<i>{}];
});
// go to the next group
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
}); // g
// restore row id and advance to the next scale
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
xdlops_gemm.mfma_instr.num_groups_per_blk,
1));
}); // k0
// restore column id and advance to the next set of rows
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
}); // m0
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
@@ -326,15 +270,32 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, 0));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_buf(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
@@ -345,8 +306,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// Initialize C
c_thread_buf.Clear();
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
// main body
if constexpr(HasMainLoop)
{
@@ -363,141 +322,166 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
// k indexes mapping to threads for 32x32x64:
// t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
// t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc.
// k = 0 k = 1
// k indexes mapping to threads for 16x16x128:
// t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc.
// t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc.
// t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc.
// t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc.
// k = 0 k = 1
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
c_thread_buf_per_scale.Clear();
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
// m = 1:MPerXDL
// n = 1:NPerXDL
// k = 1:KPack
// c(m,n) += a(m,k)*b(k,n)
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
// one scale per k0
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}(
[&](auto g) {
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}(
[&](auto r) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(
make_tuple(m0, k0, g, r));
constexpr auto reg_offset =
g * xdlops_gemm.mfma_instr.group_size + r;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, reg_offset));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<reg_offset>{}] *
type_convert<AccDataType>(
b_scale_thread_buf[Number<b_scale_offset>{}]) *
type_convert<AccDataType>(
a_scale_thread_buf[Number<a_scale_offset>{}]);
});
});
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
auto a_scale_thread_buf_group =
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_group.GetElementSpaceSize());
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_group,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_group);
a_scale_thread_buf_copy);
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_group[Number<r>{}];
});
// go to the next group
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
}); // g
// restore row id and advance to the next scale
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
xdlops_gemm.mfma_instr.num_groups_per_blk,
1));
}); // k0
// restore column id and advance to the next set of rows
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
}); // m0
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, 0));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_buf(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
@@ -507,7 +491,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
@@ -517,94 +500,107 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
// read block data in chunks to assemble correct thread
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
// read block data in chunks to assemble correct thread
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
c_thread_buf_per_scale.Clear();
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
// one scale per k0
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
constexpr auto reg_offset =
g * xdlops_gemm.mfma_instr.group_size + r;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<reg_offset>{}] *
type_convert<AccDataType>(
b_scale_thread_buf[Number<b_scale_offset>{}]) *
type_convert<AccDataType>(
a_scale_thread_buf[Number<a_scale_offset>{}]);
});
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved here
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{},
Number<KRepeat>{},
Number<xdlops_gemm.mfma_instr.num_groups_per_blk>{},
Number<xdlops_gemm.mfma_instr.group_size>{}));
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed(
make_tuple(Number<xdlops_gemm.mfma_instr.group_size>{}, Number<1>{}));
static constexpr auto a_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved here
static constexpr auto b_scale_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, Number<KRepeat>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from b_scale_grid to b_scale_thread_buf
static constexpr auto b_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
protected:
using Base::a_thread_copy_;

View File

@@ -694,14 +694,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
static_assert((is_same_v<ADataType, f8_t> || is_same_v<ADataType, bf8_t> ||
is_same_v<ADataType, f6_t> || is_same_v<ADataType, bf6_t> ||
is_same_v<ADataType, f4_t>)&&(is_same_v<BDataType, f8_t> ||
is_same_v<BDataType, bf8_t> ||
is_same_v<BDataType, f6_t> ||
is_same_v<BDataType, bf6_t> ||
is_same_v<BDataType, f4_t>),
static_assert(is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>(),
"Only microscaling formats are supported for ADataType and BDataType");
static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
@@ -711,6 +704,11 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(!IsValidCompilationParameter())
{
return false;
}
if(!ck::is_xdl_supported())
{
return false;

View File

@@ -159,16 +159,22 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
? true
: false;
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;
//> KPack is at least the k_per_blk of selected mfma
//
// Should be a multiple of k_per_blk.
// TODO: Move this to blockwise pipeline base
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -1088,10 +1094,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
static_assert(KPerBlock % ScaleBlockSize == 0,
"KPerBlock should be multiple of ScaleBlockSize");
static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat,
"Single call to xdlops_gemm::Run should process exactly ScaleBlockSize "
"elements in k dimension");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
@@ -1476,61 +1478,63 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
// NXdlPerWave == NRepeat
// MXdlPerWave == MRepeat
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
// Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2
// Initial thread mapping for:
// BlockSize = 256
// MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
// For each [m0, n0] tile, there are 4 waves:
// tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
// tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
// tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
// tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
auto a_thread_offset_m =
MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) +
mfma.selected_mfma.group_size *
((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl);
auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl;
// BlockSize = 128
// MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
// For each [m0, n0] tile, there are 2 waves:
// tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
// tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
auto b_thread_offset_n =
get_thread_local_1d_id() % NPerXdl +
(get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl;
auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl;
// TODO: Document initial thread mapping for more combinations of parameters
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak), // SrcDesc
decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc
Sequence<mfma.selected_mfma.group_size, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
0, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m,
a_thread_offset_k / ScaleBlockSize));
const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector
1,
false>(b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n,
b_thread_offset_k / ScaleBlockSize));
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
mfma.selected_mfma.num_threads_per_blk;
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak),
decltype(BlockwiseGemmPipe::a_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(
a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k));
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1,
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,

View File

@@ -211,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
* @tparam SrcVectorDim The dimension along which vectorized access is performed in the source
* tensor.
* @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor.
* @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source
* tensor.
* @tparam SrcScalarStrideInVector Not used.
* @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run
* or rolled back one step in MoveSrcSliceWindow
* @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for

View File

@@ -845,15 +845,24 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
template <index_t MPerXdlops,
index_t NPerXdlops,
class FloatA,
class ScaleA,
class FloatB,
class ScaleB,
class FloatC>
__device__ void run(const FloatA& a,
const int32_t scale_a,
const ScaleA& scale_a,
const FloatB& b,
const int32_t scale_b,
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
}
};
@@ -874,15 +883,24 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
template <index_t MPerXdlops,
index_t NPerXdlops,
class FloatA,
class ScaleA,
class FloatB,
class ScaleB,
class FloatC>
__device__ void run(const FloatA& a,
const int32_t scale_a,
const ScaleA& scale_a,
const FloatB& b,
const int32_t scale_b,
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, scale_a, b, scale_b, reg_c);
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
}
};
@@ -890,14 +908,16 @@ template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type,
bool is_single_rate_mfma = false>
bool is_single_rate_mfma = false,
bool is_scale_mfma = false>
struct MfmaSelector
{
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_,
bool is_single_rate_mfma_ = false>
bool is_single_rate_mfma_ = false,
bool is_scale_mfma_ = false>
static constexpr auto GetMfma();
template <>
@@ -1103,12 +1123,24 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32>()
{
@@ -1145,8 +1177,12 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
static constexpr auto selected_mfma = mfma_type<
GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
MPerXdlops,
NPerXdlops,
additional_type,
is_single_rate_mfma,
is_scale_mfma>()>{};
__host__ __device__ constexpr MfmaSelector()
{
@@ -1194,7 +1230,8 @@ template <typename base_type,
index_t NPerXdlops,
index_t KPack,
typename additional_type = base_type,
bool TransposeC = false>
bool TransposeC = false,
bool is_scale_mfma = false>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
@@ -1225,7 +1262,7 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
}
// XDL output supporting C = A * B
@@ -1368,6 +1405,27 @@ struct XdlopsGemm
});
}
template <class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
__device__ void Run(const FloatA& p_a_wave,
const ScaleA& a_scale_thread,
const FloatB& p_b_wave,
const ScaleB& b_scale_thread,
FloatC& p_c_thread) const
{
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
}
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
__device__ static auto GetBlkIdx()
@@ -1455,7 +1513,8 @@ struct XdlopsGemm
KPack <= 4) ||
(is_same<base_type, int8_t>::value && KPack <= 8))
? true
: false > {};
: false,
is_scale_mfma > {};
static constexpr auto mfma_instr = mfma.selected_mfma;

View File

@@ -520,9 +520,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const int32_t& scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
@@ -538,6 +538,14 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
scale_a,
0, // OPSEL
scale_b);
// XXX: Note on the scale_a and scale_b parameters:
// If compiler detects that one or both scales are constant values, it will treat that
// constant as F32 constant. I.e., if scale_a at some point was declared as
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
// when OPSEL is set otherwise.
#else
ignore = reg_a;
ignore = scale_a;
@@ -556,9 +564,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const int32_t& scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)

View File

@@ -67,10 +67,10 @@ struct e8m0_bexp_t
namespace utils {
template <typename T>
__host__ __device__ inline int get_exponent_value(T x);
__host__ __device__ inline constexpr int32_t get_exponent_value(T x);
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
__host__ __device__ inline constexpr int32_t get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
return x.data;
}

View File

@@ -32,13 +32,13 @@ template <typename T>
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
__host__ __device__ inline constexpr int32_t get_exponent_value(T x)
{
x >>= NumericUtils<T>::mant;
x &= ((1 << NumericUtils<T>::exp) - 1);
return static_cast<int>(x);
return static_cast<int32_t>(x);
}
template <typename T>

View File

@@ -30,48 +30,69 @@ enum class MFMA_F8F6F4
};
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
template <int32_t BLOCK_M, int32_t BLOCK_N>
struct mfma_type_selector;
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
template <>
struct mfma_type_selector<16, 16>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
template <typename AFragT, typename BFragT, typename AccumFragT>
__device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
}
__device__ void operator()(AFragT const& fragA,
const int32_t scale_a,
BFragT const& fragB,
const int32_t scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
op.template run<16, 16>(fragA, fragB, fragAcc);
}
};
template <typename AFragT, typename BFragT, typename AccumFragT>
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
template <>
struct mfma_type_selector<32, 32>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
template <typename AFragT, typename BFragT, typename AccumFragT>
__device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
op.template run<32, 32>(fragA, fragB, fragAcc);
}
};
__device__ void operator()(AFragT const& fragA,
const int32_t scale_a,
template <int32_t BLOCK_M, int32_t BLOCK_N>
struct mfma_scale_type_selector;
template <>
struct mfma_scale_type_selector<16, 16>
{
template <typename AFragT,
typename AScaleFragT,
typename BFragT,
typename BScaleFragT,
typename AccumFragT>
__device__ static void run(AFragT const& fragA,
AScaleFragT const& scale_a,
BFragT const& fragB,
const int32_t scale_b,
BScaleFragT const& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
}
};
template <>
struct mfma_scale_type_selector<32, 32>
{
template <typename AFragT,
typename AScaleFragT,
typename BFragT,
typename BScaleFragT,
typename AccumFragT>
__device__ static void run(AFragT const& fragA,
AScaleFragT const& scale_a,
BFragT const& fragB,
BScaleFragT const& scale_b,
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(
fragA, scale_a, fragB, scale_b, fragAcc);
op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
}
};
@@ -334,8 +355,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
// BLOCK_K / BLOCK_X is a stride in xA matrix
auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X);
// obtain 8-bit exponent
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
fragX = scale_ptr[startOffset];
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(input_ptr);
}
@@ -502,7 +522,7 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr,
auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X);
// obtain 8-bit exponent
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
fragX = scale_ptr[startOffset];
return load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(input_ptr);
}
@@ -773,7 +793,8 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
using mfma = mfma_type_selector<BLOCK_M, BLOCK_N>;
mfma::template run<>(fragA, fragB, fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{
@@ -805,29 +826,34 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
using ScaleFragT = int32_t;
using AScaleFragT = vector_type<ScaleType, 1>::type;
using BScaleFragT = vector_type<ScaleType, 1>::type;
// Create frags
auto fragA = AFragT{};
auto fragB = BFragT{};
auto fragC = CFragT{};
auto fragAcc = AccumFragT{0};
auto fragXa = ScaleFragT{0};
auto fragXb = ScaleFragT{0};
auto fragXa = AScaleFragT{};
auto fragXb = BScaleFragT{};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, ScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, AScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
a, xa, fragXa);
// B = col major, BLOCK_K x BLOCK_N
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, ScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, BScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
b, xb, fragXb);
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
fragA, fragXa, fragB, fragXb, fragAcc);
using mfma = mfma_scale_type_selector<BLOCK_M, BLOCK_N>;
mfma::template run<>(fragA,
fragXa.template AsType<ScaleType>(),
fragB,
fragXb.template AsType<ScaleType>(),
fragAcc);
for(int i = 0; i < vectorSize(fragC); ++i)
{