mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-28 02:31:26 +00:00
Merge from internal (#1857)
* enable batched_gemm_softmax_gemm_perm_wmma for gfx12 * disable instances with blocksize=256 in attention examples * debuggging * debug * fixed lds_enabled * debugging * Fix and add limit to skiplds feature * Enable skipLds feature and fix compilation bugs * add ck_tile definitions for gfx12 * fix clang format and test/wmma_op * updage instances cmake for gfx12 * disable the test_wmma_op on gfx12 * fix the builds for gfx950 * add gfx12 and gfx950 to default target list * clean-up cmake file * Initial introduction of OFP8 data types. * Renamed FP8 and BF8 tests into FP8_FNUZ and BF8_FNUZ. * Implementation of ConvertFP32Nearest in test_fp8_ocp. * Remove dependence on possibly undeclared alias. * Implement FP8OCP test for stochastic rounding mode. * Implement FP8OCP tests for half_t type conversions. * enable bf16 atomic add on gfx950 * Implement ConvertFP32Nearest test. * Implement ConvertFP32Stochastic test. * Implement ConvertFP16Nearest and ConvertFP16Stochastic tests. * Refactoring. Move FP8 definitions into a separate header file. * Enable easy switching between architectures. * Fix compilation error for gfx942 architecture. * Add fp4 type with constants * only builf gfx950 branch for gfx950 target by default * Enable OCP build of example_gemm_xdl_fp8. * Fix formatting. * fix the build logic for gfx950 * Improve GEMM example verbosity. * Add constexpr where applicable. * fix the logic of enabling XDL and WMMA instances * Improve GEMM example verbosity. * Enable build of example_gemm_xdl_fp8_bf8 test. * Fix tests for gfx1101 architecture. * Build DPP examples only on gfx103 and gfx11 architectures. * Optionaly run either CPU or GPU verifications with GEMM examples. * Extend GeneratorTensor_Sequential to produce values of prescribed data types. * Add missing constructor. * Add scale type and mxfp conversions * Update conversions * Add conversion tests * Fix typo * Improve infrastructure for OFP8 data type support. * BUGFIX. Should not use FP8 as Compute/Accum data type. * Add custom target for grouped_convnd_bwd_weight tests. * Can build `tests` target on gfx950. * Bugfixes on gfx1101 architecture. * Fix dependencies. * Add stochastic rounding tests * Provide single point of truth for FP8 INF and NAN checks * Prevent instantiation of operators that are not supported by FP8 data types * Add FP8 type selection into client_axample CMakeLists.txt * Prevent sccache server from shutting down during build * Fix test success reporting logic * Change default verification method to CPU. GPU verification takes too much time to complete on the emulator. * Add scale <-> float conversions * Add scaled conversions with tests * Add device conversions * Make sure all tests and examples are built for gfx950 * Facilitate testing of FP8 data types on the emulator * Introduce two new tensor generators * Enable instances built for gfx94 to be built on gfx950 * Verify 35_splitk_gemm on floating point numbers. splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved. * Format * Verify 04_gemm_add_add_fastgelu on floating point numbers * Verify 20_grouped_conv_bwd_weight on floating point numbers * Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers * Verify more tests on floating point data * Fix data types and improve testing verbocity. * Add fp4 vectors * Add debug tests * Upgrade to NPI 573 build docker. * Skip on gemm_universal tests. The tests take too long to complete on the emulator. Need to see if it is possible to reduce the scope of the testing to just FP8 data types. * Add new mfma instructions and examples * Add preprocessor directives for gfx950 specific code * Fix gfx1101 build * Document test availability * Re-enable fp8 gemms for gfx94/95 * Cherry-pick GEMM Universal tests for FP8 data types * Cleanup * Add vector types and tests * Add check_err function * Add tensor generators * CK_USE_GFX94 has already been set on this branch * Fix * Address formatting issues and leftovers * Make fail/pass logic consistent within 01_gemm folder Removed multiple negations in fail/pass logic to propagate `true` as the success indicator. * Fix GPU verification reporting logic. * Update year in copyright notice. * Cleanup * Use `enum class` instead of `enum` * Remove set_property for FP8 tests * Add vector conversions * Fix * Fix linker errror * Clean up * Fix gfx950 conversions * Clean up * Fix more gfx950 conversions * Fix even more gfx950 conversions * Narrowing the scope of PR to OCP FP8 enablement only * Add tests for OCP FP8 vector_type storage * Fix client examples build * Fix typo * Update e8m0 casting * Rename E8M0 type * Update unpack method * Cleanup merge artifacts * Enable gemm kernel on all gfx9 architectures (#227) * clean-up * Implement `non_native_vector_base` with `ext_vector_type` array. (#232) * Enable support of 1, 2, 4, and 8-byte custom types in CK. * Fix pool tests for OCP FP8 data type * Fix build * Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on gfx950 * fix clang format * Add new mfma instructions and examples * Add preprocessor directives for gfx950 specific code * Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on gfx950 * fix clang format * Fix clang format for the newly merged files * Use the existing example instances for fp16 bf16 and int8 * Remove comment on new mfma instructions in MfmaInstr * Update include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * merge from public repo * Fix ck build * Fix ck build * Use double for max_abs_in_val * Move scaled_type_convert functions to a separate header (#251) * re-enable building mha lib and gemm_universal_f8 instances for gfx950 * Update library/src/tensor_operation_instance/gpu/CMakeLists.txt Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * fix typo for CK_USE_OCP_FP8 * fix typo for CK_USE_OCP_FP8 * Add FP6 and BF6 types (#261) * Add a rounding flag * Add FP6 and BF6 * Add tests Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * Clean up --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * fix one more typo * Refactor E8M0 scale implementation (#262) * Refactor E8M0 scale implementation * Add MXFP6 and MXBF6 conversion methods (#270) * Add conversions * Add tests * Add docstrings * Add scaled conversions * Add fp6/bf6 tests * Remove misleading fp4 test case * Add docstrings * Clean up * Address comments * Set stricter tolerances for RNE tests * Add missing tests * Add native conversions to float * Revert "Add native conversions to float" This reverts commit 09467111f73b753c8cc3d597533b187940353dab. * Update copyright years * replace the fp6 with bf6 convert calls in test_bf6 * fix test_bf6 * enable smfmac test * [MX FP8] Add Scaled Type Convert Functions for OCP FP8/BF8 data types (#271) * Move scaled_type_convert functions to a separate header * Introduce MX data tests * Build MX tests only on relevant architectures * Refactor E8M0 scale implementation * Fix `config.h` typo * Cleanup deprecated symbols * Refactor `amd_ck_fp8.hpp` * `scaled_type_convert` for `f8_ocp_t` * Implement test for MX FP8 scaled type convert * Implement test for MX BF8 scaled type convert * Scaled type convert for vectors of 2 FP8 elements * Scaled type convert for vectors of 16 FP8 elements * Implementation of scaled conversion from F32 to F8 * Add tests for scaled conversions from FP32 to FP8 * Add documentation to the test functions * Implementation of scaled conversion from F32x2 to F8x2 * Implementation of scaled conversion from F32x16 to F8x16 * Implementation of scaled conversion from F32x32 to F8x32 * Implementation of scaled conversion from F8x32 to F32x32 * Verified on the emulator * MX FP GEMM - Example Template (#277) Temporarily uses `DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3` kernel and 128x128 scaling matrices. Must be modified to use MX-native GEMM kernell with 16 or 32 component vectors per scale. Verified on the emulator. * Add vector support * Add tests * Add missing type aliases * Fix test naming * only build mx example for gfx950 * disable CK_USE_AMD_MFMA_GFX950 by default * fic build for multiple archs * fix typo * fix typo * Update unpack signature * Fix merge * Add size checks in pack function * Add a flag * Add conversions * Fix build logic * Update pack/unpack methods * Remove unneeded AsType accessors * Add docstrings * Add a flag to config file * Test the functionality of V_MFMA_F32_16X16X128_F8F6F4 and V_MFMA_F32_32X32X64_F8F6F4 instructions. (#293) * Introduced MFMA tests * Verified f8f6f4 MFMA Instructions * Move flag logic to scaled_type_convert header * Use pointers instead of array indices * Fix a typo * Update tests and pack functions * Fix gemm gemm on gfx950 * Fix clang format * restore the default gput target lists * fix the jenkinsfile * add missing ifdef --------- Co-authored-by: Jing Zhang <jizhan@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Jun Liu <Liu.Jun@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: root <root@banff-cyxtera-s83-2.ctr.dcgpu> Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: jefyang1 <146495389+jefyang1@users.noreply.github.com> Co-authored-by: jefyang1 <Jeffreyj.Yang@amd.com>
This commit is contained in:
9
test/mx_mfma_op/CMakeLists.txt
Normal file
9
test/mx_mfma_op/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
add_custom_target(test_mx_mfma)
|
||||
|
||||
add_gtest_executable(test_mx_mfma_op mx_mfma_op.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_mx_mfma_op PRIVATE utility)
|
||||
endif()
|
||||
add_dependencies(test_mx_mfma test_mx_mfma_op)
|
||||
|
||||
|
||||
65
test/mx_mfma_op/mx_mfma_op.cpp
Normal file
65
test/mx_mfma_op/mx_mfma_op.cpp
Normal file
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "mx_mfma_op.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::f8_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
/**
|
||||
* @brief Run the test for the given MFMA instruction
|
||||
*
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mfma_test(ck::index_t init)
|
||||
{
|
||||
using ALayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using AccType = float; // only MFMA_F32 instructions supported
|
||||
using CPUAccType = AccType;
|
||||
|
||||
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
|
||||
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
|
||||
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
|
||||
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
|
||||
|
||||
const auto mx_mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = ck::mfma_test::TestMFMA<decltype(mx_mfma_kernel),
|
||||
AType,
|
||||
BType,
|
||||
CType,
|
||||
AccType,
|
||||
CPUAccType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K>{}(mx_mfma_kernel, init);
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 0;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 0;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
567
test/mx_mfma_op/mx_mfma_op.hpp
Normal file
567
test/mx_mfma_op/mx_mfma_op.hpp
Normal file
@@ -0,0 +1,567 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// MFMA instructions supported in this test
|
||||
enum class MFMA_F8F6F4
|
||||
{
|
||||
F32_16x16x128 =
|
||||
static_cast<int>(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4
|
||||
F32_32x32x64 =
|
||||
static_cast<int>(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4
|
||||
};
|
||||
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT, int32_t BLOCK_M, int32_t BLOCK_N>
|
||||
struct mfma_type_selector;
|
||||
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT>
|
||||
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
|
||||
{
|
||||
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
||||
{
|
||||
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
|
||||
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AFragT, typename BFragT, typename AccumFragT>
|
||||
struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
|
||||
{
|
||||
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
|
||||
{
|
||||
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
|
||||
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename VecT>
|
||||
static constexpr int32_t vectorSize(const VecT&)
|
||||
{
|
||||
return scalar_type<VecT>::vector_size;
|
||||
}
|
||||
|
||||
// Define a load function for input A blocks:
|
||||
// Size: (BLOCK_M x BLOCK_K)
|
||||
// ASSUMPTION:
|
||||
// - We want contiguous BLOCK_M sized column neighbors in register.
|
||||
// - Data is in col_major format
|
||||
// This means:
|
||||
// - From A we will load K columns of size BLOCK_M to satisfy our input data
|
||||
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
|
||||
__device__ AFragT load_A_col_major(AType const* input_ptr)
|
||||
{
|
||||
// clang-format off
|
||||
// Register Mapping for 16x128: || Register Mapping for 32x64:
|
||||
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
|
||||
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
|
||||
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
|
||||
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
|
||||
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
|
||||
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
|
||||
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
|
||||
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
|
||||
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
|
||||
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
|
||||
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
|
||||
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
|
||||
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
|
||||
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
|
||||
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
|
||||
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
|
||||
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
|
||||
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
|
||||
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
|
||||
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
|
||||
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
|
||||
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
|
||||
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
|
||||
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
|
||||
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
|
||||
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
|
||||
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
|
||||
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
|
||||
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
|
||||
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
|
||||
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
|
||||
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
|
||||
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
|
||||
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
|
||||
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
|
||||
// clang-format on
|
||||
|
||||
// Here we want to load a BLOCK_M x BLOCK_K block of data.
|
||||
static constexpr uint32_t VW = vectorSize(AFragT{});
|
||||
using ARawT = typename scalar_type<AFragT>::type;
|
||||
using AScalarFragT = vector_type<ARawT, VW>::type;
|
||||
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 32 elements.
|
||||
// We need to know where they start, and where the next elements are.
|
||||
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
|
||||
(threadIdx.x / BLOCK_M) * VW); // Col
|
||||
auto stepCoord2D = std::make_pair(0u, 1u);
|
||||
|
||||
// Flatten to 1D col_major offsets.
|
||||
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
||||
|
||||
// BLOCK_M is a stride in A matrix
|
||||
auto startOffset = col_major(startCoord2D, BLOCK_M);
|
||||
auto kOffset = col_major(stepCoord2D, BLOCK_M);
|
||||
|
||||
// kOffset == BLOCK_M
|
||||
// This means every BLOCK_M element is loaded into output vector
|
||||
auto fragA = AScalarFragT{};
|
||||
#pragma unroll VW
|
||||
for(uint32_t i = 0; i < VW; i++)
|
||||
{
|
||||
fragA[i] = bit_cast<ARawT>(input_ptr[startOffset + i * kOffset]);
|
||||
}
|
||||
|
||||
return fragA;
|
||||
}
|
||||
|
||||
// Define a load function for input B blocks:
|
||||
// Size: (BLOCK_K x BLOCK_N)
|
||||
// ASSUMPTION:
|
||||
// - We want contiguous BLOCK_N sized row neighbors in register.
|
||||
// - Data is in row_major format
|
||||
// This means:
|
||||
// - From B we will load K rows of size BLOCK_N to satisfy our input data
|
||||
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
|
||||
__device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
{
|
||||
// clang-format off
|
||||
// Register Mapping for 128x16: || Register Mapping for 64x32:
|
||||
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N |
|
||||
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
|
||||
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
|
||||
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
|
||||
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
|
||||
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
|
||||
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
|
||||
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
|
||||
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
|
||||
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
|
||||
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
|
||||
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
|
||||
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
|
||||
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
|
||||
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
|
||||
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
|
||||
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
|
||||
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
|
||||
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
|
||||
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
|
||||
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
|
||||
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
|
||||
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
|
||||
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
|
||||
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
|
||||
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
|
||||
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
|
||||
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
|
||||
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
|
||||
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
|
||||
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
|
||||
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
|
||||
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
|
||||
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
|
||||
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
|
||||
// clang-format on
|
||||
|
||||
// Here we want to load a BLOCK_K x BLOCK_N block of data.
|
||||
static constexpr uint32_t VW = vectorSize(BFragT{});
|
||||
|
||||
// To start the loading process, let's visualize in 2D coords.
|
||||
// Each thread will load 32 elements.
|
||||
// We need to know where they start, and where the next elements are.
|
||||
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row
|
||||
threadIdx.x % BLOCK_N); // Col
|
||||
|
||||
// Flatten to 1D col_major offsets.
|
||||
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
||||
|
||||
auto startOffset = col_major(startCoord2D, BLOCK_K);
|
||||
|
||||
auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset);
|
||||
return *fragPtr;
|
||||
}
|
||||
|
||||
// Define a store function for C
|
||||
// Size: (BLOCK_M x BLOCK_N)
|
||||
// ASSUMPTION:
|
||||
// - We want contiguous BLOCK_N sized row neighbors in register.
|
||||
// - Data is in col_major format
|
||||
// This means:
|
||||
// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data
|
||||
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
|
||||
struct store_C_col_major;
|
||||
|
||||
// Here we want to store a 16x16 block of data.
|
||||
//
|
||||
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
|
||||
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
|
||||
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
|
||||
// Register Element ------------ ------------- ------------ -------------- Element
|
||||
// Reg0 | M0 | M4 | M8 | M12 | v[0]
|
||||
// Reg1 | M1 | M5 | M9 | M13 | v[1]
|
||||
// Reg2 | M2 | M6 | M10 | M14 | v[2]
|
||||
// Reg3 | M3 | M7 | M11 | M15 | v[3]
|
||||
template <typename CType, typename CFragT>
|
||||
struct store_C_col_major<CType, CFragT, 16, 16>
|
||||
{
|
||||
__device__ void operator()(CType* output, CFragT cFrag)
|
||||
{
|
||||
static constexpr uint32_t VW = vectorSize(cFrag); // 4
|
||||
static constexpr uint32_t Dim = 16;
|
||||
|
||||
// Each thread will load 4 elements.
|
||||
// We need to know where they start, and where the next elements are.
|
||||
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
|
||||
threadIdx.x % Dim); // Col
|
||||
|
||||
// Flatten to 1D col_major offsets.
|
||||
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
||||
|
||||
auto startOffset = col_major(startCoord2D, 16);
|
||||
|
||||
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
|
||||
*fragPtr = cFrag;
|
||||
}
|
||||
};
|
||||
|
||||
// Here we want to store a 32x32 block of data.
|
||||
// Register Mapping:
|
||||
|
||||
// Size | BLOCK_N | BLOCK_N |
|
||||
// N | 0 ... 31 | 0 ... 31 |
|
||||
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
|
||||
// Register Element ------------ ------------- Element
|
||||
// Reg0 | M0 | M4 | v[0]
|
||||
// Reg1 | M1 | M5 | v[1]
|
||||
// Reg2 | M2 | M6 | v[2]
|
||||
// Reg3 | M3 | M7 | v[3]
|
||||
// ____________ _____________
|
||||
// Reg4 | M8 | M12 | v[4]
|
||||
// Reg5 | M9 | M13 | v[5]
|
||||
// Reg6 | M10 | M14 | v[6]
|
||||
// Reg7 | M11 | M15 | v[7]
|
||||
// ____________ _____________
|
||||
// Reg8 | M16 | M20 | v[8]
|
||||
// Reg9 | M17 | M21 | v[9]
|
||||
// Reg10 | M18 | M22 | v[10]
|
||||
// Reg11 | M19 | M23 | v[11]
|
||||
// ____________ _____________
|
||||
// Reg12 | M24 | M28 | v[12]
|
||||
// Reg13 | M25 | M29 | v[13]
|
||||
// Reg14 | M26 | M30 | v[14]
|
||||
// Reg15 | M27 | M31 | v[15]
|
||||
|
||||
template <typename CType, typename CFragT>
|
||||
struct store_C_col_major<CType, CFragT, 32, 32>
|
||||
{
|
||||
__device__ void operator()(CType* output, CFragT cFrag)
|
||||
{
|
||||
static constexpr uint32_t WAVE_SIZE = 64;
|
||||
static constexpr uint32_t VW = 4;
|
||||
static constexpr uint32_t Dim = 32;
|
||||
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
|
||||
|
||||
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
|
||||
threadIdx.x % Dim); // Col
|
||||
|
||||
// Major step between 'chunks'
|
||||
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
|
||||
|
||||
// Flatten to 1D col_major offsets.
|
||||
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
|
||||
|
||||
auto startOffset = col_major(startCoord2D, 32);
|
||||
auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8
|
||||
|
||||
// we can vector store 4 contiguous elements at a time.
|
||||
using CRawT = typename scalar_type<CFragT>::type;
|
||||
using CScalarFragT = vector_type<CRawT, VW>::type;
|
||||
union
|
||||
{
|
||||
CFragT frag;
|
||||
CScalarFragT chunks[vectorSize(CFragT{}) / VW];
|
||||
} fragC{cFrag}; // Initialize with input fragment
|
||||
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset)) = fragC.chunks[0];
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset + kMajorOffset)) = fragC.chunks[1];
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 2 * kMajorOffset)) =
|
||||
fragC.chunks[2];
|
||||
*(reinterpret_cast<CScalarFragT*>(output + startOffset + 3 * kMajorOffset)) =
|
||||
fragC.chunks[3];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AccType,
|
||||
int32_t BLOCK_M,
|
||||
int32_t BLOCK_N,
|
||||
int32_t BLOCK_K>
|
||||
__global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
{
|
||||
constexpr int WAVE_SIZE = 64;
|
||||
assert(threadIdx.x < WAVE_SIZE);
|
||||
assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1);
|
||||
|
||||
using AFragT = vector_type<AType, BLOCK_M * BLOCK_K / WAVE_SIZE>::type;
|
||||
using BFragT = vector_type<BType, BLOCK_K * BLOCK_N / WAVE_SIZE>::type;
|
||||
using CFragT = vector_type<CType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
using AccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>;
|
||||
using RawAccumFragT = vector_type<AccType, BLOCK_M * BLOCK_N / WAVE_SIZE>::type;
|
||||
|
||||
// Create frags
|
||||
auto fragA = AFragT{};
|
||||
auto fragB = BFragT{};
|
||||
auto fragC = CFragT{};
|
||||
auto fragAcc = AccumFragT{0};
|
||||
|
||||
// Load the inputs.
|
||||
// A = col major, BLOCK_M x BLOCK_K
|
||||
fragA = load_A_col_major<AType, AFragT, BLOCK_M, BLOCK_K>(a);
|
||||
// B = col major, BLOCK_K x BLOCK_N
|
||||
fragB = load_B_col_major<BType, BFragT, BLOCK_K, BLOCK_N>(b);
|
||||
|
||||
// Matrix multiply-accumulate using MFMA units
|
||||
// Accumulation intermediate = BLOCK_M x BLOCK_N
|
||||
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(fragA, fragB, fragAcc);
|
||||
|
||||
for(int i = 0; i < vectorSize(fragC); ++i)
|
||||
{
|
||||
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
|
||||
}
|
||||
|
||||
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
|
||||
storeC(c, fragC);
|
||||
}
|
||||
/**
|
||||
* @brief Structure to hold dimension parameters for GEMM tensors.
|
||||
*
|
||||
* M Number of rows in matrix A and matrix C.
|
||||
* N Number of columns in matrix B and matrix C.
|
||||
* K Number of columns in matrix A and number of rows in matrix B.
|
||||
* StrideA Stride (leading dimension) of matrix A.
|
||||
* StrideB Stride (leading dimension) of matrix B.
|
||||
* StrideC Stride (leading dimension) of matrix C.
|
||||
*/
|
||||
struct GemmParams
|
||||
{
|
||||
ck::index_t M = 16;
|
||||
ck::index_t N = 16;
|
||||
ck::index_t K = 128;
|
||||
|
||||
ck::index_t StrideA = -1;
|
||||
ck::index_t StrideB = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
};
|
||||
|
||||
namespace mfma_test {
|
||||
template <typename GemmInstance,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
void RunHostGEMM(const Tensor<ADataType>& A,
|
||||
const Tensor<BDataType>& B,
|
||||
Tensor<CDataType>& C,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
auto ref_gemm = GemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
template <typename KernelType, typename ADataType, typename BDataType, typename CDataType>
|
||||
bool RunDeviceGEMM(KernelType kernel,
|
||||
const Tensor<ADataType>& A,
|
||||
const Tensor<BDataType>& B,
|
||||
Tensor<CDataType>& C)
|
||||
{
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(A.mData.data());
|
||||
b_n_k_device_buf.ToDevice(B.mData.data());
|
||||
kernel<<<1, 64>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
|
||||
c_m_n_device_buf.FromDevice(C.mData.data());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DeviceMFMA,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GPUAccDataType,
|
||||
typename CPUAccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
index_t BLOCK_M,
|
||||
index_t BLOCK_N,
|
||||
index_t BLOCK_K>
|
||||
struct TestMFMA
|
||||
{
|
||||
auto PrepareGemmTensors(const GemmParams& params, index_t init)
|
||||
{
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(
|
||||
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_n_k(
|
||||
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
|
||||
switch(init)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
|
||||
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
break;
|
||||
case 1:
|
||||
// results in C = {K}
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
||||
break;
|
||||
case 2:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-5, 5});
|
||||
break;
|
||||
case 3:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
|
||||
break;
|
||||
default:
|
||||
// all initial values are representable in FP8, BF8
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6});
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
|
||||
}
|
||||
|
||||
auto operator()(const DeviceMFMA& mfma_kernel, index_t init)
|
||||
{
|
||||
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
|
||||
<< ", CLayout = " << CLayout{}.name << std::endl;
|
||||
|
||||
// Arrange
|
||||
GemmParams params;
|
||||
params.M = BLOCK_M;
|
||||
params.N = BLOCK_N;
|
||||
params.K = BLOCK_K;
|
||||
|
||||
auto f_get_default_stride = [](std::size_t row,
|
||||
std::size_t col,
|
||||
ck::index_t stride,
|
||||
auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{});
|
||||
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
|
||||
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
|
||||
|
||||
auto host_tensors = PrepareGemmTensors(params, init);
|
||||
|
||||
const Tensor<ADataType>& a = std::get<0>(host_tensors);
|
||||
const Tensor<BDataType>& b = std::get<1>(host_tensors);
|
||||
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
|
||||
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
auto a_element_op = PassThrough{};
|
||||
auto b_element_op = PassThrough{};
|
||||
auto c_element_op = PassThrough{};
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
CPUAccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
RunHostGEMM<ReferenceGemmInstance>(a, b, c_host, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
RunDeviceGEMM(mfma_kernel, a, b, c_device);
|
||||
|
||||
bool res = false;
|
||||
if constexpr(std::is_same<CDataType, float>::value ||
|
||||
std::is_same<CDataType, half_t>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device.mData, c_host.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "UNSUPPORTED CDataType" << std::endl;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mfma_test
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user