mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout ## Motivation Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled layouts using async pipeline. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8def2c724
commit
8f27f65d44
@@ -57,6 +57,7 @@ add_subdirectory(add_rmsnorm2d_rdquant)
|
||||
# add_subdirectory(layernorm2d)
|
||||
# add_subdirectory(rmsnorm2d)
|
||||
add_subdirectory(gemm_block_scale)
|
||||
add_subdirectory(gemm_mx)
|
||||
add_subdirectory(utility)
|
||||
add_subdirectory(warp_gemm)
|
||||
add_subdirectory(reduce)
|
||||
|
||||
@@ -20,3 +20,21 @@ TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesCompAsync);
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompAsync16x16x128
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompAsync16x16x128<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompAsync16x16x128, KernelTypesCompAsync16x16x128);
|
||||
TYPED_TEST(TestCkTileGemmPipelineCompAsync16x16x128, QuickTest)
|
||||
{
|
||||
constexpr int M = 1024;
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 1024;
|
||||
|
||||
this->template RunSingle<false, false, false, false>(M, N, K, 0, 0, 0, 1);
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ using NonPersistent = std::false_type;
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
// clang-format off
|
||||
@@ -224,6 +225,23 @@ using CompAsyncConfig = std::tuple<ALayout,
|
||||
Intrawave,
|
||||
CompAsync>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompAsyncConfig16x16x128 = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I64, // MBlockTileSize
|
||||
I64, // NBlockTileSize
|
||||
I128, // KBlockTileSize
|
||||
I16, // MWarpTileSize
|
||||
I16, // NWarpTileSize
|
||||
I128, // KWarpTileSize
|
||||
Intrawave,
|
||||
CompAsync>;
|
||||
|
||||
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
|
||||
CompAsyncConfig<Row, Col, Row, F16>,
|
||||
CompAsyncConfig<Col, Row, Row, F16>,
|
||||
@@ -232,6 +250,10 @@ using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16
|
||||
CompAsyncConfig<Row, Col, Row, F8>,
|
||||
CompAsyncConfig<Col, Row, Row, F8>,
|
||||
CompAsyncConfig<Col, Col, Row, F8>>;
|
||||
|
||||
using KernelTypesCompAsync16x16x128 = ::testing::Types<CompAsyncConfig16x16x128<Row, Col, Row, F4>,
|
||||
CompAsyncConfig16x16x128<Row, Col, Row, F8>>;
|
||||
|
||||
// clang-format off
|
||||
|
||||
using KernelTypesCompV6 = ::testing::Types<
|
||||
|
||||
@@ -7,6 +7,7 @@ using INT32 = ck_tile::int32_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F4 = ck_tile::pk_fp4_t;
|
||||
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
|
||||
17
test/ck_tile/gemm_mx/CMakeLists.txt
Normal file
17
test/ck_tile/gemm_mx/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(TEST_MX_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND TEST_MX_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile MX GEMM tests for current target")
|
||||
endif()
|
||||
95
test/ck_tile/gemm_mx/test_mx_gemm_config.hpp
Normal file
95
test/ck_tile/gemm_mx/test_mx_gemm_config.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
|
||||
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
|
||||
{
|
||||
using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>;
|
||||
|
||||
MXGemmHostArgs(const void* a_ptr,
|
||||
const void* b_ptr,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t k_batch_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ScaleM scale_m_,
|
||||
ScaleN scale_n_)
|
||||
: Base({a_ptr},
|
||||
{b_ptr},
|
||||
{},
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{},
|
||||
stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
ScaleM scale_m;
|
||||
ScaleN scale_n;
|
||||
};
|
||||
|
||||
struct MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 512;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp4_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
struct MXfp8_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
30
test/ck_tile/gemm_mx/test_mx_gemm_fp4.cpp
Normal file
30
test/ck_tile/gemm_mx/test_mx_gemm_fp4.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_util.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using MxFp4Types = ::testing::Types<
|
||||
std::tuple<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, MXfp4_GemmConfig16, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp4 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
|
||||
std::tuple_element_t<1, TypeParam>,
|
||||
std::tuple_element_t<2, TypeParam>,
|
||||
std::tuple_element_t<3, TypeParam>,
|
||||
std::tuple_element_t<4, TypeParam>,
|
||||
std::tuple_element_t<5, TypeParam>>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp4, MxFp4Types);
|
||||
|
||||
TYPED_TEST(TestMxGemmFp4, BasicSizes)
|
||||
{
|
||||
this->Run(64, 64, 256);
|
||||
this->Run(128, 128, 256);
|
||||
this->Run(64, 128, 512);
|
||||
}
|
||||
30
test/ck_tile/gemm_mx/test_mx_gemm_fp8.cpp
Normal file
30
test/ck_tile/gemm_mx/test_mx_gemm_fp8.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_util.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using MxFp8Types =
|
||||
::testing::Types<std::tuple<ck_tile::fp8_t, ck_tile::fp8_t, MXfp8_GemmConfig16, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp8 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
|
||||
std::tuple_element_t<1, TypeParam>,
|
||||
std::tuple_element_t<2, TypeParam>,
|
||||
std::tuple_element_t<3, TypeParam>,
|
||||
std::tuple_element_t<4, TypeParam>,
|
||||
std::tuple_element_t<5, TypeParam>>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp8, MxFp8Types);
|
||||
|
||||
TYPED_TEST(TestMxGemmFp8, BasicSizes)
|
||||
{
|
||||
this->Run(64, 64, 256);
|
||||
this->Run(128, 128, 256);
|
||||
this->Run(64, 128, 512);
|
||||
}
|
||||
97
test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp
Normal file
97
test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
bool Splitk>
|
||||
float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
|
||||
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC>>;
|
||||
|
||||
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(std::array<const void*, 1>{args.as_ptr},
|
||||
std::array<const void*, 1>{args.bs_ptr},
|
||||
std::array<const void*, 0>{},
|
||||
args.e_ptr,
|
||||
args.k_batch,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
std::array<ck_tile::index_t, 1>{args.stride_As},
|
||||
std::array<ck_tile::index_t, 1>{args.stride_Bs},
|
||||
std::array<ck_tile::index_t, 0>{},
|
||||
args.stride_E,
|
||||
args.scale_m,
|
||||
args.scale_n);
|
||||
|
||||
const auto kernel = ck_tile::make_kernel<Kernel::kBlockPerCu>(
|
||||
Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs);
|
||||
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
}
|
||||
137
test/ck_tile/gemm_mx/test_mx_gemm_util.hpp
Normal file
137
test/ck_tile/gemm_mx/test_mx_gemm_util.hpp
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_instance.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<
|
||||
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(K);
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value, K);
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
class TestMxGemmUtil : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::fp16_t;
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
|
||||
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234)
|
||||
{
|
||||
const ck_tile::index_t scale_k_size = K / 32;
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, is_row_major(ALayout{}));
|
||||
const ck_tile::index_t stride_B =
|
||||
ck_tile::get_default_stride(K, N, 0, is_row_major(BLayout{}));
|
||||
const ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, 0, is_row_major(CLayout{}));
|
||||
const ck_tile::index_t stride_scale_a =
|
||||
ck_tile::get_default_stride(M, scale_k_size, 0, is_row_major(ALayout{}));
|
||||
const ck_tile::index_t stride_scale_b =
|
||||
ck_tile::get_default_stride(scale_k_size, N, 0, is_row_major(BLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<CDataType> c_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::HostTensor<ScaleType> scale_a_host(ck_tile::host_tensor_descriptor(
|
||||
M, scale_k_size, stride_scale_a, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(ck_tile::host_tensor_descriptor(
|
||||
scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_dev_buf.ToDevice(b_host.data());
|
||||
c_dev_buf.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_host.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_host.data());
|
||||
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
MXGemmHostArgs<ScaleM, ScaleN> args(a_dev_buf.GetDeviceBuffer(),
|
||||
b_dev_buf.GetDeviceBuffer(),
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
scale_m,
|
||||
scale_n);
|
||||
|
||||
mx_gemm_calc<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
true,
|
||||
false>(args, ck_tile::stream_config{nullptr, true, 1, 0, 1, true, true, 50});
|
||||
|
||||
c_dev_buf.FromDevice(c_host.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_ref.SetZero();
|
||||
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_ref, scale_a_host, scale_b_host);
|
||||
|
||||
const float max_accumulated_value = ck_tile::type_convert<float>(c_ref.max());
|
||||
const auto rtol_atol = calculate_rtol_atol_mx<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, max_accumulated_value);
|
||||
const double rtol = rtol_atol.at(ck_tile::number<0>{});
|
||||
const double atol = rtol_atol.at(ck_tile::number<1>{});
|
||||
|
||||
bool pass = ck_tile::check_err(c_host, c_ref, "MX GEMM: Incorrect results!", rtol, atol);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user