mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
* Refactored `load_A_row_major` to follow scale mapping
* Refactored `load_A_col_major` to follow scale mapping
* Refactored `load_B_col_major` to follow scale mapping
* Verified non-scaled test
* Verified scaled tests
* Used ReferenceMXGemm for verification
* Updated license headers
[ROCm/composable_kernel commit: ffa13455a2]
1270 lines
75 KiB
C++
1270 lines
75 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck/ck.hpp"
|
|
|
|
#include "ck/library/utility/host_tensor.hpp"
|
|
#include "ck/library/utility/device_memory.hpp"
|
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
|
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
|
|
#include "ck/library/utility/host_tensor_generator.hpp"
|
|
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
|
|
#include "ck/library/utility/check_err.hpp"
|
|
|
|
namespace ck {
|
|
|
|
// MFMA instructions supported in this test
|
|
enum class MFMA_F8F6F4
|
|
{
|
|
F32_16x16x128 =
|
|
static_cast<int>(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4
|
|
F32_32x32x64 =
|
|
static_cast<int>(MfmaInstr::mfma_f32_32x32x64f8f6f4), // V_MFMA_F32_32X32X64_F8F6F4
|
|
|
|
SCALE_F32_16x16x128 = static_cast<int>(
|
|
MfmaInstr::mfma_scale_f32_16x16x128f8f6f4), // V_MFMA_SCALE_F32_16X16X128_F8F6F4
|
|
SCALE_F32_32x32x64 = static_cast<int>(
|
|
MfmaInstr::mfma_scale_f32_32x32x64f8f6f4) // V_MFMA_SCALE_F32_32X32X64_F8F6F4
|
|
|
|
};
|
|
|
|
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
|
|
struct mfma_type_selector;
|
|
|
|
template <typename AFragT, typename BFragT, typename AccumFragT>
|
|
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
|
|
{
|
|
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
|
{
|
|
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
|
|
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
|
|
}
|
|
|
|
__device__ void operator()(AFragT const& fragA,
|
|
const int32_t scale_a,
|
|
BFragT const& fragB,
|
|
const int32_t scale_b,
|
|
AccumFragT& fragAcc)
|
|
{
|
|
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
|
|
op.template run<16, 16, AFragT, BFragT, AccumFragT>(
|
|
fragA, scale_a, fragB, scale_b, fragAcc);
|
|
}
|
|
};
|
|
|
|
template <typename AFragT, typename BFragT, typename AccumFragT>
|
|
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
|
|
{
|
|
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
|
{
|
|
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
|
|
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
|
|
}
|
|
|
|
__device__ void operator()(AFragT const& fragA,
|
|
const int32_t scale_a,
|
|
BFragT const& fragB,
|
|
const int32_t scale_b,
|
|
AccumFragT& fragAcc)
|
|
{
|
|
auto op = mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>{};
|
|
op.template run<32, 32, AFragT, BFragT, AccumFragT>(
|
|
fragA, scale_a, fragB, scale_b, fragAcc);
|
|
}
|
|
};
|
|
|
|
template <typename VecT>
|
|
static constexpr int32_t vectorSize(const VecT&)
|
|
{
|
|
return scalar_type<VecT>::vector_size;
|
|
}
|
|
|
|
// Define a load function for input A blocks:
|
|
// Size: (BLOCK_M x BLOCK_K)
|
|
// - Data is in column major format
|
|
// - Rows are loaded in contiguous chunks that map to corresponding microscales
|
|
// - Each row is loaded in chunks of size 16 and each thread loads 32 elements
|
|
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
|
|
__device__ AFragT load_A_col_major(AType const* input_ptr)
|
|
{
|
|
// clang-format off
|
|
// Register Mapping for 16x128: || Register Mapping for 32x64:
|
|
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
|
|
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
|
|
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
|
// Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] |
|
|
// Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] |
|
|
// Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] |
|
|
// Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] |
|
|
// Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] |
|
|
// Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] |
|
|
// Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] |
|
|
// Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] |
|
|
// Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] |
|
|
// Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] |
|
|
// Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] |
|
|
// Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] |
|
|
// Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] |
|
|
// Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] |
|
|
// Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] |
|
|
// Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] |
|
|
// Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] |
|
|
// Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] |
|
|
// Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] |
|
|
// Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] |
|
|
// Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] |
|
|
// Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] |
|
|
// Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] |
|
|
// Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] |
|
|
// Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] |
|
|
// Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] |
|
|
// Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] |
|
|
// Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] |
|
|
// Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] |
|
|
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
|
|
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
|
|
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
|
|
// clang-format on
|
|
|
|
static constexpr int32_t WAVE_SIZE = 64;
|
|
|
|
// Here we want to load from rows of A in chunks of 16 elements each.
|
|
static constexpr uint32_t chunk_size = 16;
|
|
|
|
// each chunk is separated by offset
|
|
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;
|
|
|
|
// To start the loading process, let's visualize in 2D coords.
|
|
// Each thread will load 32 elements.
|
|
// We need to know where they start, and where the next elements are.
|
|
auto startCoord2D =
|
|
std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15}
|
|
(threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48}
|
|
|
|
auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
|
|
auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row
|
|
|
|
// Flatten to 1D col_major offsets.
|
|
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
|
|
|
// BLOCK_M is a stride in A matrix
|
|
auto startOffset = col_major(startCoord2D, BLOCK_M);
|
|
auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M);
|
|
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M);
|
|
|
|
using ARawT = typename scalar_type<AFragT>::type;
|
|
using AScalarFragT = vector_type<ARawT, vectorSize(AFragT{})>::type;
|
|
|
|
AScalarFragT fragA{};
|
|
|
|
#pragma unroll
|
|
for(int chunk = 0; chunk < 2; chunk++)
|
|
{
|
|
#pragma unroll
|
|
for(uint32_t i = 0; i < chunk_size; i++)
|
|
{
|
|
fragA[chunk * chunk_size + i] =
|
|
bit_cast<ARawT>(input_ptr[startOffset + chunk * kMajorOffset + i * kMinorOffset]);
|
|
}
|
|
}
|
|
|
|
return fragA;
|
|
}
|
|
|
|
// Define a load function for input A blocks:
|
|
// Size: (BLOCK_M x BLOCK_K)
|
|
// - Data is in row major format
|
|
// - Rows are loaded in contiguous chunks that map to corresponding microscales
|
|
// - Each row is loaded in chunks of size 16 and each thread loads 32 elements
|
|
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
|
|
__device__ AFragT load_A_row_major(AType const* input_ptr)
|
|
{
|
|
// clang-format off
|
|
// Register Mapping for 16x128: || Register Mapping for 32x64:
|
|
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
|
|
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
|
|
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
|
// Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] |
|
|
// Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] |
|
|
// Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] |
|
|
// Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] |
|
|
// Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] |
|
|
// Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] |
|
|
// Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] |
|
|
// Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] |
|
|
// Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] |
|
|
// Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] |
|
|
// Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] |
|
|
// Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] |
|
|
// Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] |
|
|
// Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] |
|
|
// Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] |
|
|
// Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] |
|
|
// Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] |
|
|
// Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] |
|
|
// Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] |
|
|
// Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] |
|
|
// Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] |
|
|
// Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] |
|
|
// Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] |
|
|
// Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] |
|
|
// Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] |
|
|
// Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] |
|
|
// Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] |
|
|
// Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] |
|
|
// Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] |
|
|
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
|
|
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
|
|
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
|
|
// clang-format on
|
|
|
|
static constexpr int32_t WAVE_SIZE = 64;
|
|
|
|
// Here we want to load from rows of A in chunks of 16 elements each.
|
|
static constexpr uint32_t chunk_size = 16;
|
|
|
|
// each chunk is separated by offset
|
|
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;
|
|
|
|
// To start the loading process, let's visualize in 2D coords.
|
|
// Each thread will load 32 elements.
|
|
// We need to know where they start, and where the next elements are.
|
|
auto startCoord2D =
|
|
std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15}
|
|
(threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48}
|
|
|
|
// auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
|
|
auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row
|
|
|
|
// Flatten to 1D row_major offsets.
|
|
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
|
|
|
// BLOCK_K is a stride in A matrix
|
|
auto startOffset = row_major(startCoord2D, BLOCK_K);
|
|
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K);
|
|
auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K);
|
|
|
|
using ARawT = typename scalar_type<AFragT>::type;
|
|
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
|
|
|
|
union
|
|
{
|
|
AFragT frag;
|
|
AScalarFragT chunks[2];
|
|
} fragA{};
|
|
|
|
auto* fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset);
|
|
fragA.chunks[0] = *fragPtr;
|
|
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
|
|
fragA.chunks[1] = *fragPtr;
|
|
|
|
return fragA.frag;
|
|
}
|
|
|
|
// Define a load function for scaled A blocks:
|
|
// Size: (BLOCK_M x BLOCK_K)
|
|
// ASSUMPTION:
|
|
// - The scale inputs distributed across 64 lanes.
|
|
template <typename AType,
|
|
typename AFragT,
|
|
typename ScaleType,
|
|
typename ScaleFragT,
|
|
int32_t BLOCK_M,
|
|
int32_t BLOCK_K,
|
|
int32_t BLOCK_X>
|
|
__device__ AFragT load_mx_A_row_major(AType const* input_ptr,
|
|
ScaleType const* scale_ptr,
|
|
ScaleFragT& fragX)
|
|
{
|
|
// clang-format off
|
|
// Register Mapping for 16x128: || Register Mapping for 32x64:
|
|
// Size | BLOCK_M | BLOCK_M | | BLOCK_M | BLOCK_M | | || Size | BLOCK_M | BLOCK_M | | |
|
|
// M | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || M | 0 ... 31 | 0 ... 31 | Vector | |
|
|
// Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale |
|
|
// Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------|
|
|
// Reg 0 [0:7] | K0 | K16 | x(M,0) | K32 | K48 | x(M,1) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(M,0) |
|
|
// Reg 0 [8:15] | K1 | K17 | x(M,0) | K33 | K49 | x(M,1) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(M,0) |
|
|
// Reg 0 [16:23] | K2 | K18 | x(M,0) | K34 | K50 | x(M,1) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(M,0) |
|
|
// Reg 0 [24:31] | K3 | K19 | x(M,0) | K35 | K51 | x(M,1) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(M,0) |
|
|
// Reg 1 [0:7] | K4 | K20 | x(M,0) | K36 | K52 | x(M,1) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(M,0) |
|
|
// Reg 1 [8:15] | K5 | K21 | x(M,0) | K37 | K53 | x(M,1) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(M,0) |
|
|
// Reg 1 [16:23] | K6 | K22 | x(M,0) | K38 | K54 | x(M,1) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(M,0) |
|
|
// Reg 1 [24:31] | K7 | K23 | x(M,0) | K39 | K55 | x(M,1) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(M,0) |
|
|
// Reg 2 [0:7] | K8 | K24 | x(M,0) | K40 | K56 | x(M,1) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(M,0) |
|
|
// Reg 2 [8:15] | K9 | K25 | x(M,0) | K41 | K57 | x(M,1) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(M,0) |
|
|
// Reg 2 [16:23] | K10 | K26 | x(M,0) | K42 | K58 | x(M,1) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(M,0) |
|
|
// Reg 2 [24:31] | K11 | K27 | x(M,0) | K43 | K59 | x(M,1) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(M,0) |
|
|
// Reg 3 [0:7] | K12 | K28 | x(M,0) | K44 | K60 | x(M,1) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(M,0) |
|
|
// Reg 3 [8:15] | K13 | K29 | x(M,0) | K45 | K61 | x(M,1) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(M,0) |
|
|
// Reg 3 [16:23] | K14 | K30 | x(M,0) | K46 | K62 | x(M,1) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(M,0) |
|
|
// Reg 3 [24:31] | K15 | K31 | x(M,0) | K47 | K63 | x(M,1) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(M,0) |
|
|
// Reg 4 [0:7] | K64 | K80 | x(M,2) | K96 | K112 | x(M,3) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(M,1) |
|
|
// Reg 4 [8:15] | K65 | K81 | x(M,2) | K97 | K113 | x(M,3) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(M,1) |
|
|
// Reg 4 [16:23] | K66 | K82 | x(M,2) | K98 | K114 | x(M,3) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(M,1) |
|
|
// Reg 4 [24:31] | K67 | K83 | x(M,2) | K99 | K115 | x(M,3) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(M,1) |
|
|
// Reg 5 [0:7] | K68 | K84 | x(M,2) | K100 | K116 | x(M,3) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(M,1) |
|
|
// Reg 5 [8:15] | K69 | K85 | x(M,2) | K101 | K117 | x(M,3) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(M,1) |
|
|
// Reg 5 [16:23] | K70 | K86 | x(M,2) | K102 | K118 | x(M,3) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(M,1) |
|
|
// Reg 5 [24:31] | K71 | K87 | x(M,2) | K103 | K119 | x(M,3) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(M,1) |
|
|
// Reg 6 [0:7] | K72 | K88 | x(M,2) | K104 | K120 | x(M,3) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(M,1) |
|
|
// Reg 6 [8:15] | K73 | K89 | x(M,2) | K105 | K121 | x(M,3) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(M,1) |
|
|
// Reg 6 [16:23] | K74 | K90 | x(M,2) | K106 | K122 | x(M,3) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(M,1) |
|
|
// Reg 6 [24:31] | K75 | K91 | x(M,2) | K107 | K123 | x(M,3) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(M,1) |
|
|
// Reg 7 [0:7] | K76 | K92 | x(M,2) | K108 | K124 | x(M,3) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(M,1) |
|
|
// Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) |
|
|
// Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) |
|
|
// Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) |
|
|
// clang-format on
|
|
static constexpr uint32_t VW = vectorSize(AFragT{});
|
|
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
|
|
|
// To start the loading process, let's visualize in 2D coords.
|
|
// Each thread will load 1 element
|
|
// We need to know where they start
|
|
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
|
|
(threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col
|
|
|
|
// Flatten to 1D row_major offsets.
|
|
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
|
|
|
// BLOCK_K / BLOCK_X is a stride in xA matrix
|
|
auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X);
|
|
|
|
// obtain 8-bit exponent
|
|
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
|
|
|
|
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(input_ptr);
|
|
}
|
|
|
|
// Define a load function for input B blocks:
|
|
// Size: (BLOCK_K x BLOCK_N)
|
|
// - Data is in col major format
|
|
// - Cols are loaded in contiguous chunks that map to corresponding microscales
|
|
// - Each col is loaded in chunks of size 16 and each thread loads 32 elements
|
|
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
|
|
__device__ BFragT load_B_col_major(BType const* input_ptr)
|
|
{
|
|
// clang-format off
|
|
// Register Mapping for 128x16: || Register Mapping for 64x32:
|
|
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
|
|
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
|
|
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
|
|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
|
|
// Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] |
|
|
// Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] |
|
|
// Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] |
|
|
// Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] |
|
|
// Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] |
|
|
// Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] |
|
|
// Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] |
|
|
// Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] |
|
|
// Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] |
|
|
// Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] |
|
|
// Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] |
|
|
// Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] |
|
|
// Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] |
|
|
// Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] |
|
|
// Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] |
|
|
// Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] |
|
|
// Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] |
|
|
// Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] |
|
|
// Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] |
|
|
// Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] |
|
|
// Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] |
|
|
// Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] |
|
|
// Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] |
|
|
// Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] |
|
|
// Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] |
|
|
// Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] |
|
|
// Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] |
|
|
// Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] |
|
|
// Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] |
|
|
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
|
|
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
|
|
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
|
|
// clang-format on
|
|
|
|
static constexpr int32_t WAVE_SIZE = 64;
|
|
|
|
// Here we want to load from cols of B in chunks of 16 elements each.
|
|
static constexpr uint32_t chunk_size = 16;
|
|
|
|
// each chunk is separated by an offset
|
|
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64
|
|
|
|
// To start the loading process, let's visualize in 2D coords.
|
|
// Each thread will load 32 elements.
|
|
// We need to know where they start, and where the next elements are.
|
|
auto startCoord2D =
|
|
std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48}
|
|
threadIdx.x % BLOCK_N); // Col {0-31} | {0-15}
|
|
|
|
// Flatten to 1D col_major offsets.
|
|
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
|
|
|
// auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols
|
|
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
|
|
|
|
// BLOCK_K is a stride in B matrix
|
|
auto startOffset = col_major(startCoord2D, BLOCK_K);
|
|
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K);
|
|
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K);
|
|
|
|
using BRawT = typename scalar_type<BFragT>::type;
|
|
using BScalarFragT = vector_type<BRawT, chunk_size>::type;
|
|
|
|
union
|
|
{
|
|
BFragT frag;
|
|
BScalarFragT chunks[2];
|
|
} fragB{};
|
|
|
|
auto* fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset);
|
|
fragB.chunks[0] = *fragPtr;
|
|
fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
|
|
fragB.chunks[1] = *fragPtr;
|
|
|
|
return fragB.frag;
|
|
}
|
|
|
|
// Define a load function for scaled B blocks:
|
|
// Size: (BLOCK_K x BLOCK_N)
|
|
// ASSUMPTION:
|
|
// - The scale inputs distributed across 64 lanes.
|
|
template <typename BType,
|
|
typename BFragT,
|
|
typename ScaleType,
|
|
typename ScaleFragT,
|
|
int32_t BLOCK_K,
|
|
int32_t BLOCK_N,
|
|
int32_t BLOCK_X>
|
|
__device__ BFragT load_mx_B_col_major(BType const* input_ptr,
|
|
ScaleType const* scale_ptr,
|
|
ScaleFragT& fragX)
|
|
|
|
{
|
|
// clang-format off
|
|
// Register Mapping for 128x16: || Register Mapping for 64x32:
|
|
// Size | BLOCK_N | BLOCK_N | | BLOCK_N | BLOCK_N | | || Size | BLOCK_N | BLOCK_N | | |
|
|
// N | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || N | 0 ... 31 | 0 ... 31 | Vector | |
|
|
// Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale |
|
|
// Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------|
|
|
// Reg 0 [0:7] | K0 | K16 | x(0,N) | K32 | K48 | x(1,N) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(0,N) |
|
|
// Reg 0 [8:15] | K1 | K17 | x(0,N) | K33 | K49 | x(1,N) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(0,N) |
|
|
// Reg 0 [16:23] | K2 | K18 | x(0,N) | K34 | K50 | x(1,N) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(0,N) |
|
|
// Reg 0 [24:31] | K3 | K19 | x(0,N) | K35 | K51 | x(1,N) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(0,N) |
|
|
// Reg 1 [0:7] | K4 | K20 | x(0,N) | K36 | K52 | x(1,N) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(0,N) |
|
|
// Reg 1 [8:15] | K5 | K21 | x(0,N) | K37 | K53 | x(1,N) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(0,N) |
|
|
// Reg 1 [16:23] | K6 | K22 | x(0,N) | K38 | K54 | x(1,N) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(0,N) |
|
|
// Reg 1 [24:31] | K7 | K23 | x(0,N) | K39 | K55 | x(1,N) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(0,N) |
|
|
// Reg 2 [0:7] | K8 | K24 | x(0,N) | K40 | K56 | x(1,N) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(0,N) |
|
|
// Reg 2 [8:15] | K9 | K25 | x(0,N) | K41 | K57 | x(1,N) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(0,N) |
|
|
// Reg 2 [16:23] | K10 | K26 | x(0,N) | K42 | K58 | x(1,N) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(0,N) |
|
|
// Reg 2 [24:31] | K11 | K27 | x(0,N) | K43 | K59 | x(1,N) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(0,N) |
|
|
// Reg 3 [0:7] | K12 | K28 | x(0,N) | K44 | K60 | x(1,N) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(0,N) |
|
|
// Reg 3 [8:15] | K13 | K29 | x(0,N) | K45 | K61 | x(1,N) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(0,N) |
|
|
// Reg 3 [16:23] | K14 | K30 | x(0,N) | K46 | K62 | x(1,N) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(0,N) |
|
|
// Reg 3 [24:31] | K15 | K31 | x(0,N) | K47 | K63 | x(1,N) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(0,N) |
|
|
// Reg 4 [0:7] | K64 | K80 | x(2,N) | K96 | K112 | x(3,N) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(1,N) |
|
|
// Reg 4 [8:15] | K65 | K81 | x(2,N) | K97 | K113 | x(3,N) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(1,N) |
|
|
// Reg 4 [16:23] | K66 | K82 | x(2,N) | K98 | K114 | x(3,N) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(1,N) |
|
|
// Reg 4 [24:31] | K67 | K83 | x(2,N) | K99 | K115 | x(3,N) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(1,N) |
|
|
// Reg 5 [0:7] | K68 | K84 | x(2,N) | K100 | K116 | x(3,N) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(1,N) |
|
|
// Reg 5 [8:15] | K69 | K85 | x(2,N) | K101 | K117 | x(3,N) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(1,N) |
|
|
// Reg 5 [16:23] | K70 | K86 | x(2,N) | K102 | K118 | x(3,N) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(1,N) |
|
|
// Reg 5 [24:31] | K71 | K87 | x(2,N) | K103 | K119 | x(3,N) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(1,N) |
|
|
// Reg 6 [0:7] | K72 | K88 | x(2,N) | K104 | K120 | x(3,N) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(1,N) |
|
|
// Reg 6 [8:15] | K73 | K89 | x(2,N) | K105 | K121 | x(3,N) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(1,N) |
|
|
// Reg 6 [16:23] | K74 | K90 | x(2,N) | K106 | K122 | x(3,N) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(1,N) |
|
|
// Reg 6 [24:31] | K75 | K91 | x(2,N) | K107 | K123 | x(3,N) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(1,N) |
|
|
// Reg 7 [0:7] | K76 | K92 | x(2,N) | K108 | K124 | x(3,N) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(1,N) |
|
|
// Reg 7 [8:15] | K77 | K93 | x(2,N) | K109 | K125 | x(3,N) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(1,N) |
|
|
// Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) |
|
|
// Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) |
|
|
|
|
// clang-format on
|
|
static constexpr uint32_t VW = vectorSize(BFragT{});
|
|
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
|
|
|
|
// To start the loading process, let's visualize in 2D coords.
|
|
// Each thread will load 1 element
|
|
// We need to know where to start
|
|
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row
|
|
threadIdx.x % BLOCK_N); // Col
|
|
|
|
// Flatten to 1D col_major offsets.
|
|
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
|
|
|
auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X);
|
|
|
|
// obtain 8-bit exponent
|
|
fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF;
|
|
|
|
return load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(input_ptr);
|
|
}
|
|
|
|
// Define a store function for C
|
|
// Size: (BLOCK_M x BLOCK_N)
|
|
// ASSUMPTION:
|
|
// - We want contiguous BLOCK_N sized row neighbors in register.
|
|
// - Data is in col_major format
|
|
// This means:
|
|
// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data
|
|
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
|
|
struct store_C_col_major;
|
|
|
|
// Here we want to store a 16x16 block of data.
|
|
//
|
|
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
|
|
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
|
|
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
|
|
// Register Element ------------ ------------- ------------ -------------- Element
|
|
// Reg0 | M0 | M4 | M8 | M12 | v[0]
|
|
// Reg1 | M1 | M5 | M9 | M13 | v[1]
|
|
// Reg2 | M2 | M6 | M10 | M14 | v[2]
|
|
// Reg3 | M3 | M7 | M11 | M15 | v[3]
|
|
template <typename CType, typename CFragT>
|
|
struct store_C_col_major<CType, CFragT, 16, 16>
|
|
{
|
|
__device__ void operator()(CType* output, CFragT cFrag)
|
|
{
|
|
static constexpr uint32_t VW = vectorSize(cFrag); // 4
|
|
static constexpr uint32_t Dim = 16;
|
|
|
|
// Each thread will load 4 elements.
|
|
// We need to know where they start, and where the next elements are.
|
|
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
|
|
threadIdx.x % Dim); // Col
|
|
|
|
// Flatten to 1D col_major offsets.
|
|
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
|
|
|
auto startOffset = col_major(startCoord2D, 16);
|
|
|
|
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
|
|
*fragPtr = cFrag;
|
|
}
|
|
};
|
|
|
|
// Here we want to store a 32x32 block of data.
|
|
// Register Mapping:
|
|
|
|
// Size | BLOCK_N | BLOCK_N |
|
|
// N | 0 ... 31 | 0 ... 31 |
|
|
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
|
|
// Register Element ------------ ------------- Element
|
|
// Reg0 | M0 | M4 | v[0]
|
|
// Reg1 | M1 | M5 | v[1]
|
|
// Reg2 | M2 | M6 | v[2]
|
|
// Reg3 | M3 | M7 | v[3]
|
|
// ____________ _____________
|
|
// Reg4 | M8 | M12 | v[4]
|
|
// Reg5 | M9 | M13 | v[5]
|
|
// Reg6 | M10 | M14 | v[6]
|
|
// Reg7 | M11 | M15 | v[7]
|
|
// ____________ _____________
|
|
// Reg8 | M16 | M20 | v[8]
|
|
// Reg9 | M17 | M21 | v[9]
|
|
// Reg10 | M18 | M22 | v[10]
|
|
// Reg11 | M19 | M23 | v[11]
|
|
// ____________ _____________
|
|
// Reg12 | M24 | M28 | v[12]
|
|
// Reg13 | M25 | M29 | v[13]
|
|
// Reg14 | M26 | M30 | v[14]
|
|
// Reg15 | M27 | M31 | v[15]
|
|
|
|
template <typename CType, typename CFragT>
|
|
struct store_C_col_major<CType, CFragT, 32, 32>
|
|
{
|
|
__device__ void operator()(CType* output, CFragT cFrag)
|
|
{
|
|
static constexpr uint32_t WAVE_SIZE = 64;
|
|
static constexpr uint32_t VW = 4;
|
|
static constexpr uint32_t Dim = 32;
|
|
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
|
|
|
|
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
|
|
threadIdx.x % Dim); // Col
|
|
|
|
// Major step between 'chunks'
|
|
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
|
|
|
|
// Flatten to 1D col_major offsets.
|
|
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
|
|
|
auto startOffset = col_major(startCoord2D, 32);
|
|
auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8
|
|
|
|
// we can vector store 4 contiguous elements at a time.
|
|
using CRawT = typename scalar_type<CFragT>::type;
|
|
using CScalarFragT = vector_type<CRawT, VW>::type;
|
|
union
|
|
{
|
|
CFragT frag;
|
|
CScalarFragT chunks[vectorSize(CFragT{}) / VW];
|
|
} fragC{cFrag}; // Initialize with input fragment
|
|
|
|
*(reinterpret_cast<CScalarFragT*>(output + startOffset)) = fragC.chunks[0];
|
|
*(reinterpret_cast<CScalarFragT*>(output + startOffset + kMajorOffset)) = fragC.chunks[1];
|
|
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 2 * kMajorOffset)) =
|
|
fragC.chunks[2];
|
|
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 3 * kMajorOffset)) =
|
|
fragC.chunks[3];
|
|
}
|
|
};
|
|
|
|
// Define a store function for C
|
|
// Size: (BLOCK_M x BLOCK_N)
|
|
// ASSUMPTION:
|
|
// - We want contiguous BLOCK_N sized row neighbors in register.
|
|
// - Data is in row major format
|
|
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
|
|
struct store_C_row_major;
|
|
|
|
// Here we want to store a 16x16 block of data.
|
|
//
|
|
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
|
|
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
|
|
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
|
|
// Register Element ------------ ------------- ------------ -------------- Element
|
|
// Reg0 | M0 | M4 | M8 | M12 | v[0]
|
|
// Reg1 | M1 | M5 | M9 | M13 | v[1]
|
|
// Reg2 | M2 | M6 | M10 | M14 | v[2]
|
|
// Reg3 | M3 | M7 | M11 | M15 | v[3]
|
|
template <typename CType, typename CFragT>
|
|
struct store_C_row_major<CType, CFragT, 16, 16>
|
|
{
|
|
__device__ void operator()(CType* output, CFragT cFrag)
|
|
{
|
|
static constexpr uint32_t VW = vectorSize(cFrag); // 4
|
|
static constexpr uint32_t Dim = 16;
|
|
|
|
// Each thread will load 4 elements.
|
|
// We need to know where they start, and where the next elements are.
|
|
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
|
|
threadIdx.x % Dim); // Col
|
|
auto stepCoord2D = std::make_pair(1u, 0u);
|
|
|
|
// Flatten to 1D row_major offsets.
|
|
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
|
|
|
auto startOffset = row_major(startCoord2D, 16);
|
|
auto kOffset = row_major(stepCoord2D, 16);
|
|
|
|
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
|
|
*fragPtr = cFrag;
|
|
|
|
// If you notice carefully, kOffset != 1.
|
|
// This means the following is vector is updated with 4 non-contiguous offsets,
|
|
// which the compiler will separate into 4 different global_store_dword instructions.
|
|
output[startOffset] = cFrag[0]; // v[0] = Reg 0
|
|
output[startOffset + kOffset] = cFrag[1]; // v[1] = Reg 1
|
|
output[startOffset + 2 * kOffset] = cFrag[2]; // v[2] = Reg 2
|
|
output[startOffset + 3 * kOffset] = cFrag[3]; // v[3] = Reg 3
|
|
}
|
|
};
|
|
|
|
// Here we want to store a 32x32 block of data.
|
|
// Register Mapping:
|
|
|
|
// Size | BLOCK_N | BLOCK_N |
|
|
// N | 0 ... 31 | 0 ... 31 |
|
|
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
|
|
// Register Element ------------ ------------- Element
|
|
// Reg0 | M0 | M4 | v[0]
|
|
// Reg1 | M1 | M5 | v[1]
|
|
// Reg2 | M2 | M6 | v[2]
|
|
// Reg3 | M3 | M7 | v[3]
|
|
// ____________ _____________
|
|
// Reg4 | M8 | M12 | v[4]
|
|
// Reg5 | M9 | M13 | v[5]
|
|
// Reg6 | M10 | M14 | v[6]
|
|
// Reg7 | M11 | M15 | v[7]
|
|
// ____________ _____________
|
|
// Reg8 | M16 | M20 | v[8]
|
|
// Reg9 | M17 | M21 | v[9]
|
|
// Reg10 | M18 | M22 | v[10]
|
|
// Reg11 | M19 | M23 | v[11]
|
|
// ____________ _____________
|
|
// Reg12 | M24 | M28 | v[12]
|
|
// Reg13 | M25 | M29 | v[13]
|
|
// Reg14 | M26 | M30 | v[14]
|
|
// Reg15 | M27 | M31 | v[15]
|
|
|
|
template <typename CType, typename CFragT>
|
|
struct store_C_row_major<CType, CFragT, 32, 32>
|
|
{
|
|
__device__ void operator()(CType* output, CFragT cFrag)
|
|
{
|
|
static constexpr uint32_t WAVE_SIZE = 64;
|
|
static constexpr uint32_t VW = 4; // This VW is per 'chunk'
|
|
static constexpr uint32_t Dim = 32; // BLOCK_N
|
|
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
|
|
|
|
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
|
|
threadIdx.x % Dim); // Col
|
|
|
|
// Minor step for each 'chunk'
|
|
auto minorStepCoord2D = std::make_pair(1u, 0u);
|
|
|
|
// Major step between 'chunks'
|
|
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
|
|
|
|
// Flatten to 1D row_major offsets.
|
|
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
|
|
|
|
auto startOffset = row_major(startCoord2D, 32);
|
|
auto kMinorOffset = row_major(minorStepCoord2D, 32);
|
|
auto kMajorOffset = row_major(majorStepCoord2D, 32);
|
|
|
|
output[startOffset] = cFrag[0]; // v[0] = Reg 0
|
|
output[startOffset + kMinorOffset] = cFrag[1]; // v[1] = Reg 1
|
|
output[startOffset + 2 * kMinorOffset] = cFrag[2]; // v[2] = Reg 2
|
|
output[startOffset + 3 * kMinorOffset] = cFrag[3]; // v[3] = Reg 3
|
|
output[startOffset + kMajorOffset] = cFrag[4]; // v[4] = Reg 4
|
|
output[startOffset + kMajorOffset + kMinorOffset] = cFrag[5]; // v[5] = Reg 5
|
|
output[startOffset + kMajorOffset + 2 * kMinorOffset] = cFrag[6]; // v[6] = Reg 6
|
|
output[startOffset + kMajorOffset + 3 * kMinorOffset] = cFrag[7]; // v[7] = Reg 7
|
|
output[startOffset + 2 * kMajorOffset] = cFrag[8]; // v[8] = Reg 8
|
|
output[startOffset + 2 * kMajorOffset + kMinorOffset] = cFrag[9]; // v[9] = Reg 9
|
|
output[startOffset + 2 * kMajorOffset + 2 * kMinorOffset] = cFrag[10]; // v[10] = Reg 10
|
|
output[startOffset + 2 * kMajorOffset + 3 * kMinorOffset] = cFrag[11]; // v[11] = Reg 11
|
|
output[startOffset + 3 * kMajorOffset] = cFrag[12]; // v[12] = Reg 12
|
|
output[startOffset + 3 * kMajorOffset + kMinorOffset] = cFrag[13]; // v[13] = Reg 13
|
|
output[startOffset + 3 * kMajorOffset + 2 * kMinorOffset] = cFrag[14]; // v[14] = Reg 14
|
|
output[startOffset + 3 * kMajorOffset + 3 * kMinorOffset] = cFrag[15]; // v[15] = Reg 15
|
|
}
|
|
};
|
|
|
|
template <typename AType,
|
|
typename BType,
|
|
typename CType,
|
|
typename AccType,
|
|
int32_t BLOCK_M,
|
|
int32_t BLOCK_N,
|
|
int32_t BLOCK_K>
|
|
__global__ void matmul(const AType* a, const BType* b, CType* c)
|
|
{
|
|
constexpr int WAVE_SIZE = 64;
|
|
assert(threadIdx.x < WAVE_SIZE);
|
|
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
|
|
|
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
|
|
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
|
|
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
|
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
|
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
|
|
|
// Create frags
|
|
auto fragA = AFragT{};
|
|
auto fragB = BFragT{};
|
|
auto fragC = CFragT{};
|
|
auto fragAcc = AccumFragT{0};
|
|
|
|
// Load the inputs.
|
|
// A = col major, BLOCK_M x BLOCK_K
|
|
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
|
// B = col major, BLOCK_K x BLOCK_N
|
|
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
|
|
|
|
// Matrix multiply-accumulate using MFMA units
|
|
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
|
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
|
|
|
|
for(int i = 0; i < vectorSize(fragC); ++i)
|
|
{
|
|
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
|
|
}
|
|
|
|
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
|
|
storeC(c, fragC);
|
|
}
|
|
|
|
template <typename AType,
|
|
typename BType,
|
|
typename ScaleType,
|
|
typename CType,
|
|
typename AccType,
|
|
int32_t BLOCK_M,
|
|
int32_t BLOCK_N,
|
|
int32_t BLOCK_K,
|
|
int32_t BLOCK_X>
|
|
__global__ void
|
|
matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c)
|
|
{
|
|
constexpr int WAVE_SIZE = 64;
|
|
assert(threadIdx.x < WAVE_SIZE);
|
|
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
|
|
|
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
|
|
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
|
|
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
|
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
|
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
|
using ScaleFragT = int32_t;
|
|
|
|
// Create frags
|
|
auto fragA = AFragT{};
|
|
auto fragB = BFragT{};
|
|
auto fragC = CFragT{};
|
|
auto fragAcc = AccumFragT{0};
|
|
auto fragXa = ScaleFragT{0};
|
|
auto fragXb = ScaleFragT{0};
|
|
|
|
// Load the inputs.
|
|
// A = col major, BLOCK_M x BLOCK_K
|
|
fragA = load_mx_A_row_major<AType, AFragT, ScaleType, ScaleFragT, BLOCK_M, BLOCK_K, BLOCK_X>(
|
|
a, xa, fragXa);
|
|
|
|
// B = col major, BLOCK_K x BLOCK_N
|
|
fragB = load_mx_B_col_major<BType, BFragT, ScaleType, ScaleFragT, BLOCK_K, BLOCK_N, BLOCK_X>(
|
|
b, xb, fragXb);
|
|
|
|
// Scaled Matrix multiply-accumulate using MFMA units
|
|
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
|
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
|
|
fragA, fragXa, fragB, fragXb, fragAcc);
|
|
|
|
for(int i = 0; i < vectorSize(fragC); ++i)
|
|
{
|
|
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
|
|
}
|
|
|
|
auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
|
|
storeC(c, fragC);
|
|
}
|
|
|
|
/**
|
|
* @brief Structure to hold dimension parameters for GEMM tensors.
|
|
*
|
|
* M Number of rows in matrix A and matrix C.
|
|
* N Number of columns in matrix B and matrix C.
|
|
* K Number of columns in matrix A and number of rows in matrix B.
|
|
* StrideA Stride (leading dimension) of matrix A.
|
|
* StrideB Stride (leading dimension) of matrix B.
|
|
* StrideC Stride (leading dimension) of matrix C.
|
|
*/
|
|
struct GemmParams
|
|
{
|
|
ck::index_t M = 16;
|
|
ck::index_t N = 16;
|
|
ck::index_t K = 128;
|
|
|
|
ck::index_t StrideA = -1;
|
|
ck::index_t StrideB = -1;
|
|
ck::index_t StrideC = -1;
|
|
};
|
|
|
|
namespace mxmfma_test {
|
|
template <typename ADataType, typename BDataType, typename ScaleType, typename CDataType>
|
|
void RunHostGEMM(const Tensor<ADataType>& A,
|
|
const Tensor<ScaleType>& a_scales,
|
|
const Tensor<BDataType>& B,
|
|
const Tensor<ScaleType>& b_scales,
|
|
Tensor<CDataType>& C)
|
|
{
|
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
|
|
|
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType,
|
|
BDataType,
|
|
CDataType,
|
|
float,
|
|
ScaleType,
|
|
PassThrough,
|
|
PassThrough,
|
|
PassThrough,
|
|
float,
|
|
float>;
|
|
auto ref_gemm = ReferenceGemmInstance{};
|
|
auto ref_invoker = ref_gemm.MakeInvoker();
|
|
|
|
auto ref_argument = ref_gemm.MakeArgument(
|
|
A, a_scales, B, b_scales, C, PassThrough{}, PassThrough{}, PassThrough{});
|
|
|
|
ref_invoker.Run(ref_argument);
|
|
}
|
|
|
|
template <typename KernelType,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename ScaleType,
|
|
typename CDataType>
|
|
bool RunDeviceGEMM(KernelType kernel,
|
|
const Tensor<ADataType>& A,
|
|
const Tensor<ScaleType>& a_scales,
|
|
const Tensor<BDataType>& B,
|
|
const Tensor<ScaleType>& b_scales,
|
|
Tensor<CDataType>& C)
|
|
{
|
|
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
|
|
DeviceMem a_scales_device_buf(sizeof(ScaleType) * a_scales.mDesc.GetElementSpaceSize());
|
|
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
|
|
DeviceMem b_scales_device_buf(sizeof(ScaleType) * b_scales.mDesc.GetElementSpaceSize());
|
|
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
|
|
|
|
a_m_k_device_buf.ToDevice(A.mData.data());
|
|
a_scales_device_buf.ToDevice(a_scales.mData.data());
|
|
b_n_k_device_buf.ToDevice(B.mData.data());
|
|
b_scales_device_buf.ToDevice(b_scales.mData.data());
|
|
kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
|
static_cast<const ScaleType*>(a_scales_device_buf.GetDeviceBuffer()),
|
|
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
|
static_cast<const ScaleType*>(b_scales_device_buf.GetDeviceBuffer()),
|
|
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
|
|
c_m_n_device_buf.FromDevice(C.mData.data());
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename DeviceMFMA,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename ScaleType,
|
|
typename CDataType,
|
|
typename ALayout,
|
|
typename BLayout,
|
|
typename CLayout,
|
|
index_t BLOCK_M,
|
|
index_t BLOCK_N,
|
|
index_t BLOCK_K,
|
|
index_t BLOCK_X>
|
|
struct TestMXMFMA
|
|
{
|
|
auto PrepareGemmTensors(const GemmParams& params, index_t init)
|
|
{
|
|
auto f_host_tensor_descriptor =
|
|
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
|
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
|
{
|
|
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
|
std::vector<std::size_t>({stride, 1}));
|
|
}
|
|
else
|
|
{
|
|
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
|
std::vector<std::size_t>({1, stride}));
|
|
}
|
|
};
|
|
|
|
Tensor<ADataType> a_m_k(
|
|
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
|
|
Tensor<ScaleType> a_scales(
|
|
f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{}));
|
|
Tensor<BDataType> b_n_k(
|
|
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
|
Tensor<ScaleType> b_scales(
|
|
f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{}));
|
|
Tensor<CDataType> c_m_n_host_result(
|
|
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
|
Tensor<CDataType> c_m_n_device_result(
|
|
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
|
|
|
switch(init)
|
|
{
|
|
case 0:
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
|
a_scales.GenerateTensorValue(
|
|
GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/64
|
|
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
|
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
|
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
|
|
break;
|
|
case 1:
|
|
// results in C = {K}
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
|
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
|
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
|
|
break;
|
|
case 2:
|
|
// expect small round off errors
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
|
a_scales.GenerateTensorValue(
|
|
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
|
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
|
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
|
|
break;
|
|
|
|
case 3:
|
|
// expect small round off errors
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(0, 1));
|
|
a_scales.GenerateTensorValue(
|
|
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(0, 1));
|
|
b_scales.GenerateTensorValue(
|
|
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
|
break;
|
|
default:
|
|
// all initial values are representable in FP8, BF8
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
|
|
a_scales.GenerateTensorValue(
|
|
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
|
|
b_scales.GenerateTensorValue(
|
|
GeneratorTensor_2<ScaleType>{122, 129}); // scales: [1/32,..., 2]
|
|
|
|
break;
|
|
}
|
|
|
|
return std::make_tuple(
|
|
a_m_k, a_scales, b_n_k, b_scales, c_m_n_host_result, c_m_n_device_result);
|
|
}
|
|
|
|
auto operator()(const DeviceMFMA& mfma_kernel, index_t init)
|
|
{
|
|
// Arrange
|
|
GemmParams params;
|
|
params.M = BLOCK_M;
|
|
params.N = BLOCK_N;
|
|
params.K = BLOCK_K;
|
|
|
|
auto f_get_default_stride = [](std::size_t row,
|
|
std::size_t col,
|
|
ck::index_t stride,
|
|
auto layout) {
|
|
if(stride == -1)
|
|
{
|
|
// give a chance if stride is -1, return a default packed stride
|
|
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return static_cast<std::size_t>(col);
|
|
}
|
|
else
|
|
{
|
|
return static_cast<std::size_t>(row);
|
|
}
|
|
}
|
|
else
|
|
return static_cast<std::size_t>(stride);
|
|
};
|
|
|
|
params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{});
|
|
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
|
|
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
|
|
|
|
auto host_tensors = PrepareGemmTensors(params, init);
|
|
|
|
const Tensor<ADataType>& a = std::get<0>(host_tensors);
|
|
const Tensor<ScaleType>& a_scales = std::get<1>(host_tensors);
|
|
const Tensor<BDataType>& b = std::get<2>(host_tensors);
|
|
const Tensor<ScaleType>& b_scales = std::get<3>(host_tensors);
|
|
Tensor<CDataType>& c_host = std::get<4>(host_tensors);
|
|
Tensor<CDataType>& c_device = std::get<5>(host_tensors);
|
|
|
|
RunHostGEMM(a, a_scales, b, b_scales, c_host);
|
|
|
|
RunDeviceGEMM(mfma_kernel, a, a_scales, b, b_scales, c_device);
|
|
|
|
bool res = false;
|
|
if constexpr(std::is_same<CDataType, float>::value ||
|
|
std::is_same<CDataType, half_t>::value)
|
|
{
|
|
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
|
}
|
|
else
|
|
{
|
|
std::cout << "UNSUPPORTED CDataType" << std::endl;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
};
|
|
|
|
} // namespace mxmfma_test
|
|
|
|
namespace mfma_test {
|
|
template <typename GemmInstance,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename CDataType,
|
|
typename AElementwiseOperation,
|
|
typename BElementwiseOperation,
|
|
typename CElementwiseOperation>
|
|
void RunHostGEMM(const Tensor<ADataType>& A,
|
|
const Tensor<BDataType>& B,
|
|
Tensor<CDataType>& C,
|
|
AElementwiseOperation a_element_op,
|
|
BElementwiseOperation b_element_op,
|
|
CElementwiseOperation c_element_op)
|
|
{
|
|
auto ref_gemm = GemmInstance{};
|
|
auto ref_invoker = ref_gemm.MakeInvoker();
|
|
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
|
|
|
|
ref_invoker.Run(ref_argument);
|
|
}
|
|
|
|
template <typename KernelType, typename ADataType, typename BDataType, typename CDataType>
|
|
bool RunDeviceGEMM(KernelType kernel,
|
|
const Tensor<ADataType>& A,
|
|
const Tensor<BDataType>& B,
|
|
Tensor<CDataType>& C)
|
|
{
|
|
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
|
|
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
|
|
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
|
|
|
|
a_m_k_device_buf.ToDevice(A.mData.data());
|
|
b_n_k_device_buf.ToDevice(B.mData.data());
|
|
kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
|
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
|
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
|
|
c_m_n_device_buf.FromDevice(C.mData.data());
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename DeviceMFMA,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename CDataType,
|
|
typename GPUAccDataType,
|
|
typename CPUAccDataType,
|
|
typename ALayout,
|
|
typename BLayout,
|
|
typename CLayout,
|
|
index_t BLOCK_M,
|
|
index_t BLOCK_N,
|
|
index_t BLOCK_K>
|
|
struct TestMFMA
|
|
{
|
|
auto PrepareGemmTensors(const GemmParams& params, index_t init)
|
|
{
|
|
auto f_host_tensor_descriptor =
|
|
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
|
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
|
{
|
|
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
|
std::vector<std::size_t>({stride, 1}));
|
|
}
|
|
else
|
|
{
|
|
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
|
std::vector<std::size_t>({1, stride}));
|
|
}
|
|
};
|
|
|
|
Tensor<ADataType> a_m_k(
|
|
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
|
|
Tensor<BDataType> b_n_k(
|
|
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
|
Tensor<CDataType> c_m_n_host_result(
|
|
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
|
Tensor<CDataType> c_m_n_device_result(
|
|
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
|
|
|
switch(init)
|
|
{
|
|
case 0:
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
|
|
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
|
break;
|
|
case 1:
|
|
// results in C = {K}
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
|
break;
|
|
case 2:
|
|
// expect small round off errors
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-5, 5});
|
|
break;
|
|
case 3:
|
|
// expect small round off errors
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
|
|
break;
|
|
default:
|
|
// all initial values are representable in FP8, BF8
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
|
|
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
|
|
|
|
break;
|
|
}
|
|
|
|
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
|
|
}
|
|
|
|
auto operator()(const DeviceMFMA& mfma_kernel, index_t init)
|
|
{
|
|
// Arrange
|
|
GemmParams params;
|
|
params.M = BLOCK_M;
|
|
params.N = BLOCK_N;
|
|
params.K = BLOCK_K;
|
|
|
|
auto f_get_default_stride = [](std::size_t row,
|
|
std::size_t col,
|
|
ck::index_t stride,
|
|
auto layout) {
|
|
if(stride == -1)
|
|
{
|
|
// give a chance if stride is -1, return a default packed stride
|
|
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return static_cast<std::size_t>(col);
|
|
}
|
|
else
|
|
{
|
|
return static_cast<std::size_t>(row);
|
|
}
|
|
}
|
|
else
|
|
return static_cast<std::size_t>(stride);
|
|
};
|
|
|
|
params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{});
|
|
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
|
|
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
|
|
|
|
auto host_tensors = PrepareGemmTensors(params, init);
|
|
|
|
const Tensor<ADataType>& a = std::get<0>(host_tensors);
|
|
const Tensor<BDataType>& b = std::get<1>(host_tensors);
|
|
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
|
|
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
|
|
|
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
|
|
|
auto a_element_op = PassThrough{};
|
|
auto b_element_op = PassThrough{};
|
|
auto c_element_op = PassThrough{};
|
|
|
|
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
|
BDataType,
|
|
CDataType,
|
|
CPUAccDataType,
|
|
PassThrough,
|
|
PassThrough,
|
|
PassThrough,
|
|
CPUAccDataType>;
|
|
|
|
RunHostGEMM<ReferenceGemmInstance>(a, b, c_host, a_element_op, b_element_op, c_element_op);
|
|
|
|
RunDeviceGEMM(mfma_kernel, a, b, c_device);
|
|
|
|
bool res = false;
|
|
if constexpr(std::is_same<CDataType, float>::value ||
|
|
std::is_same<CDataType, half_t>::value)
|
|
{
|
|
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
|
}
|
|
else
|
|
{
|
|
std::cout << "UNSUPPORTED CDataType" << std::endl;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
};
|
|
|
|
} // namespace mfma_test
|
|
} // namespace ck
|