mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
[rocm-libraries] ROCm/rocm-libraries#8554 (commit be9af54)
refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
604c56bc0e
commit
d559ec00a8
@@ -7,8 +7,16 @@ if(CK_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_async test_mx_gemm_async.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_async PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_async_rcr test_mx_gemm_async_rcr.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_async_rcr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_async_rrr test_mx_gemm_async_rrr.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_async_rrr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_async_crr test_mx_gemm_async_crr.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_async_crr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_async_ccr test_mx_gemm_async_ccr.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_async_ccr PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
add_gtest_executable(test_ck_tile_mx_gemm_async_rcr_large_cases test_mx_gemm_async_rcr_large_cases.cpp)
|
||||
target_compile_options(test_ck_tile_mx_gemm_async_rcr_large_cases PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile MX GEMM tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_util.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using F4 = ck_tile::pk_fp4_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using B8 = ck_tile::bf8_t;
|
||||
|
||||
// clang-format off
|
||||
using MxTypes = ::testing::Types<std::tuple<F4, F4, MX_GemmConfig16, Row, Col, Row>,
|
||||
std::tuple<F4, F4, MX_GemmConfigEightWaves, Row, Col, Row>,
|
||||
std::tuple<F4, F4, MXfp4_GemmConfig16_Preshuffle, Row, Col, Row>,
|
||||
std::tuple<F4, F4, MXfp4_GemmConfig16_PermuteN, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MX_GemmConfig16, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MX_GemmConfigEightWaves, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MXfp8_GemmConfig16_Preshuffle, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MXfp8_GemmConfig16_PermuteN, Row, Col, Row>>;
|
||||
// clang-format on
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemm : public TestMxGemmUtil<TypeParam>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemm, MxTypes);
|
||||
|
||||
TYPED_TEST(TestMxGemm, Default)
|
||||
{
|
||||
this->Run(128, 512, 256);
|
||||
this->Run(256, 512, 512);
|
||||
this->Run(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// 32x32x64 MFMA warp tile: enables all four A/B layout combinations via ds_read_tr
|
||||
// transposed LDS loads. (16x16x128 stays Row/Col only above: KWarpTile=128 exceeds the
|
||||
// ds_read_tr subtile limit, which disables transpose loads.)
|
||||
// clang-format off
|
||||
using MxTypesTranspose = ::testing::Types<
|
||||
std::tuple<F4, F4, MXfp4_GemmConfig32, Row, Col, Row>,
|
||||
std::tuple<F4, F4, MXfp4_GemmConfig32, Row, Row, Row>,
|
||||
std::tuple<F4, F4, MXfp4_GemmConfig32, Col, Col, Row>,
|
||||
std::tuple<F4, F4, MXfp4_GemmConfig32, Col, Row, Row>,
|
||||
std::tuple<F8, F8, MXfp8_GemmConfig32, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MXfp8_GemmConfig32, Row, Row, Row>,
|
||||
std::tuple<F8, F8, MXfp8_GemmConfig32, Col, Col, Row>,
|
||||
std::tuple<F8, F8, MXfp8_GemmConfig32, Col, Row, Row>,
|
||||
// bf8/bf8 and mixed fp8/bf8 exercise the float8 paths newly consolidated into the generic
|
||||
// 32x32x64 f8/f6/f4 dispatcher (previously distinct per-type code paths).
|
||||
std::tuple<B8, B8, MXfp8_GemmConfig32, Row, Col, Row>,
|
||||
std::tuple<B8, B8, MXfp8_GemmConfig32, Row, Row, Row>,
|
||||
std::tuple<B8, B8, MXfp8_GemmConfig32, Col, Col, Row>,
|
||||
std::tuple<B8, B8, MXfp8_GemmConfig32, Col, Row, Row>,
|
||||
std::tuple<F8, B8, MXfp8_GemmConfig32, Row, Col, Row>,
|
||||
std::tuple<F8, B8, MXfp8_GemmConfig32, Row, Row, Row>,
|
||||
std::tuple<F8, B8, MXfp8_GemmConfig32, Col, Col, Row>,
|
||||
std::tuple<F8, B8, MXfp8_GemmConfig32, Col, Row, Row>>;
|
||||
// clang-format on
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmTranspose : public TestMxGemmUtil<TypeParam>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmTranspose, MxTypesTranspose);
|
||||
|
||||
TYPED_TEST(TestMxGemmTranspose, BasicSizes)
|
||||
{
|
||||
this->Run(128, 128, 256);
|
||||
this->Run(128, 128, 512);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestMxGemmTranspose, MultiBlockMN)
|
||||
{
|
||||
this->Run(256, 128, 256);
|
||||
this->Run(128, 256, 256);
|
||||
this->Run(256, 256, 256);
|
||||
}
|
||||
|
||||
// Preshuffle split-K coverage. MxTypes already exercises the preshuffle configs on the
|
||||
// non-split-K shapes (TestMxGemm.Default); this fixture pins the split-K shapes to the
|
||||
// fp4/fp8 preshuffle configs.
|
||||
using MxTypesPreshuffle =
|
||||
::testing::Types<std::tuple<F4, F4, MXfp4_GemmConfig16_Preshuffle, Row, Col, Row>,
|
||||
std::tuple<F8, F8, MXfp8_GemmConfig16_Preshuffle, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmPreshuffle : public TestMxGemmUtil<TypeParam>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmPreshuffle, MxTypesPreshuffle);
|
||||
|
||||
// Split-K for the preshuffle pipeline: each k_id offsets the flat-B window and the
|
||||
// host-preshuffled A/B scale windows into its own K slice (and accumulates via atomic-add).
|
||||
// K is a multiple of K_Tile * k_batch (= 256 * k_batch); N is a multiple of 512 so the shapes
|
||||
// are valid for both the fp4 (N_Tile = 512) and fp8 (N_Tile = 256) preshuffle configs.
|
||||
TYPED_TEST(TestMxGemmPreshuffle, SplitK)
|
||||
{
|
||||
this->Run(128, 512, 512, /*k_batch=*/2);
|
||||
this->Run(128, 512, 1024, /*k_batch=*/2);
|
||||
this->Run(128, 512, 1024, /*k_batch=*/4);
|
||||
this->Run(256, 512, 2048, /*k_batch=*/4);
|
||||
}
|
||||
|
||||
// Regression coverage for the MX GEMM correctness fixes (PR #6663): num_loop == 3 hot-loop
|
||||
// dispatch, split-K, and M/N padding. Shapes are pinned to fp8 x MX_GemmConfig16 (M_Tile = 64,
|
||||
// N_Tile = 128, K_Tile = 256, default comp-async pipeline) so the regressions hit the intended
|
||||
// code path -- e.g. K = 768 gives num_loop = K / K_Tile = 3.
|
||||
using MxFp8Cfg16Types = ::testing::Types<std::tuple<F8, F8, MX_GemmConfig16, Row, Col, Row>>;
|
||||
|
||||
using MxFp8PadMNTypes =
|
||||
::testing::Types<std::tuple<F8, F8, MXfp8_GemmConfig16_PadMN, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp8Regression : public TestMxGemmUtil<TypeParam>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp8Regression, MxFp8Cfg16Types);
|
||||
|
||||
// num_loop == 3 must not enter the hot loop: with K_Tile = 256, K = 768 gives num_loop = 3,
|
||||
// which previously produced 5 gemm accumulations instead of 3 (deterministically wrong).
|
||||
TYPED_TEST(TestMxGemmFp8Regression, HotLoopTailNumLoopThree)
|
||||
{
|
||||
this->Run(64, 128, 768);
|
||||
this->Run(128, 256, 768);
|
||||
this->Run(256, 256, 768);
|
||||
}
|
||||
|
||||
// Split-K: exercises both the full_k_read and partial_k_read paths of SplitKBatchOffset together
|
||||
// with the per-split scale-window K offset and the atomic-add epilogue. K is a multiple of
|
||||
// K_Tile * k_batch and of WarpTile_K * k_batch (= 128 * k_batch) so every split lands on a
|
||||
// packed-scale boundary.
|
||||
TYPED_TEST(TestMxGemmFp8Regression, SplitK)
|
||||
{
|
||||
this->Run(128, 256, 512, /*k_batch=*/2);
|
||||
this->Run(128, 256, 1024, /*k_batch=*/2);
|
||||
this->Run(128, 256, 1024, /*k_batch=*/4);
|
||||
this->Run(256, 256, 2048, /*k_batch=*/4);
|
||||
}
|
||||
|
||||
// fp4 split-K (non-preshuffle). Same MX_GemmConfig16 tile shape as the fp8 regression above, so
|
||||
// the K alignment requirements are identical; this verifies the packed (BPackedSize = 2) A/B
|
||||
// pointer K-offset works under split-K + atomic-add for fp4.
|
||||
using MxF4Cfg16Types = ::testing::Types<std::tuple<F4, F4, MX_GemmConfig16, Row, Col, Row>>;
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp4Regression : public TestMxGemmUtil<TypeParam>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp4Regression, MxF4Cfg16Types);
|
||||
|
||||
TYPED_TEST(TestMxGemmFp4Regression, SplitK)
|
||||
{
|
||||
this->Run(128, 256, 512, /*k_batch=*/2);
|
||||
this->Run(128, 256, 1024, /*k_batch=*/2);
|
||||
this->Run(128, 256, 1024, /*k_batch=*/4);
|
||||
this->Run(256, 256, 2048, /*k_batch=*/4);
|
||||
}
|
||||
|
||||
template <typename TypeParam>
|
||||
class TestMxGemmFp8PadMN : public TestMxGemmUtil<TypeParam>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestMxGemmFp8PadMN, MxFp8PadMNTypes);
|
||||
|
||||
// M/N padding (kPadM = kPadN = true). M_Tile = 64, N_Tile = 128. Each of M and N must be >= its
|
||||
// block tile (the CShuffleEpilogue cannot safely run a single partial tile along either
|
||||
// dimension); K stays aligned because the MX async pipeline does not support K padding.
|
||||
TYPED_TEST(TestMxGemmFp8PadMN, MNPaddingAligned)
|
||||
{
|
||||
// Sanity: padding enabled but already-aligned M, N must not regress the normal path.
|
||||
this->Run(64, 128, 256);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestMxGemmFp8PadMN, MPadding)
|
||||
{
|
||||
// M has a full tile + partial trailing tile (N aligned to N_Tile = 128).
|
||||
this->Run(96, 128, 256);
|
||||
this->Run(160, 128, 256);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestMxGemmFp8PadMN, NPadding)
|
||||
{
|
||||
// N has a full tile + partial trailing tile (M aligned to M_Tile = 64).
|
||||
this->Run(64, 160, 256);
|
||||
this->Run(64, 224, 256);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestMxGemmFp8PadMN, MNPadding)
|
||||
{
|
||||
// Both M and N unaligned (full + partial trailing tiles).
|
||||
this->Run(96, 160, 256);
|
||||
this->Run(160, 224, 512);
|
||||
}
|
||||
22
test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp
Normal file
22
test/ck_tile/gemm_mx/test_mx_gemm_async_ccr.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_mx_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompAsyncCCR
|
||||
: public TestCkTileMxGemmPipeline<T, TestCkTileMxGemmPipelineCompAsyncCCR<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncCCR
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncCCR, KernelTypesMxGemmCompAsyncCCR);
|
||||
|
||||
#include "test_mx_gemm_pipeline_tr_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
22
test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp
Normal file
22
test/ck_tile/gemm_mx/test_mx_gemm_async_crr.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_mx_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompAsyncCRR
|
||||
: public TestCkTileMxGemmPipeline<T, TestCkTileMxGemmPipelineCompAsyncCRR<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncCRR
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncCRR, KernelTypesMxGemmCompAsyncCRR);
|
||||
|
||||
#include "test_mx_gemm_pipeline_tr_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
51
test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp
Normal file
51
test/ck_tile/gemm_mx/test_mx_gemm_async_rcr.cpp
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_mx_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompAsyncRCR
|
||||
: public TestCkTileMxGemmPipeline<T, TestCkTileMxGemmPipelineCompAsyncRCR<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRCR
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRCR, KernelTypesMxGemmCompAsyncRCR);
|
||||
|
||||
#include "test_mx_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, MNPadding)
|
||||
{
|
||||
if constexpr(TestFixture::PipelineType == MxGemmPipelineType::WeightPreshuffle ||
|
||||
TestFixture::PipelineType == MxGemmPipelineType::CompEightWaves)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> Ms{96, 160, 224};
|
||||
std::vector<int> Ns{96, 160, 224};
|
||||
std::vector<int> Ks;
|
||||
// K must be multiple of ScaleBlockSize (16 or 32) and K_Tile
|
||||
for(auto K_count : {2, 3, 4})
|
||||
{
|
||||
Ks.push_back(K_count * TestFixture::K_Tile);
|
||||
}
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
for(int N : Ns)
|
||||
{
|
||||
for(int K : Ks)
|
||||
{
|
||||
this->template Run<true, true>(M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
29
test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp
Normal file
29
test/ck_tile/gemm_mx/test_mx_gemm_async_rcr_large_cases.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_mx_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompAsyncRCR
|
||||
: public TestCkTileMxGemmPipeline<T, TestCkTileMxGemmPipelineCompAsyncRCR<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRCR
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRCR, KernelTypesMxGemmCompAsyncRCRLargeCases);
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, Large)
|
||||
{
|
||||
int M = 6422528;
|
||||
int N = 6144;
|
||||
int K = 1024;
|
||||
|
||||
this->RunAllGpu(M, N, K);
|
||||
}
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
22
test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp
Normal file
22
test/ck_tile/gemm_mx/test_mx_gemm_async_rrr.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_mx_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_mx_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompAsyncRRR
|
||||
: public TestCkTileMxGemmPipeline<T, TestCkTileMxGemmPipelineCompAsyncRRR<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsyncRRR
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsyncRRR, KernelTypesMxGemmCompAsyncRRR);
|
||||
|
||||
#include "test_mx_gemm_pipeline_tr_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,169 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
|
||||
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
|
||||
{
|
||||
using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>;
|
||||
|
||||
MXGemmHostArgs(const void* a_ptr,
|
||||
const void* b_ptr,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t k_batch_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ScaleM scale_m_,
|
||||
ScaleN scale_n_)
|
||||
: Base({a_ptr},
|
||||
{b_ptr},
|
||||
{},
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{},
|
||||
stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
|
||||
ScaleM scale_m;
|
||||
ScaleN scale_n;
|
||||
};
|
||||
|
||||
struct MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 512;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr ck_tile::index_t BContiguousItemsPerAccess = 16;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MX_GemmConfig16 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
struct MX_GemmConfigEightWaves : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 * K_Warp;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
struct MXfp4_GemmConfig16_Preshuffle : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr ck_tile::index_t BContiguousItemsPerAccess = 32;
|
||||
};
|
||||
|
||||
struct MXfp4_GemmConfig16_PermuteN : MXfp4_GemmConfig16_Preshuffle
|
||||
{
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
struct MXfp8_GemmConfig16_Preshuffle : MxGemmConfig
|
||||
{
|
||||
// For FP8 Preshuffle:
|
||||
// The theoretical functional minimum is N_Tile = N_Warp * N_Warp_Tile * NXdlPack = 4*16*2 =
|
||||
// 128 . For better performance, we would choose N_Repeat = 2 which would yield N_Tile = 128 * 2
|
||||
// = 256 . Note: If we use fewer waves, the minimum theoretical N_Tile can be even smaller,
|
||||
// reduced to N_Tile = 32 for 1 single wave.
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr bool Preshuffle = true;
|
||||
};
|
||||
|
||||
struct MxGemmConfig32 : MxGemmConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
};
|
||||
|
||||
struct MXfp4_GemmConfig32 : MxGemmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
struct MXfp8_GemmConfig32 : MxGemmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
// Variant with M/N padding enabled. Used to cover shapes where M/N are not multiples of
|
||||
// the respective block tiles (MX_GemmConfig16 has M_Tile = 64, N_Tile = 128). K is still
|
||||
// required to be a multiple of K_Tile -- the MX comp-async pipeline does not support K padding
|
||||
// (see MXGemmKernel::IsSupportedArgument).
|
||||
struct MXfp8_GemmConfig16_PadMN : MX_GemmConfig16
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
};
|
||||
|
||||
struct MXfp8_GemmConfig16_PermuteN : MXfp8_GemmConfig16_Preshuffle
|
||||
{
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
@@ -1,143 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm_mx.hpp"
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
bool Splitk>
|
||||
float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
|
||||
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_gemm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
constexpr bool IsEightWave =
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8;
|
||||
using MXGemmPipeline = std::conditional_t<
|
||||
GemmConfig::Preshuffle,
|
||||
ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1<MXPipelineProblem>,
|
||||
std::conditional_t<IsEightWave,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsyncEightWaves<MXPipelineProblem>,
|
||||
ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmEpilogue =
|
||||
std::conditional_t<GemmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
false, // FixedVectorSize_ (Default)
|
||||
1>>, // VectorSizeC_ (Default)
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ComputeDataType,
|
||||
ComputeDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
ck_tile::MXEpilogueTraits<GemmConfig>::BlockedXDLNPerWarp,
|
||||
false, // DoubleSmemBuffer_ (Default)
|
||||
ADataType, // AComputeDataType
|
||||
BDataType, // BComputeDataType
|
||||
!GemmConfig::Preshuffle>>>; // TilesPacked_ (because of
|
||||
// packed scales)
|
||||
|
||||
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(std::array<const void*, 1>{args.as_ptr},
|
||||
std::array<const void*, 1>{args.bs_ptr},
|
||||
std::array<const void*, 0>{},
|
||||
args.e_ptr,
|
||||
args.k_batch,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
std::array<ck_tile::index_t, 1>{args.stride_As},
|
||||
std::array<ck_tile::index_t, 1>{args.stride_Bs},
|
||||
std::array<ck_tile::index_t, 0>{},
|
||||
args.stride_E,
|
||||
args.scale_m,
|
||||
args.scale_n);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"MX GEMM: unsupported shape/configuration (set CK_TILE_LOGGING=1 for details).");
|
||||
}
|
||||
|
||||
const auto kernel = ck_tile::make_kernel<Kernel::kBlockPerCu>(
|
||||
Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs);
|
||||
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
}
|
||||
@@ -10,19 +10,23 @@
|
||||
#include "test_mx_gemm_pipeline_util.hpp"
|
||||
#include "test_mx_gemm_pipeline_prec_types.hpp"
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave>;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using CompTDMV1 = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompTDMV1>;
|
||||
using CompTDMV2 = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompTDMV2>;
|
||||
using CompAsync = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompAsync>;
|
||||
using CompEightWaves =
|
||||
ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompEightWaves>;
|
||||
using WeightPreshuffle =
|
||||
ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::WeightPreshuffle>;
|
||||
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
using I512 = ck_tile::number<512>;
|
||||
|
||||
using ClusterEnable = std::true_type;
|
||||
using ClusterDisable = std::false_type;
|
||||
@@ -33,48 +37,89 @@ using ClusterDisable = std::false_type;
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, Scheduler, PipelineType, ScaleBlockSize
|
||||
using KernelTypesMxGemmCompTDMWmma = ::testing::Types<
|
||||
// --- Scale32 (WarpTile=32, CompTDMV1) ---
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
// --- Scale32 (WarpTile=32, CompTDMV2) ---
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F8, E4M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32>,
|
||||
std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E4M3, E4M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F4, E8M0, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F8, E4M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32>,
|
||||
std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32>,
|
||||
// --- Scale16 (WarpTile=16, CompTDMV1) ---
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV1, I16>, // RRR (non-RCR) layout
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV1, I16>, // RRR (non-RCR) layout
|
||||
// --- Scale16 (WarpTile=32, CompTDMV1) ---
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I16>, // RRR (non-RCR) layout
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I16>, // RRR (non-RCR) layout
|
||||
// --- Scale16 (WarpTile=16, CompTDMV2) ---
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, Intrawave, CompTDMV2, I16>, // RRR (non-RCR) layout
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I16, I16, CompTDMV2, I16>, // RRR (non-RCR) layout
|
||||
// --- Scale16 (WarpTile=32, CompTDMV2) ---
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I16>, // RRR (non-RCR) layout
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I16>, // RRR (non-RCR) layout
|
||||
// --- Scale32 cluster launch (from develop; ScaleBlockSize=I32 at idx 16, ClusterEnable at idx 17) ---
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, I32, ClusterEnable>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, I32, ClusterEnable>
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV1, I32, std::false_type, ClusterEnable>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, CompTDMV2, I32, std::false_type, ClusterEnable>
|
||||
>;
|
||||
|
||||
using KernelTypesMxGemmCompAsyncRCR = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I128, I16, I16, CompEightWaves, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I256, I128, I16, I16, CompEightWaves, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I512, I256, I16, I16, WeightPreshuffle, I32>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32, std::true_type>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I256, I256, I16, I16, WeightPreshuffle, I32, std::true_type>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Row, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Row, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>
|
||||
>;
|
||||
|
||||
using KernelTypesMxGemmCompAsyncRCRLargeCases = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I64, I64, I256, I16, I16, CompAsync, I32>
|
||||
>;
|
||||
|
||||
using KernelTypesMxGemmCompAsyncRRR = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Row, Row, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Row, Row, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Row, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>
|
||||
>;
|
||||
|
||||
using KernelTypesMxGemmCompAsyncCRR = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Col, Row, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Col, Row, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>
|
||||
>;
|
||||
|
||||
using KernelTypesMxGemmCompAsyncCCR = ::testing::Types<
|
||||
std::tuple< Col, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Col, Col, Row, F4, F4, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Col, Col, Row, BF8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>,
|
||||
std::tuple< Col, Col, Row, F8, BF8, E8M0, E8M0, F32, F16, I128, I128, I256, I32, I32, CompAsync, I32>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
27
test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc
Normal file
27
test/ck_tile/gemm_mx/test_mx_gemm_pipeline_tr_cases.inc
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, Regular)
|
||||
{
|
||||
std::vector<int> Ms{128, 256};
|
||||
std::vector<int> Ns{128, 256};
|
||||
std::vector<int> Ks;
|
||||
// K must be multiple of ScaleBlockSize (16 or 32) and K_Tile
|
||||
for(auto K_count : {1, 2, 3, 4})
|
||||
{
|
||||
Ks.push_back(K_count * TestFixture::K_Tile);
|
||||
}
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
for(int N : Ns)
|
||||
{
|
||||
for(int K : Ks)
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ TYPED_TEST(TEST_SUITE_NAME, SingleTile)
|
||||
TYPED_TEST(TEST_SUITE_NAME, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 4, 8, 16};
|
||||
constexpr int N = 64;
|
||||
constexpr int N = TestFixture::N_Tile;
|
||||
std::vector<int> Ks;
|
||||
// K must be multiple of ScaleBlockSize (16 or 32) and K_Tile
|
||||
for(auto K_count : {2, 3, 4})
|
||||
@@ -34,7 +34,7 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM)
|
||||
TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{32, 64, 128, 256};
|
||||
std::vector<int> Ns{96, 128}; // 96 tests non-tile-aligned N
|
||||
std::vector<int> Ns{TestFixture::N_Tile};
|
||||
std::vector<int> Ks;
|
||||
for(auto K_count : {2, 3, 4})
|
||||
{
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck/library/utility/gpu_verification.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
@@ -44,9 +45,9 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
#endif
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
return 64;
|
||||
else
|
||||
return 32;
|
||||
return 128;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -68,10 +69,64 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
// Deterministic per-element hash RNG for GPU data init. Returns a float in [-3, 3).
|
||||
// The generic `fill_tensor_uniform_rand_fp_values` filler is NOT valid for ck_tile::pk_fp4_t
|
||||
// (it converts a single float and duplicates it into both nibbles, and special-cases only the
|
||||
// classic ck::f4x2_pk_t). We need two independent fp4 values per byte, so we fill directly.
|
||||
// The narrow [-3,3) range keeps the fp16 GEMM output from overflowing at K up to 4096 (with the
|
||||
// [0.25,1.0] scales used in RunAllGpu, worst case K*9 = 36864 < 65504).
|
||||
__device__ inline float mx_fp4_fill_rand(unsigned int seed, unsigned long long idx)
|
||||
{
|
||||
// splitmix64-style avalanche; deterministic given (seed, idx).
|
||||
unsigned long long z = (idx + 1ULL) * 0x9E3779B97F4A7C15ULL +
|
||||
static_cast<unsigned long long>(seed) * 0xD1B54A32D192ED03ULL;
|
||||
z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL;
|
||||
z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL;
|
||||
z ^= z >> 31;
|
||||
const float u =
|
||||
static_cast<float>((z >> 40) & 0xFFFFFFULL) / static_cast<float>(0x1000000); // [0,1)
|
||||
return u * 6.0f - 3.0f; // [-3,3)
|
||||
}
|
||||
|
||||
// Fill a packed-fp4 buffer with two independent, deterministic random fp4 values per byte.
|
||||
// `num_packed` is the number of pk_fp4_t elements (= total fp4 values / 2).
|
||||
__global__ void
|
||||
fill_pk_fp4_uniform_kernel(ck_tile::pk_fp4_t* __restrict__ ptr, long num_packed, unsigned int seed)
|
||||
{
|
||||
const long idx0 = static_cast<long>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const long nthr = static_cast<long>(gridDim.x) * blockDim.x;
|
||||
for(long i = idx0; i < num_packed; i += nthr)
|
||||
{
|
||||
const float lo_f = rintf(mx_fp4_fill_rand(seed, static_cast<unsigned long long>(i) * 2ULL));
|
||||
const float hi_f =
|
||||
rintf(mx_fp4_fill_rand(seed, static_cast<unsigned long long>(i) * 2ULL + 1ULL));
|
||||
const auto lo = ck_tile::float_to_mxfp4(lo_f, 1.0f);
|
||||
const auto hi = ck_tile::float_to_mxfp4(hi_f, 1.0f);
|
||||
ptr[i] = ck_tile::pk_fp4_t::_pack(lo, hi);
|
||||
}
|
||||
}
|
||||
|
||||
inline void fill_pk_fp4_uniform(ck_tile::pk_fp4_t* ptr,
|
||||
long num_packed,
|
||||
unsigned int seed,
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
constexpr int threads = 256;
|
||||
constexpr long max_blocks = 65536; // grid-stride cap
|
||||
const long needed = (num_packed + threads - 1) / threads;
|
||||
const long blocks = needed < max_blocks ? needed : max_blocks;
|
||||
fill_pk_fp4_uniform_kernel<<<dim3(static_cast<unsigned>(blocks)), dim3(threads), 0, stream>>>(
|
||||
ptr, num_packed, seed);
|
||||
ck_tile::hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
enum struct MxGemmPipelineType
|
||||
{
|
||||
CompTDMV1,
|
||||
CompTDMV2
|
||||
CompTDMV2,
|
||||
CompAsync,
|
||||
CompEightWaves,
|
||||
WeightPreshuffle
|
||||
};
|
||||
|
||||
template <MxGemmPipelineType PT, typename Problem>
|
||||
@@ -95,88 +150,93 @@ struct MxGemmPipelineTypeSelector<MxGemmPipelineType::CompTDMV2, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; }
|
||||
};
|
||||
|
||||
template <MxGemmPipelineType PT, typename Problem>
|
||||
template <typename Problem>
|
||||
struct MxGemmPipelineTypeSelector<MxGemmPipelineType::CompAsync, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; }
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmPipelineTypeSelector<MxGemmPipelineType::CompEightWaves, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompEightWaves"; }
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmPipelineTypeSelector<MxGemmPipelineType::WeightPreshuffle, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using pipeline = ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; }
|
||||
};
|
||||
|
||||
template <MxGemmPipelineType PT, typename Problem, bool PermuteN>
|
||||
struct MxGemmEpilogueTypeSelector
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompTDMV1, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::TdmEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompTDMV2, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::TdmEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompAsync, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::CShuffleEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompEightWaves, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::CShuffleEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem, bool PermuteN>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::WeightPreshuffle, Problem, PermuteN>
|
||||
{
|
||||
using epilogue = std::conditional_t<PermuteN,
|
||||
ck_tile::PermuteNEpilogue<Problem>,
|
||||
ck_tile::CShuffleEpilogue<Problem>>;
|
||||
};
|
||||
|
||||
template <MxGemmPipelineType PT>
|
||||
struct MxGemmPipelineDefaultParams
|
||||
{
|
||||
static constexpr bool PadM = false;
|
||||
static constexpr bool PadN = false;
|
||||
static constexpr bool PadK = false;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool Preshuffle = PT == MxGemmPipelineType::WeightPreshuffle;
|
||||
};
|
||||
|
||||
/// @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction.
|
||||
///
|
||||
/// Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific
|
||||
/// layout expected by the gfx1250 wmma instruction.
|
||||
///
|
||||
/// @tparam ScaleType Scale data type (e.g., e8m0_t)
|
||||
/// @tparam ScaleBlockSize The block size for microscaling (e.g., 32)
|
||||
/// @tparam KStride Whether K is the fast-moving dimension
|
||||
template <typename ScaleType, ck_tile::index_t ScaleBlockSize, bool KStride>
|
||||
void preShuffleScaleBuffer_gfx1250(const ScaleType* src,
|
||||
ScaleType* dst,
|
||||
ck_tile::index_t MN,
|
||||
ck_tile::index_t K)
|
||||
template <ck_tile::index_t N_Warp_Tile_,
|
||||
ck_tile::index_t K_Warp_Tile_,
|
||||
ck_tile::index_t N_Tile_,
|
||||
ck_tile::index_t N_Warp_,
|
||||
typename BDataType_>
|
||||
struct Config
|
||||
{
|
||||
static_assert((ScaleBlockSize == 32 || ScaleBlockSize == 16) && sizeof(ScaleType) == 1,
|
||||
"wrong! only support 8-bit scale with ScaleBlockSize=32 or 16");
|
||||
|
||||
// ScaleBlockSize == 16: the natural row-major scale layout already matches the gfx1250
|
||||
// wmma scale distribution (one e8m0 per 16 K-elements lands warp-aligned), so the
|
||||
// device-side shuffle is the identity transform for all K.
|
||||
if constexpr(ScaleBlockSize == 16)
|
||||
{
|
||||
for(ck_tile::index_t mn = 0; mn < MN; ++mn)
|
||||
for(ck_tile::index_t k = 0; k < K; ++k)
|
||||
{
|
||||
if constexpr(KStride)
|
||||
dst[mn * K + k] = src[mn * K + k];
|
||||
else
|
||||
dst[mn * K + k] = src[k * MN + mn];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr ck_tile::index_t MPerXdlops = 16;
|
||||
constexpr ck_tile::index_t KPerXdlops = 128;
|
||||
|
||||
int MNPack = 2;
|
||||
int KPack = 1;
|
||||
|
||||
int MNStep = MPerXdlops;
|
||||
int KStep = KPerXdlops / ScaleBlockSize;
|
||||
|
||||
int K0 = K / KPack / KStep;
|
||||
|
||||
for(int mn = 0; mn < MN; ++mn)
|
||||
{
|
||||
int iMNRepeat = mn / (MNStep * MNPack);
|
||||
int tempmn = mn % (MNStep * MNPack);
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int iKRepeat = k / (KStep * KPack);
|
||||
int tempk = k % (KStep * KPack);
|
||||
|
||||
int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) +
|
||||
(iKRepeat * KStep * KPack) * (MNStep * MNPack) +
|
||||
tempmn * (KStep * KPack) + tempk;
|
||||
|
||||
if constexpr(KStride)
|
||||
{
|
||||
dst[outputIndex] = src[mn * K + k];
|
||||
}
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + mn];
|
||||
}
|
||||
}
|
||||
}
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_;
|
||||
static constexpr ck_tile::index_t N_Tile = N_Tile_;
|
||||
static constexpr ck_tile::index_t N_Warp = N_Warp_;
|
||||
static constexpr ck_tile::index_t BContiguousItemsPerAccess =
|
||||
std::is_same_v<BDataType_, ck_tile::pk_fp4_t> ? 32 : 16;
|
||||
};
|
||||
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
@@ -191,8 +251,10 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
using BScaleDataType = std::tuple_element_t<6, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<14, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<15, Tuple>::value;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value;
|
||||
static constexpr bool PermuteN =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 16, std::false_type>::value;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<9, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<10, Tuple>{};
|
||||
@@ -213,17 +275,21 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
static constexpr bool ClusterLaunch =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 17, std::false_type>::value;
|
||||
|
||||
static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<16, Tuple>{};
|
||||
static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<15, Tuple>{};
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp =
|
||||
PipelineType == MxGemmPipelineType::WeightPreshuffle
|
||||
? 1
|
||||
: (PipelineType == MxGemmPipelineType::CompEightWaves ? 4 : 2);
|
||||
static constexpr ck_tile::index_t N_Warp =
|
||||
PipelineType == MxGemmPipelineType::WeightPreshuffle ? 4 : 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
protected:
|
||||
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_mx_gemm(const ck_tile::MxGemmHostArgs<1, 1, 0>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
// if cluster launch is enabled, set cluster dim to 2x2x1
|
||||
constexpr ck_tile::index_t kClusterSizeM =
|
||||
std::conditional_t<ClusterLaunch, ck_tile::number<2>, ck_tile::number<1>>{};
|
||||
@@ -240,11 +306,13 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer
|
||||
|
||||
#if defined(CK_USE_GFX1250)
|
||||
constexpr ck_tile::index_t BlockedXDLNPerWarp = 1;
|
||||
constexpr bool TransposeC =
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
|
||||
M_Warp_Tile == N_Warp_Tile;
|
||||
#else
|
||||
constexpr bool TransposeC = false;
|
||||
#elif defined(CK_USE_GFX950)
|
||||
constexpr ck_tile::index_t BlockedXDLNPerWarp = Preshuffle ? 2 : 1;
|
||||
constexpr bool TransposeC = false;
|
||||
#endif
|
||||
static constexpr bool StructuredSparsity = false;
|
||||
static constexpr bool NumWaveGroup = 1;
|
||||
@@ -302,8 +370,26 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
using GemmPipeline =
|
||||
typename MxGemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
|
||||
|
||||
using GemmEpilogue = typename MxGemmEpilogueTypeSelector<
|
||||
PipelineType,
|
||||
using GemmEpilogueProblem = std::conditional_t<
|
||||
PipelineType == MxGemmPipelineType::WeightPreshuffle && PermuteN,
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
false, /*FixedVectorSize_*/
|
||||
1>, /*VectorSizeC_*/
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
@@ -320,13 +406,18 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
1, /*kNumWaveGroups_*/
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
DoubleSmemBuffer, /*DoubleSmemBuffer*/
|
||||
AComputeDataType, /*AComputeDataType_*/
|
||||
BComputeDataType /*BComputeDataType_*/>>::epilogue;
|
||||
1, /*kNumWaveGroups_*/
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
BlockedXDLNPerWarp, /*BlockedXDLN_PerWarp_*/
|
||||
DoubleSmemBuffer, /*DoubleSmemBuffer*/
|
||||
AComputeDataType, /*AComputeDataType_*/
|
||||
BComputeDataType, /*BComputeDataType_*/
|
||||
!preshuffle>>;
|
||||
|
||||
using GemmEpilogue = typename MxGemmEpilogueTypeSelector<PipelineType,
|
||||
GemmEpilogueProblem,
|
||||
PermuteN>::epilogue;
|
||||
|
||||
using Kernel = ck_tile::MxGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
@@ -360,12 +451,28 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if constexpr(!Derived::check_data_type())
|
||||
{
|
||||
GTEST_SKIP() << "Unsupported data type combination for mx_gemm pipeline test.";
|
||||
}
|
||||
// for TDM it's used tdm_epilogue which don't support split-k
|
||||
if constexpr(PipelineType == MxGemmPipelineType::CompTDMV1 ||
|
||||
PipelineType == MxGemmPipelineType::CompTDMV2 ||
|
||||
std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor> ||
|
||||
std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// Only do k_batch = 1
|
||||
k_batches_ = {1};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, use k_batch = 1 and 2
|
||||
k_batches_ = {1, 2};
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PadM = MxGemmPipelineDefaultParams<PipelineType>::PadM,
|
||||
@@ -381,7 +488,15 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
{
|
||||
if constexpr(Derived::check_data_type())
|
||||
{
|
||||
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, 1);
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
// skip test when split k' number is not evenly distributed
|
||||
if((K / K_Tile) % kb != 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -422,16 +537,18 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
// so M must be padded to at least MNPack * MPerXdlops = 32.
|
||||
constexpr index_t ScaleShuffleAlign = 32;
|
||||
const index_t scale_padded_M = integer_least_multiple(
|
||||
static_cast<index_t>(M),
|
||||
static_cast<index_t>(ck_tile::max(M_Warp_Tile, ScaleShuffleAlign)));
|
||||
static_cast<index_t>(M), static_cast<index_t>(ck_tile::max(M_Tile, ScaleShuffleAlign)));
|
||||
|
||||
HostTensor<AScaleDataType> scale_a(
|
||||
{static_cast<std::size_t>(scale_padded_M), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
|
||||
// scale_b uses N as first dimension (col-major like B)
|
||||
const index_t scale_padded_N = integer_least_multiple(
|
||||
static_cast<index_t>(N), static_cast<index_t>(ck_tile::max(N_Tile, ScaleShuffleAlign)));
|
||||
// Pre-shuffle interleaves 2 K-lanes (MNPack=2) with MPerXdlops=16 stride,
|
||||
// so N must be padded to at least MNPack * NPerXdlops = 32.
|
||||
HostTensor<BScaleDataType> scale_b(
|
||||
{static_cast<std::size_t>(N), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(scale_padded_N), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
|
||||
// Fill data
|
||||
@@ -485,38 +602,112 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
}
|
||||
|
||||
// Pre-shuffle scale buffers for the hardware
|
||||
#if defined(CK_USE_GFX1250)
|
||||
static constexpr index_t NXdlPackEff = 1;
|
||||
|
||||
HostTensor<AScaleDataType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(scale_padded_M), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
|
||||
HostTensor<BScaleDataType> scale_b_shuffled(
|
||||
{static_cast<std::size_t>(N), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(scale_padded_N), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
|
||||
// Pre-shuffle for gfx1250 (WaveSize=32, WMMA)
|
||||
// Scales start in natural tensor layout and are pre-shuffled into the device layout
|
||||
// for both scale block sizes (the shuffle is the identity for ScaleBlockSize==16,
|
||||
// whose natural layout already matches the warp scale distribution).
|
||||
preShuffleScaleBuffer_gfx1250<AScaleDataType, ScaleBlockSize, true>(
|
||||
ck_tile::preShuffleScaleBuffer_gfx1250<AScaleDataType, ScaleBlockSize, true>(
|
||||
scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k);
|
||||
preShuffleScaleBuffer_gfx1250<BScaleDataType, ScaleBlockSize, true>(
|
||||
scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k);
|
||||
ck_tile::preShuffleScaleBuffer_gfx1250<BScaleDataType, ScaleBlockSize, true>(
|
||||
scale_b.mData.data(), scale_b_shuffled.mData.data(), scale_padded_N, num_scale_k);
|
||||
#elif defined(CK_USE_GFX950)
|
||||
constexpr ck_tile::index_t MPerXdl = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl = N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl = K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl);
|
||||
constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl);
|
||||
constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff =
|
||||
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff =
|
||||
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff =
|
||||
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
|
||||
constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
HostTensor<AScaleDataType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(scale_padded_M / MXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2), static_cast<std::size_t>(1)});
|
||||
|
||||
HostTensor<BScaleDataType> scale_b_shuffled(
|
||||
{static_cast<std::size_t>(scale_padded_N / NXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2), static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::preShuffleScaleBuffer_gfx950<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k, true);
|
||||
|
||||
if constexpr(PipelineType == MxGemmPipelineType::WeightPreshuffle && PermuteN)
|
||||
{
|
||||
ck_tile::preShuffleScaleBufferPermuteN_gfx950<N_Warp, N_Tile, XdlMNThread>(
|
||||
scale_b.mData.data(),
|
||||
scale_b_shuffled.mData.data(),
|
||||
scale_padded_N,
|
||||
num_scale_k,
|
||||
true);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
preShuffleScaleBuffer_gfx950<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_b.mData.data(),
|
||||
scale_b_shuffled.mData.data(),
|
||||
scale_padded_N,
|
||||
num_scale_k,
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Allocate device memory
|
||||
DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
|
||||
// Upload data to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
using GemmConfig = Config<N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp, BDataType>;
|
||||
|
||||
const auto b_host_for_dev = [&]() {
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
if constexpr(PermuteN)
|
||||
{
|
||||
return ck_tile::shuffle_b_permuteN<GemmConfig, BDataType, NXdlPackEff>(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::shuffle_b<GemmConfig>(b_k_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return b_k_n;
|
||||
}
|
||||
}();
|
||||
DeviceMem b_k_n_dev_buf(b_host_for_dev.get_element_space_size_in_bytes());
|
||||
b_k_n_dev_buf.ToDevice(b_host_for_dev.data());
|
||||
|
||||
// Create MxGemmHostArgs
|
||||
ck_tile::MxGemmHostArgs<1, 1, 0> args(
|
||||
{static_cast<const void*>(a_m_k_dev_buf.GetDeviceBuffer())},
|
||||
@@ -546,7 +737,14 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
{static_cast<std::size_t>(1), static_cast<std::size_t>(num_scale_k)});
|
||||
// Copy scale_b data (our scale_b is (N, num_scale_k) row-major,
|
||||
// reference expects (num_scale_k, N) col-major, which is the same memory layout)
|
||||
std::copy(scale_b.mData.begin(), scale_b.mData.end(), scale_b_ref.mData.begin());
|
||||
// Truncate scale_a to actual N (not padded)
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < num_scale_k; ++k)
|
||||
{
|
||||
scale_b_ref(k, n) = scale_b(n, k);
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate scale_a to actual M (not padded)
|
||||
HostTensor<AScaleDataType> scale_a_ref(
|
||||
@@ -582,4 +780,275 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
rtol_atol.at(number<1>{}));
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
// All-GPU validation path for the fp4 (pk_fp4_t) MX GEMM.
|
||||
//
|
||||
// Unlike Run(), this never materializes the A/B/C tensors on the host:
|
||||
// - A/B are generated directly on device with a deterministic fp4 fill.
|
||||
// - the reference is computed on device by reference_mx_gemm_gpu.
|
||||
// - the comparison is done on device by ck::profiler::gpu_verify.
|
||||
// Only the tiny e8m0 scales touch the host (for pre-shuffle + an unshuffled copy that the
|
||||
// device reference consumes).
|
||||
void RunAllGpu(const int M, const int N, const int K, const int kbatch = 1)
|
||||
{
|
||||
if constexpr(!Derived::check_data_type())
|
||||
return;
|
||||
|
||||
static_assert(std::is_same_v<ADataType, ck_tile::pk_fp4_t> &&
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
|
||||
"RunAllGpu currently supports pk_fp4_t A/B only.");
|
||||
// The GPU reference (reference_mx_gemm_gpu) hardcodes these layouts; guard so it cannot be
|
||||
// silently misused with a layout it does not handle.
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor> &&
|
||||
std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor> &&
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>,
|
||||
"RunAllGpu / reference_mx_gemm_gpu assume RowMajor-A, ColumnMajor-B, "
|
||||
"RowMajor-C.");
|
||||
|
||||
static_assert(PipelineType != MxGemmPipelineType::WeightPreshuffle);
|
||||
|
||||
#if !defined(CK_USE_GFX950)
|
||||
(void)M;
|
||||
(void)N;
|
||||
(void)K;
|
||||
(void)kbatch;
|
||||
GTEST_SKIP() << "RunAllGpu requires CK_USE_GFX950.";
|
||||
#else
|
||||
using namespace ck_tile::literals;
|
||||
constexpr long kIntMax = 2147483647L; // INT_MAX
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return col;
|
||||
else
|
||||
return row;
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
constexpr ck_tile::index_t psize = ck_tile::numeric_traits<ADataType>::PackedSize; // 2
|
||||
static_assert(psize == 2,
|
||||
"RunAllGpu byte-sizing and reference_mx_gemm_kernel's a_ptr[a_lin/2] "
|
||||
"addressing assume pk_fp4_t PackedSize == 2.");
|
||||
|
||||
bool pass = true;
|
||||
long total_MN = 0;
|
||||
|
||||
// Strides are K/N here (small); keep them as index_t to match the kernel args, and
|
||||
// make the size_t->index_t narrowing explicit.
|
||||
const ck_tile::index_t stride_A =
|
||||
static_cast<ck_tile::index_t>(f_get_default_stride(M, K, 0, ALayout{})); // K
|
||||
const ck_tile::index_t stride_B =
|
||||
static_cast<ck_tile::index_t>(f_get_default_stride(K, N, 0, BLayout{})); // K
|
||||
const ck_tile::index_t stride_C =
|
||||
static_cast<ck_tile::index_t>(f_get_default_stride(M, N, 0, CLayout{})); // N
|
||||
|
||||
ASSERT_EQ(K % ScaleBlockSize, 0) << "K must be a multiple of ScaleBlockSize for MX GEMM";
|
||||
const ck_tile::index_t num_scale_k = K / ScaleBlockSize;
|
||||
ASSERT_EQ(num_scale_k % (K_Warp_Tile / ScaleBlockSize), 0)
|
||||
<< "K must be a multiple of K_Warp_Tile (" << K_Warp_Tile
|
||||
<< ") for MX GEMM. Pad the scale data.";
|
||||
const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple(
|
||||
static_cast<ck_tile::index_t>(M), static_cast<ck_tile::index_t>(M_Tile));
|
||||
|
||||
// int32-safety: the property under test for the M-decomposition. The predicate is
|
||||
// "largest 0-based element offset fits in a signed 32-bit int", i.e. offset <= INT_MAX.
|
||||
const long MN = static_cast<long>(M) * N;
|
||||
const long A_elems = static_cast<long>(M) * K;
|
||||
const long B_elems = static_cast<long>(K) * N;
|
||||
const long C_off = static_cast<long>(M - 1) * stride_C + (N - 1);
|
||||
const long A_off = static_cast<long>(M - 1) * stride_A + (K - 1);
|
||||
const long B_off = static_cast<long>(N - 1) * stride_B + (K - 1);
|
||||
const long c_bytes = MN * static_cast<long>(sizeof(CDataType));
|
||||
std::cout << "[int32-safety] M=" << M << " N=" << N << " K=" << K << " M*N=" << MN
|
||||
<< " A_elems=" << A_elems << " B_elems=" << B_elems << " C_off=" << C_off
|
||||
<< " A_off=" << A_off << " B_off=" << B_off << " C_bytes=" << c_bytes
|
||||
<< " (INT_MAX=" << kIntMax << ")" << std::endl;
|
||||
// Note (not an assert): the C *byte* span can exceed INT_MAX even when the element
|
||||
// count is int32-safe. We deliberately let the run proceed -- if any internal byte
|
||||
// offset overflows, gpu_verify will flag it, which is exactly what we want to discover.
|
||||
if(c_bytes > kIntMax)
|
||||
std::cout << "[int32-safety][note] C byte span (" << c_bytes
|
||||
<< ") exceeds INT_MAX; if verification fails, byte-offset overflow is the "
|
||||
"prime suspect."
|
||||
<< std::endl;
|
||||
ASSERT_LE(B_off, kIntMax) << "B offset exceeds INT_MAX";
|
||||
total_MN += MN;
|
||||
|
||||
const long a_bytes = (A_elems + psize - 1) / psize;
|
||||
const long b_bytes = (B_elems + psize - 1) / psize;
|
||||
|
||||
// Bound peak device memory (A + B + 2*C + scales). Skip cleanly rather
|
||||
// than aborting via hip_check_error if the device cannot hold test shapes.
|
||||
{
|
||||
std::size_t free_b = 0, total_b = 0;
|
||||
ck_tile::hip_check_error(hipMemGetInfo(&free_b, &total_b));
|
||||
const std::size_t need = static_cast<std::size_t>(a_bytes) +
|
||||
static_cast<std::size_t>(b_bytes) +
|
||||
2u * static_cast<std::size_t>(c_bytes) + (64u << 20);
|
||||
if(free_b < need)
|
||||
GTEST_SKIP() << "insufficient device memory (need " << need << " B, free " << free_b
|
||||
<< " B)";
|
||||
}
|
||||
|
||||
auto a_dev = std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(a_bytes));
|
||||
auto b_dev = std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(b_bytes));
|
||||
auto c_dev = std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(c_bytes));
|
||||
auto c_ref_dev = std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(c_bytes));
|
||||
c_dev->SetZero();
|
||||
c_ref_dev->SetZero();
|
||||
|
||||
// GPU fill A/B (deterministic, fp4-correct). Same device buffers feed both the kernel
|
||||
// and the reference, so the fill need not bit-match any host RNG.
|
||||
fill_pk_fp4_uniform(
|
||||
reinterpret_cast<ADataType*>(a_dev->GetDeviceBuffer()), a_bytes, 11939u);
|
||||
fill_pk_fp4_uniform(
|
||||
reinterpret_cast<BDataType*>(b_dev->GetDeviceBuffer()), b_bytes, 11940u);
|
||||
ck_tile::hip_check_error(hipDeviceSynchronize()); // surface fill faults at the fill site
|
||||
|
||||
// e8m0 scales (tiny, host-built). The range is
|
||||
// deliberately narrow ([0.25,1.0] scales, [-3,3) fp4 fill) so that K up to 4096 cannot
|
||||
// overflow the fp16 output (worst case K*9 = 36864 < 65504); gpu_verify counts matched
|
||||
// infinities as errors, so an overflow would otherwise be a false failure.
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a(
|
||||
{static_cast<std::size_t>(scale_padded_M), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b(
|
||||
{static_cast<std::size_t>(N), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
{
|
||||
std::mt19937 gen(11941u);
|
||||
std::uniform_real_distribution<float> dist(0.25f, 1.0f);
|
||||
for(auto& s : scale_a.mData)
|
||||
s = AScaleDataType{dist(gen)};
|
||||
for(auto& s : scale_b.mData)
|
||||
s = BScaleDataType{dist(gen)};
|
||||
}
|
||||
|
||||
// gfx950 scale pre-shuffle. NOTE: this must stay in sync with the identical block in
|
||||
// Run() -- the kernel-input layout and the reference-input layout must agree.
|
||||
constexpr ck_tile::index_t MPerXdl = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl = N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl = K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl);
|
||||
constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl);
|
||||
constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff =
|
||||
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff =
|
||||
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff =
|
||||
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
|
||||
constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(scale_padded_M / MXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2), static_cast<std::size_t>(1)});
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b_shuffled(
|
||||
{static_cast<std::size_t>(N / NXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2), static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::preShuffleScaleBuffer_gfx950<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k, true);
|
||||
ck_tile::preShuffleScaleBuffer_gfx950<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true);
|
||||
|
||||
// Device scale buffers: shuffled feed the kernel, unshuffled feed the reference.
|
||||
auto scale_a_shuf_dev = std::make_unique<ck_tile::DeviceMem>(
|
||||
scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
auto scale_b_shuf_dev = std::make_unique<ck_tile::DeviceMem>(
|
||||
scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
scale_a_shuf_dev->ToDevice(scale_a_shuffled.data());
|
||||
scale_b_shuf_dev->ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_ref_dev =
|
||||
std::make_unique<ck_tile::DeviceMem>(scale_a.get_element_space_size_in_bytes());
|
||||
auto scale_b_ref_dev =
|
||||
std::make_unique<ck_tile::DeviceMem>(scale_b.get_element_space_size_in_bytes());
|
||||
scale_a_ref_dev->ToDevice(scale_a.data());
|
||||
scale_b_ref_dev->ToDevice(scale_b.data());
|
||||
|
||||
// Launch kernel
|
||||
ck_tile::MxGemmHostArgs<1, 1, 0> args(
|
||||
{static_cast<const void*>(a_dev->GetDeviceBuffer())},
|
||||
{static_cast<const void*>(scale_a_shuf_dev->GetDeviceBuffer())},
|
||||
{static_cast<const void*>(b_dev->GetDeviceBuffer())},
|
||||
{static_cast<const void*>(scale_b_shuf_dev->GetDeviceBuffer())},
|
||||
{},
|
||||
c_dev->GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_A},
|
||||
{stride_B},
|
||||
{},
|
||||
stride_C);
|
||||
|
||||
invoke_mx_gemm<false, false, false, false>(args, ck_tile::stream_config{nullptr, false});
|
||||
|
||||
ck_tile::hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
// GPU reference on the same device A/B buffers.
|
||||
ck_tile::reference_mx_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AScaleDataType,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
reinterpret_cast<const ADataType*>(a_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<const BDataType*>(b_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<const AScaleDataType*>(scale_a_ref_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<const BScaleDataType*>(scale_b_ref_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<CDataType*>(c_ref_dev->GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
num_scale_k,
|
||||
ScaleBlockSize);
|
||||
ck_tile::hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
// GPU verify with explicit MX tolerance (auto tolerance defaults too tight for MX).
|
||||
const float max_acc = ck::profiler::gpu_reduce_max<CDataType>(c_ref_dev->GetDeviceBuffer(),
|
||||
static_cast<std::size_t>(MN));
|
||||
// The reference must be non-degenerate, else error_count==0 is a vacuous pass.
|
||||
ASSERT_GT(max_acc, 0.0f) << "GPU reference output is all-zero";
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(K, kbatch, max_acc);
|
||||
const auto res = ck::profiler::gpu_verify<CDataType>(c_dev->GetDeviceBuffer(),
|
||||
c_ref_dev->GetDeviceBuffer(),
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}),
|
||||
static_cast<std::size_t>(MN));
|
||||
|
||||
// Positive liveness check on the *device* output. res.all_zero ANDs device- and
|
||||
// reference-zeroness, and the reference is never zero here, so it cannot detect a no-op
|
||||
// kernel on its own -- reduce the device buffer directly.
|
||||
const float c_dev_absmax = ck::profiler::gpu_reduce_max<CDataType>(
|
||||
c_dev->GetDeviceBuffer(), static_cast<std::size_t>(MN));
|
||||
|
||||
std::cout << "[verify] errors=" << res.error_count << " max_error=" << res.max_error
|
||||
<< " c_dev_absmax=" << c_dev_absmax << " max_acc=" << max_acc
|
||||
<< " rtol=" << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " atol=" << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
|
||||
EXPECT_EQ(res.error_count, 0ull) << "produced mismatched results";
|
||||
EXPECT_GT(c_dev_absmax, 0.0f) << "produced an all-zero device output";
|
||||
pass &= (res.error_count == 0 && c_dev_absmax > 0.0f);
|
||||
|
||||
std::cout << "[int32-safety] aggregate total_M*N=" << total_MN << " (INT_MAX=" << kIntMax
|
||||
<< ") -> decomposition is the variable under test" << std::endl;
|
||||
EXPECT_TRUE(pass);
|
||||
#endif // CK_USE_GFX950
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "ck_tile/host/mx_processing.hpp"
|
||||
#include "test_mx_gemm_config.hpp"
|
||||
#include "test_mx_gemm_instance.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<
|
||||
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(K);
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value, K);
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestMxGemmUtil : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using GemmConfig = std::tuple_element_t<2, Tuple>;
|
||||
using ALayout = std::tuple_element_t<3, Tuple>;
|
||||
using BLayout = std::tuple_element_t<4, Tuple>;
|
||||
using CLayout = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::fp16_t;
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
|
||||
void
|
||||
Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t k_batch = 1)
|
||||
{
|
||||
const ck_tile::index_t scale_k_size = K / 32;
|
||||
const ck_tile::index_t stride_A =
|
||||
ck_tile::get_default_stride(M, K, 0, is_row_major(ALayout{}));
|
||||
const ck_tile::index_t stride_B =
|
||||
ck_tile::get_default_stride(K, N, 0, is_row_major(BLayout{}));
|
||||
const ck_tile::index_t stride_C =
|
||||
ck_tile::get_default_stride(M, N, 0, is_row_major(CLayout{}));
|
||||
// Scales use fixed layouts independent of A/B layout:
|
||||
// scale A is row-major [M, K/32], and scale B is column-major [K/32, N].
|
||||
const ck_tile::index_t stride_scale_a =
|
||||
ck_tile::get_default_stride(M, scale_k_size, 0, ck_tile::bool_constant<true>{});
|
||||
const ck_tile::index_t stride_scale_b =
|
||||
ck_tile::get_default_stride(scale_k_size, N, 0, ck_tile::bool_constant<false>{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<CDataType> c_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::HostTensor<ScaleType> scale_a_host(ck_tile::host_tensor_descriptor(
|
||||
M, scale_k_size, stride_scale_a, ck_tile::bool_constant<true>{}));
|
||||
ck_tile::HostTensor<ScaleType> scale_b_host(ck_tile::host_tensor_descriptor(
|
||||
scale_k_size, N, stride_scale_b, ck_tile::bool_constant<false>{}));
|
||||
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
|
||||
// e8m0_t is basically an exponent of float32
|
||||
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{
|
||||
range_min, range_max, fill_seed(gen)}(pow2);
|
||||
scales.ForEach([&](auto& self, const auto& i) {
|
||||
self(i) = static_cast<ScaleType>(std::exp2(pow2(i)));
|
||||
});
|
||||
};
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, fill_seed(gen)}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, fill_seed(gen)}(b_host);
|
||||
gen_scales(scale_a_host, -2, 2);
|
||||
gen_scales(scale_b_host, -2, 2);
|
||||
|
||||
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
|
||||
constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl = GemmConfig::K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp =
|
||||
GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl);
|
||||
constexpr ck_tile::index_t NIterPerWarp =
|
||||
GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl);
|
||||
constexpr ck_tile::index_t KIterPerWarp = GemmConfig::K_Tile / KPerXdl;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff =
|
||||
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff =
|
||||
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff =
|
||||
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
|
||||
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
// Pack scales into int32_t for GPU consumption
|
||||
auto scale_a_packed =
|
||||
ck_tile::packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host,
|
||||
true);
|
||||
auto scale_b_packed =
|
||||
ck_tile::packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host,
|
||||
false);
|
||||
|
||||
const auto b_host_for_device = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
return ck_tile::shuffle_b_permuteN<GemmConfig, BDataType, NXdlPackEff>(b_host);
|
||||
else
|
||||
return ck_tile::shuffle_b<GemmConfig>(b_host);
|
||||
else
|
||||
return b_host;
|
||||
}();
|
||||
|
||||
const auto scale_a_host_for_device = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
return ck_tile::preShuffleScale<GemmConfig::N_Warp_Tile>(scale_a_host, true);
|
||||
else
|
||||
return scale_a_packed;
|
||||
}();
|
||||
|
||||
constexpr ck_tile::index_t XdlNThread = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerBlock = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t NWarp = GemmConfig::N_Warp;
|
||||
|
||||
const auto scale_b_host_for_device = [&]() {
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
return ck_tile::preShuffleScalePermuteN<NWarp, NPerBlock, XdlNThread>(
|
||||
scale_b_host, false);
|
||||
else
|
||||
return ck_tile::preShuffleScale<XdlNThread>(scale_b_host, false);
|
||||
else
|
||||
return scale_b_packed;
|
||||
}();
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host_for_device.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(
|
||||
scale_a_host_for_device.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(
|
||||
scale_b_host_for_device.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_dev_buf.ToDevice(b_host_for_device.data());
|
||||
c_dev_buf.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_host_for_device.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_host_for_device.data());
|
||||
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
MXGemmHostArgs<ScaleM, ScaleN> args(a_dev_buf.GetDeviceBuffer(),
|
||||
b_dev_buf.GetDeviceBuffer(),
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
scale_m,
|
||||
scale_n);
|
||||
|
||||
mx_gemm_calc<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
true,
|
||||
false>(args, ck_tile::stream_config{});
|
||||
|
||||
c_dev_buf.FromDevice(c_host.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_ref.SetZero();
|
||||
ck_tile::
|
||||
reference_mx_gemm<ADataType, BDataType, ScaleType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_host, c_ref, scale_a_host, scale_b_host);
|
||||
|
||||
const float max_accumulated_value = ck_tile::type_convert<float>(c_ref.max());
|
||||
const auto rtol_atol = calculate_rtol_atol_mx<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, max_accumulated_value);
|
||||
const double rtol = rtol_atol.at(ck_tile::number<0>{});
|
||||
const double atol = rtol_atol.at(ck_tile::number<1>{});
|
||||
|
||||
bool pass = ck_tile::check_err(c_host, c_ref, "MX GEMM: Incorrect results!", rtol, atol);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
@@ -11,13 +11,33 @@ endif()
|
||||
|
||||
# Currently TDM is only supported on gfx1250
|
||||
if(GPU_TARGETS MATCHES "gfx1250")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_mx_tdm test_mx_grouped_gemm.cpp)
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_mx_tdm test_mx_grouped_gemm_wmma_tdm.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_mx_tdm PRIVATE --save-temps)
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_mx_flatmm_tdm test_grouped_gemm_mx_flatmm_tdm.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_mx_flatmm_tdm
|
||||
PRIVATE ${GROUPED_MX_FLATMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_mx_comp_async test_mx_grouped_gemm_comp_async.cpp)
|
||||
# target_compile_options(test_ck_tile_grouped_gemm_mx_tdm PRIVATE --save-temps)
|
||||
|
||||
# Large-tensor / decomposition cases. Built so it does not bitrot, but NOT
|
||||
# registered with ctest (plain add_executable, no add_test) -> excluded from the default CI
|
||||
# test pass; run explicitly (mirrors the CK *_large_cases convention). Allocates multi-GB
|
||||
# device buffers to exercise the int32 element-count / decomposition boundary.
|
||||
set(_mx_large_cases_src test_mx_grouped_gemm_comp_async_large_cases.cpp)
|
||||
set_source_files_properties(${_mx_large_cases_src} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(test_ck_tile_grouped_gemm_mx_comp_async_large_cases ${_mx_large_cases_src})
|
||||
set_property(TARGET test_ck_tile_grouped_gemm_mx_comp_async_large_cases
|
||||
PROPERTY HIP_ARCHITECTURES ${SUPPORTED_GPU_TARGETS})
|
||||
target_compile_options(test_ck_tile_grouped_gemm_mx_comp_async_large_cases
|
||||
PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_ck_tile_grouped_gemm_mx_comp_async_large_cases
|
||||
PRIVATE gtest_main getopt::getopt)
|
||||
add_dependencies(tests test_ck_tile_grouped_gemm_mx_comp_async_large_cases)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx950|gfx1250")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_mx_flatmm_non_tdm
|
||||
test_grouped_gemm_mx_flatmm_non_tdm.cpp)
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_mx_grouped_gemm_util.hpp"
|
||||
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave>;
|
||||
using CompTDMV1 = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompTDMV1>;
|
||||
using CompTDMV2 = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompTDMV2>;
|
||||
template <ck_tile::index_t N>
|
||||
using ScaleBS = ck_tile::integral_constant<ck_tile::index_t, N>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, Persistent, Scheduler, PipelineType, ScaleBlockSize
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>,
|
||||
std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, Intrawave, CompTDMV2, ScaleBS<32>>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGroupedGemm, KernelTypes);
|
||||
|
||||
#include "test_mx_grouped_gemm_ut_cases.inc"
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_mx_grouped_gemm_util.hpp"
|
||||
#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompAsync
|
||||
: public TestCkTileMxGroupedGemm<T, TestCkTileMxGemmPipelineCompAsync<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsync
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsync, KernelTypesMxGemmCompAsync);
|
||||
|
||||
#include "test_mx_grouped_gemm_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Large-tensor / decomposition cases for the fp4 (a4w4) grouped MX GEMM (ROCM-22075).
|
||||
//
|
||||
// This is a SEPARATE executable from test_ck_tile_grouped_gemm_mx_comp_async and is intentionally
|
||||
// NOT registered with ctest (its CMake target uses add_executable, not add_gtest_executable), so
|
||||
// it is excluded from the default CI test pass. It is run explicitly (mirroring the CK
|
||||
// *_large_cases / RUN_*_LARGE_CASES_TESTS convention) because it allocates multi-GB device buffers
|
||||
// (per-group C ~2.5 GB) to exercise the int32 element-count / decomposition boundary.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_mx_grouped_gemm_util.hpp"
|
||||
#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompAsync
|
||||
: public TestCkTileMxGroupedGemm<T, TestCkTileMxGemmPipelineCompAsync<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompAsync
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompAsync, KernelTypesMxGemmCompAsync);
|
||||
|
||||
#include "test_mx_grouped_gemm_largeM_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
// Large-tensor decomposition / INT_MAX element-count validation for fp4 (a4w4) grouped MX GEMM
|
||||
// These cases exercise the all-GPU validation path (RunAllGpu): GPU data
|
||||
// init, GPU reference, GPU compare. They apply only to the fp4 non-persistent CompAsync kernel
|
||||
// row; for every other type in the suite they skip (RunAllGpu is fp4-only and the discarded
|
||||
// if-constexpr branch is never instantiated for non-fp4 types).
|
||||
//
|
||||
// These tests are deliberately NOT registered with ctest (see the dedicated *_large_cases target
|
||||
// in CMakeLists.txt) so they do not run in the default CI test pass; they are run explicitly
|
||||
// (mirroring the CK *_large_cases / RUN_*_LARGE_CASES_TESTS convention).
|
||||
|
||||
// Stage A: cross-validate the new GPU reference against the trusted host reference on identical
|
||||
// small shapes. Both Run() (host reference + check_err) and RunAllGpu() (GPU reference +
|
||||
// gpu_verify) must pass before the GPU reference can be trusted at large scale.
|
||||
TYPED_TEST(TEST_SUITE_NAME, GpuRefCrossValidate_Small)
|
||||
{
|
||||
using ADataType = typename TestFixture::ADataType;
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp4_t> &&
|
||||
TestFixture::PipelineType == MxGemmPipelineType::CompAsync &&
|
||||
!TestFixture::Persistent)
|
||||
{
|
||||
const int group_count = 2;
|
||||
const int kbatch = 1;
|
||||
const std::vector<int> Ms{256, 512};
|
||||
const std::vector<int> Ns{512, 1024};
|
||||
const std::vector<int> Ks{512, 512};
|
||||
|
||||
// Trusted host reference path first, then the new all-GPU path on the same shapes.
|
||||
this->Run(Ms, Ns, Ks, kbatch, group_count);
|
||||
this->RunAllGpu(Ms, Ns, Ks, kbatch, group_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
GTEST_SKIP() << "GPU-reference cases apply only to fp4 non-persistent CompAsync.";
|
||||
}
|
||||
}
|
||||
|
||||
// Stage B: the decomposition / int32-overflow case. Minimal shape that crosses the boundary:
|
||||
// 2 groups, each per-group M*N = 100352*12288 = 1,233,125,376 < INT_MAX (and C byte span
|
||||
// 2,466,250,752 > INT_MAX), aggregate total M*N = 2,466,250,752 > INT_MAX. This proves the
|
||||
// host M-decomposition is what keeps every per-group buffer int32-safe (a single fused tensor
|
||||
// would overflow the int32 element count), while the kernel still addresses C correctly even
|
||||
// though the per-group C *byte* span exceeds INT_MAX. K is kept small (compute is irrelevant to
|
||||
// the addressing property under test).
|
||||
TYPED_TEST(TEST_SUITE_NAME, LargeM_decomposition_int32)
|
||||
{
|
||||
using ADataType = typename TestFixture::ADataType;
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp4_t> &&
|
||||
TestFixture::PipelineType == MxGemmPipelineType::CompAsync &&
|
||||
!TestFixture::Persistent)
|
||||
{
|
||||
const int group_count = 2;
|
||||
const int kbatch = 1;
|
||||
const std::vector<int> Ms(group_count, 100352);
|
||||
const std::vector<int> Ns(group_count, 12288);
|
||||
const std::vector<int> Ks(group_count, 512);
|
||||
|
||||
this->RunAllGpu(Ms, Ns, Ks, kbatch, group_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
GTEST_SKIP() << "GPU-reference cases apply only to fp4 non-persistent CompAsync.";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_mx_grouped_gemm_util.hpp"
|
||||
|
||||
using F4 = ck_tile::pk_fp4_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using CompTDMV1 = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompTDMV1>;
|
||||
using CompTDMV2 = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompTDMV2>;
|
||||
using CompAsync = ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompAsync>;
|
||||
using CompEightWaves =
|
||||
ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::CompEightWaves>;
|
||||
using WeightPreshuffle =
|
||||
ck_tile::integral_constant<MxGemmPipelineType, MxGemmPipelineType::WeightPreshuffle>;
|
||||
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
using I512 = ck_tile::number<512>;
|
||||
|
||||
template <ck_tile::index_t N>
|
||||
using ScaleBS = ck_tile::integral_constant<ck_tile::index_t, N>;
|
||||
|
||||
// clang-format off
|
||||
// MX GEMM kernel types using TDM pipeline with scale support
|
||||
// Tuple format:
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AScaleDataType, BScaleDataType, AccDataType, CDataType, Persistent, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, PipelineType
|
||||
using KernelTypesMxGemmCompTDMWmma = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV1, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>,
|
||||
std::tuple< Col, Row, Row, F8, BF8, E8M0, E8M0, F32, F16, False, I64, I64, I128, I32, I32, CompTDMV2, ScaleBS<32>>
|
||||
>;
|
||||
|
||||
using KernelTypesMxGemmCompAsync = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I64, I64, I256, I16, I16, CompAsync, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I256, I128, I16, I16, CompEightWaves, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, False, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, False, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, True, I128, I256, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>,
|
||||
std::tuple< Row, Col, Row, F4, F4, E8M0, E8M0, F32, F16, True, I128, I512, I256, I16, I16, WeightPreshuffle, ScaleBS<32>, True>
|
||||
>;
|
||||
// clang-format on
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileMxGroupedGemm, Basic)
|
||||
TYPED_TEST(TEST_SUITE_NAME, Basic)
|
||||
{
|
||||
const int group_count = 4;
|
||||
const int kbatch = 1;
|
||||
@@ -14,8 +14,8 @@ TYPED_TEST(TestCkTileMxGroupedGemm, Basic)
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
Ns.push_back(512 + 512 * i);
|
||||
Ks.push_back(512 + TestFixture::K_Tile * i);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, kbatch, group_count);
|
||||
|
||||
@@ -14,11 +14,44 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck/library/utility/gpu_verification.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
#if defined(CK_USE_GFX1250)
|
||||
constexpr bool is_8bit = std::is_same_v<PrecType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<PrecType, ck_tile::bf8_t> ||
|
||||
std::is_same_v<PrecType, ck_tile::int8_t>;
|
||||
constexpr bool is_mxtype =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::pk_fp4_t>;
|
||||
if constexpr(M_Warp_Tile == 32 && is_mxtype)
|
||||
{
|
||||
return 128;
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_8bit ? 64 : 32;
|
||||
}
|
||||
#else
|
||||
return 16;
|
||||
#endif
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 64;
|
||||
else
|
||||
return 128;
|
||||
#endif
|
||||
}
|
||||
|
||||
enum struct MxGemmPipelineType
|
||||
{
|
||||
CompTDMV1,
|
||||
CompTDMV2
|
||||
CompTDMV2,
|
||||
CompAsync,
|
||||
CompEightWaves,
|
||||
WeightPreshuffle
|
||||
};
|
||||
|
||||
template <MxGemmPipelineType PT, typename Problem>
|
||||
@@ -42,61 +75,137 @@ struct MxGemmPipelineTypeSelector<MxGemmPipelineType::CompTDMV2, Problem>
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Pre-shuffle scale buffer for gfx1250 wmma mx scale instruction.
|
||||
*
|
||||
* Reorganizes the scale data from row-major (MN x K) layout to the hardware-specific
|
||||
* layout expected by the gfx1250 wmma instruction.
|
||||
*
|
||||
* @tparam ScaleType Scale data type (e.g., e8m0_t)
|
||||
* @tparam ScaleBlockSize The block size for microscaling (e.g., 32)
|
||||
* @tparam KStride Whether K is the fast-moving dimension
|
||||
*/
|
||||
template <typename ScaleType, ck_tile::index_t ScaleBlockSize, bool KStride>
|
||||
void preShuffleScaleBuffer_gfx1250(const ScaleType* src,
|
||||
ScaleType* dst,
|
||||
ck_tile::index_t MN,
|
||||
ck_tile::index_t K)
|
||||
template <typename Problem>
|
||||
struct MxGemmPipelineTypeSelector<MxGemmPipelineType::CompAsync, Problem>
|
||||
{
|
||||
static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1,
|
||||
"wrong! only support 8-bit scale with ScaleBlockSize=32");
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t MPerXdlops = 16;
|
||||
constexpr ck_tile::index_t KPerXdlops = 128;
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; }
|
||||
};
|
||||
|
||||
int MNPack = 2;
|
||||
int KPack = 1;
|
||||
template <typename Problem>
|
||||
struct MxGemmPipelineTypeSelector<MxGemmPipelineType::CompEightWaves, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsyncEightWaves<Problem>;
|
||||
|
||||
int MNStep = MPerXdlops;
|
||||
int KStep = KPerXdlops / ScaleBlockSize;
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrCompEightWaves"; }
|
||||
};
|
||||
|
||||
int K0 = K / KPack / KStep;
|
||||
template <typename Problem>
|
||||
struct MxGemmPipelineTypeSelector<MxGemmPipelineType::WeightPreshuffle, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using pipeline = ck_tile::MXGemmPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
|
||||
for(int mn = 0; mn < MN; ++mn)
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; }
|
||||
};
|
||||
|
||||
template <MxGemmPipelineType PT, typename Problem, bool PermuteN>
|
||||
struct MxGemmEpilogueTypeSelector
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompTDMV1, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::TdmEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompTDMV2, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::TdmEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompAsync, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::CShuffleEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::CompEightWaves, Problem, false>
|
||||
{
|
||||
using epilogue = ck_tile::CShuffleEpilogue<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem, bool PermuteN>
|
||||
struct MxGemmEpilogueTypeSelector<MxGemmPipelineType::WeightPreshuffle, Problem, PermuteN>
|
||||
{
|
||||
using epilogue = std::conditional_t<PermuteN,
|
||||
ck_tile::PermuteNEpilogue<Problem>,
|
||||
ck_tile::CShuffleEpilogue<Problem>>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t N_Warp_Tile_,
|
||||
ck_tile::index_t K_Warp_Tile_,
|
||||
ck_tile::index_t N_Tile_,
|
||||
ck_tile::index_t N_Warp_,
|
||||
typename BDataType_>
|
||||
struct Config
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_;
|
||||
static constexpr ck_tile::index_t N_Tile = N_Tile_;
|
||||
static constexpr ck_tile::index_t N_Warp = N_Warp_;
|
||||
static constexpr ck_tile::index_t BContiguousItemsPerAccess =
|
||||
std::is_same_v<BDataType_, ck_tile::pk_fp4_t> ? 32 : 16;
|
||||
};
|
||||
|
||||
// Deterministic per-element hash RNG for GPU data init. Returns a float in [-3, 3).
|
||||
// The generic `fill_tensor_uniform_rand_fp_values` filler is NOT valid for ck_tile::pk_fp4_t
|
||||
// (it converts a single float and duplicates it into both nibbles, and special-cases only the
|
||||
// classic ck::f4x2_pk_t). We need two independent fp4 values per byte, so we fill directly.
|
||||
// The narrow [-3,3) range keeps the fp16 GEMM output from overflowing at K up to 4096 (with the
|
||||
// [0.25,1.0] scales used in RunAllGpu, worst case K*9 = 36864 < 65504).
|
||||
__device__ inline float mx_fp4_fill_rand(unsigned int seed, unsigned long long idx)
|
||||
{
|
||||
// splitmix64-style avalanche; deterministic given (seed, idx).
|
||||
unsigned long long z = (idx + 1ULL) * 0x9E3779B97F4A7C15ULL +
|
||||
static_cast<unsigned long long>(seed) * 0xD1B54A32D192ED03ULL;
|
||||
z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ULL;
|
||||
z = (z ^ (z >> 27)) * 0x94D049BB133111EBULL;
|
||||
z ^= z >> 31;
|
||||
const float u =
|
||||
static_cast<float>((z >> 40) & 0xFFFFFFULL) / static_cast<float>(0x1000000); // [0,1)
|
||||
return u * 6.0f - 3.0f; // [-3,3)
|
||||
}
|
||||
|
||||
// Fill a packed-fp4 buffer with two independent, deterministic random fp4 values per byte.
|
||||
// `num_packed` is the number of pk_fp4_t elements (= total fp4 values / 2).
|
||||
__global__ void
|
||||
fill_pk_fp4_uniform_kernel(ck_tile::pk_fp4_t* __restrict__ ptr, long num_packed, unsigned int seed)
|
||||
{
|
||||
const long idx0 = static_cast<long>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const long nthr = static_cast<long>(gridDim.x) * blockDim.x;
|
||||
for(long i = idx0; i < num_packed; i += nthr)
|
||||
{
|
||||
int iMNRepeat = mn / (MNStep * MNPack);
|
||||
int tempmn = mn % (MNStep * MNPack);
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int iKRepeat = k / (KStep * KPack);
|
||||
int tempk = k % (KStep * KPack);
|
||||
|
||||
int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) +
|
||||
(iKRepeat * KStep * KPack) * (MNStep * MNPack) +
|
||||
tempmn * (KStep * KPack) + tempk;
|
||||
|
||||
if constexpr(KStride)
|
||||
{
|
||||
dst[outputIndex] = src[mn * K + k];
|
||||
}
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + mn];
|
||||
}
|
||||
const float lo_f = rintf(mx_fp4_fill_rand(seed, static_cast<unsigned long long>(i) * 2ULL));
|
||||
const float hi_f =
|
||||
rintf(mx_fp4_fill_rand(seed, static_cast<unsigned long long>(i) * 2ULL + 1ULL));
|
||||
const auto lo = ck_tile::float_to_mxfp4(lo_f, 1.0f);
|
||||
const auto hi = ck_tile::float_to_mxfp4(hi_f, 1.0f);
|
||||
ptr[i] = ck_tile::pk_fp4_t::_pack(lo, hi);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
inline void fill_pk_fp4_uniform(ck_tile::pk_fp4_t* ptr,
|
||||
long num_packed,
|
||||
unsigned int seed,
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
constexpr int threads = 256;
|
||||
constexpr long max_blocks = 65536; // grid-stride cap
|
||||
const long needed = (num_packed + threads - 1) / threads;
|
||||
const long blocks = needed < max_blocks ? needed : max_blocks;
|
||||
fill_pk_fp4_uniform_kernel<<<dim3(static_cast<unsigned>(blocks)), dim3(threads), 0, stream>>>(
|
||||
ptr, num_packed, seed);
|
||||
ck_tile::hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
@@ -111,9 +220,36 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
using PersistentType = std::tuple_element_t<9, Tuple>;
|
||||
static constexpr bool Persistent = PersistentType::value;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<10, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<11, Tuple>::value;
|
||||
static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<12, Tuple>::value;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<15, Tuple>::value;
|
||||
static constexpr ck_tile::index_t ScaleBlockSize = std::tuple_element_t<16, Tuple>::value;
|
||||
static constexpr bool PermuteN =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 17, std::false_type>::value;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<10, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<11, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<12, Tuple>{};
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<13, Tuple>{};
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<14, Tuple>{};
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::max(
|
||||
get_k_warp_tile<ADataType, M_Warp_Tile>(), get_k_warp_tile<BDataType, N_Warp_Tile>());
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp =
|
||||
PipelineType == MxGemmPipelineType::WeightPreshuffle
|
||||
? 1
|
||||
: (PipelineType == MxGemmPipelineType::CompEightWaves ? 4 : 2);
|
||||
static constexpr ck_tile::index_t N_Warp =
|
||||
PipelineType == MxGemmPipelineType::WeightPreshuffle ? 4 : 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool Preshuffle = PipelineType == MxGemmPipelineType::WeightPreshuffle;
|
||||
|
||||
// No D tensors for this test
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
@@ -123,42 +259,27 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
using AComputeDataType = ADataType;
|
||||
using BComputeDataType = BDataType;
|
||||
|
||||
struct GroupedGemKernelParam_Wmma
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = false;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 64;
|
||||
static const ck_tile::index_t N_Tile = 64;
|
||||
static const ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
static const ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
};
|
||||
|
||||
using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>;
|
||||
std::size_t get_workspace_size(const std::vector<mx_grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::MxGemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
bool invoke_mx_grouped_gemm(const std::vector<mx_grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr bool preshuffle = false;
|
||||
constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer
|
||||
#if defined(CK_USE_GFX1250)
|
||||
constexpr ck_tile::index_t BlockedXDLNPerWarp = 1;
|
||||
constexpr bool TransposeC =
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor> &&
|
||||
GroupedGemKernelParam::M_Warp_Tile == GroupedGemKernelParam::N_Warp_Tile;
|
||||
M_Warp_Tile == N_Warp_Tile;
|
||||
#elif defined(CK_USE_GFX950)
|
||||
constexpr ck_tile::index_t BlockedXDLNPerWarp = Preshuffle ? 2 : 1;
|
||||
constexpr bool TransposeC = false;
|
||||
#endif
|
||||
static constexpr bool StructuredSparsity = false;
|
||||
static constexpr bool NumWaveGroup = 1;
|
||||
|
||||
@@ -166,21 +287,15 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -189,7 +304,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
StructuredSparsity,
|
||||
Persistent,
|
||||
NumWaveGroup,
|
||||
preshuffle>;
|
||||
Preshuffle>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::MxGemmPipelineProblem<ADataType,
|
||||
@@ -209,7 +324,26 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
using GemmPipeline =
|
||||
typename MxGemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
|
||||
|
||||
using GemmEpilogue = ck_tile::TdmEpilogue<
|
||||
using GemmEpilogueProblem = std::conditional_t<
|
||||
Preshuffle && PermuteN,
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
false, /*FixedVectorSize_*/
|
||||
1>, /*VectorSizeC_*/
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
@@ -220,19 +354,24 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
1, /*kNumWaveGroups_*/
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
DoubleSmemBuffer, /*DoubleSmemBuffer*/
|
||||
AComputeDataType, /*AComputeDataType_*/
|
||||
BComputeDataType /*BComputeDataType_*/>>;
|
||||
1, /*kNumWaveGroups_*/
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
BlockedXDLNPerWarp, /*BlockedXDLN_PerWarp_*/
|
||||
DoubleSmemBuffer, /*DoubleSmemBuffer*/
|
||||
AComputeDataType, /*AComputeDataType_*/
|
||||
BComputeDataType, /*BComputeDataType_*/
|
||||
!Preshuffle>>;
|
||||
|
||||
using GemmEpilogue = typename MxGemmEpilogueTypeSelector<PipelineType,
|
||||
GemmEpilogueProblem,
|
||||
PermuteN>::epilogue;
|
||||
|
||||
using Kernel = ck_tile::MxGroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
@@ -266,7 +405,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
|
||||
ck_tile::ignore =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
@@ -303,36 +442,9 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
static constexpr bool check_data_type()
|
||||
{
|
||||
|
||||
// Validate scale type / data type combination
|
||||
constexpr bool a_is_f4 = std::is_same_v<ADataType, ck_tile::pk_fp4_t>;
|
||||
constexpr bool b_is_f4 = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
constexpr bool a_scale_e8m0 = std::is_same_v<AScaleDataType, ck_tile::e8m0_t>;
|
||||
constexpr bool b_scale_e8m0 = std::is_same_v<BScaleDataType, ck_tile::e8m0_t>;
|
||||
if constexpr(!a_is_f4 && !a_scale_e8m0)
|
||||
return false;
|
||||
if constexpr(!b_is_f4 && !b_scale_e8m0)
|
||||
return false;
|
||||
|
||||
// Check hardware WMMA support for the fixed warp tile (32x32x128)
|
||||
#if defined(CK_USE_GFX1250)
|
||||
return ck_tile::has_wmma_traits_v<ck_tile::gfx125_t,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GroupedGemKernelParam_Wmma::M_Warp_Tile,
|
||||
GroupedGemKernelParam_Wmma::N_Warp_Tile,
|
||||
GroupedGemKernelParam_Wmma::K_Warp_Tile>;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if constexpr(!check_data_type())
|
||||
if constexpr(!Derived::check_data_type())
|
||||
{
|
||||
GTEST_SKIP() << "Unsupported data type / layout combination for mx_grouped_gemm.";
|
||||
}
|
||||
@@ -345,7 +457,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
const int kbatch = 1,
|
||||
const int group_count = 16)
|
||||
{
|
||||
if constexpr(!check_data_type())
|
||||
if constexpr(!Derived::check_data_type())
|
||||
return;
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
@@ -445,8 +557,42 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
// For pk_fp4_t each byte packs two 4-bit elements; the generic filler
|
||||
// converts a single float and duplicates it into both nibbles.
|
||||
// Generate two independent random values per byte instead.
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
std::mt19937 gen(11939);
|
||||
std::uniform_real_distribution<float> dis(-5.f, 5.f);
|
||||
for(auto& elem : a_m_k_tensors[i].mData)
|
||||
{
|
||||
auto lo = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f);
|
||||
auto hi = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f);
|
||||
elem = ck_tile::pk_fp4_t::_pack(lo, hi);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(
|
||||
a_m_k_tensors[i]);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
std::mt19937 gen(11940);
|
||||
std::uniform_real_distribution<float> dis(-5.f, 5.f);
|
||||
for(auto& elem : b_k_n_tensors[i].mData)
|
||||
{
|
||||
auto lo = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f);
|
||||
auto hi = ck_tile::float_to_mxfp4(std::round(dis(gen)), 1.f);
|
||||
elem = ck_tile::pk_fp4_t::_pack(lo, hi);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(
|
||||
b_k_n_tensors[i]);
|
||||
}
|
||||
|
||||
// K must be a multiple of ScaleBlockSize
|
||||
if(K % ScaleBlockSize != 0)
|
||||
@@ -454,15 +600,13 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
GTEST_SKIP() << "K must be multiple of ScaleBlockSize for MX GEMM";
|
||||
}
|
||||
const ck_tile::index_t num_scale_k = K / ScaleBlockSize;
|
||||
if(num_scale_k % (GroupedGemKernelParam_Wmma::K_Warp_Tile / ScaleBlockSize) != 0)
|
||||
if(num_scale_k % (K_Warp_Tile / ScaleBlockSize) != 0)
|
||||
{
|
||||
GTEST_SKIP() << "K must be a multiple of K_Warp_Tile ("
|
||||
<< GroupedGemKernelParam_Wmma::K_Warp_Tile
|
||||
GTEST_SKIP() << "K must be a multiple of K_Warp_Tile (" << K_Warp_Tile
|
||||
<< ") for MX GEMM. Pad the scale data.";
|
||||
}
|
||||
const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple(
|
||||
static_cast<ck_tile::index_t>(M),
|
||||
static_cast<ck_tile::index_t>(GroupedGemKernelParam_Wmma::M_Warp_Tile));
|
||||
static_cast<ck_tile::index_t>(M), static_cast<ck_tile::index_t>(M_Tile));
|
||||
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a(
|
||||
{static_cast<std::size_t>(scale_padded_M), static_cast<std::size_t>(num_scale_k)},
|
||||
@@ -515,6 +659,9 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
}
|
||||
|
||||
// Pre-shuffle scale buffers for the hardware
|
||||
#if defined(CK_USE_GFX1250)
|
||||
constexpr ck_tile::index_t NXdlPackEff = 1;
|
||||
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(scale_padded_M), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
@@ -523,11 +670,6 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
{static_cast<std::size_t>(N), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
|
||||
std::cout << " scale_a: [scale_padded_M = " << scale_padded_M
|
||||
<< ", num_scale_k = " << num_scale_k << "]." << std::endl;
|
||||
std::cout << " scale_b: [N = " << N << ", num_scale_k = " << num_scale_k << "]."
|
||||
<< std::endl;
|
||||
|
||||
// Pre-shuffle for gfx1250 (WaveSize=32, WMMA)
|
||||
preShuffleScaleBuffer_gfx1250<AScaleDataType, ScaleBlockSize, true>(
|
||||
scale_a.mData.data(), scale_a_shuffled.mData.data(), scale_padded_M, num_scale_k);
|
||||
@@ -536,19 +678,95 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
// where N is the fast-changing dimension for col-major B
|
||||
preShuffleScaleBuffer_gfx1250<BScaleDataType, ScaleBlockSize, true>(
|
||||
scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k);
|
||||
#elif defined(CK_USE_GFX950)
|
||||
constexpr ck_tile::index_t MPerXdl = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl = N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl = K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl);
|
||||
constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl);
|
||||
constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff =
|
||||
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff =
|
||||
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff =
|
||||
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
|
||||
constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(scale_padded_M / MXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2),
|
||||
static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b_shuffled(
|
||||
{static_cast<std::size_t>(N / NXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2),
|
||||
static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::
|
||||
preShuffleScaleBuffer_gfx950<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_a.mData.data(),
|
||||
scale_a_shuffled.mData.data(),
|
||||
scale_padded_M,
|
||||
num_scale_k,
|
||||
true);
|
||||
|
||||
if constexpr(Preshuffle && PermuteN)
|
||||
{
|
||||
ck_tile::preShuffleScaleBufferPermuteN_gfx950<N_Warp, N_Tile, XdlMNThread>(
|
||||
scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
preShuffleScaleBuffer_gfx950<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::cout << " scale_a: [scale_padded_M = " << scale_padded_M
|
||||
<< ", num_scale_k = " << num_scale_k << "]." << std::endl;
|
||||
std::cout << " scale_b: [N = " << N << ", num_scale_k = " << num_scale_k << "]."
|
||||
<< std::endl;
|
||||
|
||||
scale_a_tensors.push_back(scale_a_shuffled);
|
||||
scale_b_tensors.push_back(scale_b_shuffled);
|
||||
|
||||
using GemmConfig = Config<N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp, BDataType>;
|
||||
|
||||
const auto b_host_for_dev = [&]() {
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
if constexpr(PermuteN)
|
||||
{
|
||||
return ck_tile::shuffle_b_permuteN<GemmConfig, BDataType, NXdlPackEff>(
|
||||
b_k_n_tensors[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return b_k_n_tensors[i];
|
||||
}
|
||||
}();
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_host_for_dev.get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_host_for_dev.data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
@@ -584,7 +802,7 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
if(!invoke_mx_grouped_gemm<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
|
||||
if(!invoke_mx_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer()))
|
||||
@@ -628,4 +846,321 @@ class TestCkTileMxGroupedGemm : public ::testing::Test
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
// All-GPU validation path for the fp4 (pk_fp4_t) MX grouped GEMM.
|
||||
//
|
||||
// Unlike Run(), this never materializes the (potentially 39 GB) A/B/C tensors on the host:
|
||||
// - A/B are generated directly on device with a deterministic fp4 fill.
|
||||
// - the reference is computed on device by reference_mx_gemm_gpu.
|
||||
// - the comparison is done on device by ck::profiler::gpu_verify.
|
||||
// Only the tiny e8m0 scales touch the host (for pre-shuffle + an unshuffled copy that the
|
||||
// device reference consumes). Groups are processed one at a time to bound peak device memory
|
||||
// and to make any fault attributable to a specific group.
|
||||
//
|
||||
// Per group it logs and asserts that every int32-addressed quantity (M*N, A/B/C worst-case
|
||||
// element offsets) stays < INT_MAX -- this is the property under test for the host-side
|
||||
// M-decomposition that keeps each per-group buffer kernel-safe.
|
||||
void RunAllGpu(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const int kbatch = 1,
|
||||
const int group_count = 16)
|
||||
{
|
||||
if constexpr(!Derived::check_data_type())
|
||||
return;
|
||||
|
||||
static_assert(std::is_same_v<ADataType, ck_tile::pk_fp4_t> &&
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
|
||||
"RunAllGpu currently supports pk_fp4_t A/B only.");
|
||||
// The GPU reference (reference_mx_gemm_gpu) hardcodes these layouts; guard so it cannot be
|
||||
// silently misused with a layout it does not handle.
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor> &&
|
||||
std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor> &&
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>,
|
||||
"RunAllGpu / reference_mx_gemm_gpu assume RowMajor-A, ColumnMajor-B, "
|
||||
"RowMajor-C.");
|
||||
|
||||
#if !defined(CK_USE_GFX950)
|
||||
(void)Ms;
|
||||
(void)Ns;
|
||||
(void)Ks;
|
||||
(void)kbatch;
|
||||
(void)group_count;
|
||||
GTEST_SKIP() << "RunAllGpu requires CK_USE_GFX950.";
|
||||
#else
|
||||
using namespace ck_tile::literals;
|
||||
constexpr long kIntMax = 2147483647L; // INT_MAX
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return col;
|
||||
else
|
||||
return row;
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
constexpr ck_tile::index_t psize = ck_tile::numeric_traits<ADataType>::PackedSize; // 2
|
||||
static_assert(psize == 2,
|
||||
"RunAllGpu byte-sizing and reference_mx_gemm_kernel's a_ptr[a_lin/2] "
|
||||
"addressing assume pk_fp4_t PackedSize == 2.");
|
||||
|
||||
bool pass = true;
|
||||
long total_MN = 0;
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
// Strides are K/N here (small); keep them as index_t to match the kernel args, and
|
||||
// make the size_t->index_t narrowing explicit.
|
||||
const ck_tile::index_t stride_A =
|
||||
static_cast<ck_tile::index_t>(f_get_default_stride(M, K, 0, ALayout{})); // K
|
||||
const ck_tile::index_t stride_B =
|
||||
static_cast<ck_tile::index_t>(f_get_default_stride(K, N, 0, BLayout{})); // K
|
||||
const ck_tile::index_t stride_C =
|
||||
static_cast<ck_tile::index_t>(f_get_default_stride(M, N, 0, CLayout{})); // N
|
||||
|
||||
// Per-group shape guards. Fatal (ASSERT), not GTEST_SKIP: a skip mid-loop would
|
||||
// silently report success for only a prefix of the groups already validated.
|
||||
ASSERT_EQ(K % ScaleBlockSize, 0)
|
||||
<< "group " << i << ": K must be a multiple of ScaleBlockSize for MX GEMM";
|
||||
const ck_tile::index_t num_scale_k = K / ScaleBlockSize;
|
||||
ASSERT_EQ(num_scale_k % (K_Warp_Tile / ScaleBlockSize), 0)
|
||||
<< "group " << i << ": K must be a multiple of K_Warp_Tile (" << K_Warp_Tile
|
||||
<< ") for MX GEMM. Pad the scale data.";
|
||||
const ck_tile::index_t scale_padded_M = ck_tile::integer_least_multiple(
|
||||
static_cast<ck_tile::index_t>(M), static_cast<ck_tile::index_t>(M_Tile));
|
||||
|
||||
// int32-safety: the property under test for the M-decomposition. The predicate is
|
||||
// "largest 0-based element offset fits in a signed 32-bit int", i.e. offset <= INT_MAX.
|
||||
const long MN = static_cast<long>(M) * N;
|
||||
const long A_elems = static_cast<long>(M) * K;
|
||||
const long B_elems = static_cast<long>(K) * N;
|
||||
const long C_off = static_cast<long>(M - 1) * stride_C + (N - 1);
|
||||
const long A_off = static_cast<long>(M - 1) * stride_A + (K - 1);
|
||||
const long B_off = static_cast<long>(N - 1) * stride_B + (K - 1);
|
||||
const long c_bytes = MN * static_cast<long>(sizeof(CDataType));
|
||||
std::cout << "[int32-safety] group " << i << " M=" << M << " N=" << N << " K=" << K
|
||||
<< " M*N=" << MN << " A_elems=" << A_elems << " B_elems=" << B_elems
|
||||
<< " C_off=" << C_off << " A_off=" << A_off << " B_off=" << B_off
|
||||
<< " C_bytes=" << c_bytes << " (INT_MAX=" << kIntMax << ")" << std::endl;
|
||||
// Note (not an assert): the C *byte* span can exceed INT_MAX even when the element
|
||||
// count is int32-safe. We deliberately let the run proceed -- if any internal byte
|
||||
// offset overflows, gpu_verify will flag it, which is exactly what we want to discover.
|
||||
if(c_bytes > kIntMax)
|
||||
std::cout
|
||||
<< "[int32-safety][note] group " << i << " C byte span (" << c_bytes
|
||||
<< ") exceeds INT_MAX; if verification fails, byte-offset overflow is the "
|
||||
"prime suspect."
|
||||
<< std::endl;
|
||||
ASSERT_LE(MN - 1, kIntMax) << "group " << i << " max C element index exceeds INT_MAX";
|
||||
ASSERT_LE(C_off, kIntMax) << "group " << i << " C offset exceeds INT_MAX";
|
||||
ASSERT_LE(A_off, kIntMax) << "group " << i << " A offset exceeds INT_MAX";
|
||||
ASSERT_LE(B_off, kIntMax) << "group " << i << " B offset exceeds INT_MAX";
|
||||
total_MN += MN;
|
||||
|
||||
// Device buffers (no big host tensors). Round byte counts up: a stray odd fp4 element
|
||||
// still occupies a full packed byte.
|
||||
const long a_bytes = (A_elems + psize - 1) / psize;
|
||||
const long b_bytes = (B_elems + psize - 1) / psize;
|
||||
|
||||
// Bound peak device memory (A + B + 2*C + scales/workspace slack). Skip cleanly rather
|
||||
// than aborting via hip_check_error if the device cannot hold one group.
|
||||
{
|
||||
std::size_t free_b = 0, total_b = 0;
|
||||
ck_tile::hip_check_error(hipMemGetInfo(&free_b, &total_b));
|
||||
const std::size_t need = static_cast<std::size_t>(a_bytes) +
|
||||
static_cast<std::size_t>(b_bytes) +
|
||||
2u * static_cast<std::size_t>(c_bytes) + (64u << 20);
|
||||
if(free_b < need)
|
||||
GTEST_SKIP() << "group " << i << ": insufficient device memory (need " << need
|
||||
<< " B, free " << free_b << " B)";
|
||||
}
|
||||
|
||||
auto a_dev = std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(a_bytes));
|
||||
auto b_dev = std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(b_bytes));
|
||||
auto c_dev = std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(c_bytes));
|
||||
auto c_ref_dev =
|
||||
std::make_unique<ck_tile::DeviceMem>(static_cast<std::size_t>(c_bytes));
|
||||
c_dev->SetZero();
|
||||
c_ref_dev->SetZero();
|
||||
|
||||
// GPU fill A/B (deterministic, fp4-correct). Same device buffers feed both the kernel
|
||||
// and the reference, so the fill need not bit-match any host RNG. Fold the group index
|
||||
// into the seed so each group gets a distinct data pattern.
|
||||
fill_pk_fp4_uniform(reinterpret_cast<ADataType*>(a_dev->GetDeviceBuffer()),
|
||||
a_bytes,
|
||||
11939u + static_cast<unsigned int>(i));
|
||||
fill_pk_fp4_uniform(reinterpret_cast<BDataType*>(b_dev->GetDeviceBuffer()),
|
||||
b_bytes,
|
||||
11940u + static_cast<unsigned int>(i));
|
||||
ck_tile::hip_check_error(
|
||||
hipDeviceSynchronize()); // surface fill faults at the fill site
|
||||
|
||||
// e8m0 scales (tiny, host-built, fixed per-group seed for determinism). The range is
|
||||
// deliberately narrow ([0.25,1.0] scales, [-3,3) fp4 fill) so that K up to 4096 cannot
|
||||
// overflow the fp16 output (worst case K*9 = 36864 < 65504); gpu_verify counts matched
|
||||
// infinities as errors, so an overflow would otherwise be a false failure.
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a(
|
||||
{static_cast<std::size_t>(scale_padded_M), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b(
|
||||
{static_cast<std::size_t>(N), static_cast<std::size_t>(num_scale_k)},
|
||||
{static_cast<std::size_t>(num_scale_k), static_cast<std::size_t>(1)});
|
||||
{
|
||||
std::mt19937 gen(11941u + static_cast<unsigned int>(i));
|
||||
std::uniform_real_distribution<float> dist(0.25f, 1.0f);
|
||||
for(auto& s : scale_a.mData)
|
||||
s = AScaleDataType{dist(gen)};
|
||||
for(auto& s : scale_b.mData)
|
||||
s = BScaleDataType{dist(gen)};
|
||||
}
|
||||
|
||||
// gfx950 scale pre-shuffle. NOTE: this must stay in sync with the identical block in
|
||||
// Run() -- the kernel-input layout and the reference-input layout must agree.
|
||||
constexpr ck_tile::index_t MPerXdl = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl = N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl = K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * MPerXdl);
|
||||
constexpr ck_tile::index_t NIterPerWarp = N_Tile / (N_Warp * NPerXdl);
|
||||
constexpr ck_tile::index_t KIterPerWarp = K_Tile / KPerXdl;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff =
|
||||
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff =
|
||||
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff =
|
||||
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
|
||||
constexpr ck_tile::index_t XdlMNThread = M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
ck_tile::HostTensor<AScaleDataType> scale_a_shuffled(
|
||||
{static_cast<std::size_t>(scale_padded_M / MXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2),
|
||||
static_cast<std::size_t>(1)});
|
||||
ck_tile::HostTensor<BScaleDataType> scale_b_shuffled(
|
||||
{static_cast<std::size_t>(N / NXdlPackEff * 2),
|
||||
static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2)},
|
||||
{static_cast<std::size_t>(num_scale_k / KXdlPackEff * 2),
|
||||
static_cast<std::size_t>(1)});
|
||||
|
||||
ck_tile::
|
||||
preShuffleScaleBuffer_gfx950<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_a.mData.data(),
|
||||
scale_a_shuffled.mData.data(),
|
||||
scale_padded_M,
|
||||
num_scale_k,
|
||||
true);
|
||||
ck_tile::
|
||||
preShuffleScaleBuffer_gfx950<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(
|
||||
scale_b.mData.data(), scale_b_shuffled.mData.data(), N, num_scale_k, true);
|
||||
|
||||
// Device scale buffers: shuffled feed the kernel, unshuffled feed the reference.
|
||||
auto scale_a_shuf_dev = std::make_unique<ck_tile::DeviceMem>(
|
||||
scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
auto scale_b_shuf_dev = std::make_unique<ck_tile::DeviceMem>(
|
||||
scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
scale_a_shuf_dev->ToDevice(scale_a_shuffled.data());
|
||||
scale_b_shuf_dev->ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_ref_dev =
|
||||
std::make_unique<ck_tile::DeviceMem>(scale_a.get_element_space_size_in_bytes());
|
||||
auto scale_b_ref_dev =
|
||||
std::make_unique<ck_tile::DeviceMem>(scale_b.get_element_space_size_in_bytes());
|
||||
scale_a_ref_dev->ToDevice(scale_a.data());
|
||||
scale_b_ref_dev->ToDevice(scale_b.data());
|
||||
|
||||
// Launch the grouped kernel for this single group.
|
||||
std::vector<mx_grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.push_back(mx_grouped_gemm_kargs(a_dev->GetDeviceBuffer(),
|
||||
scale_a_shuf_dev->GetDeviceBuffer(),
|
||||
b_dev->GetDeviceBuffer(),
|
||||
scale_b_shuf_dev->GetDeviceBuffer(),
|
||||
{/*ds_ptr*/},
|
||||
c_dev->GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{/*stride_Ds*/},
|
||||
stride_C));
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
if(!invoke_mx_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer()))
|
||||
{
|
||||
ADD_FAILURE() << "invoke_mx_grouped_gemm failed for group " << i;
|
||||
pass = false;
|
||||
continue; // DeviceMem frees cleanly at loop end; keep validating other groups
|
||||
}
|
||||
ck_tile::hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
// GPU reference on the same device A/B buffers.
|
||||
ck_tile::reference_mx_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AScaleDataType,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
reinterpret_cast<const ADataType*>(a_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<const BDataType*>(b_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<const AScaleDataType*>(scale_a_ref_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<const BScaleDataType*>(scale_b_ref_dev->GetDeviceBuffer()),
|
||||
reinterpret_cast<CDataType*>(c_ref_dev->GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
num_scale_k,
|
||||
ScaleBlockSize);
|
||||
ck_tile::hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
// GPU verify with explicit MX tolerance (auto tolerance defaults too tight for MX).
|
||||
const float max_acc = ck::profiler::gpu_reduce_max<CDataType>(
|
||||
c_ref_dev->GetDeviceBuffer(), static_cast<std::size_t>(MN));
|
||||
// The reference must be non-degenerate, else error_count==0 is a vacuous pass.
|
||||
ASSERT_GT(max_acc, 0.0f) << "group " << i << ": GPU reference output is all-zero";
|
||||
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_acc);
|
||||
const auto res = ck::profiler::gpu_verify<CDataType>(c_dev->GetDeviceBuffer(),
|
||||
c_ref_dev->GetDeviceBuffer(),
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}),
|
||||
static_cast<std::size_t>(MN));
|
||||
|
||||
// Positive liveness check on the *device* output. res.all_zero ANDs device- and
|
||||
// reference-zeroness, and the reference is never zero here, so it cannot detect a no-op
|
||||
// kernel on its own -- reduce the device buffer directly.
|
||||
const float c_dev_absmax = ck::profiler::gpu_reduce_max<CDataType>(
|
||||
c_dev->GetDeviceBuffer(), static_cast<std::size_t>(MN));
|
||||
|
||||
std::cout << "[verify] group " << i << " errors=" << res.error_count
|
||||
<< " max_error=" << res.max_error << " c_dev_absmax=" << c_dev_absmax
|
||||
<< " max_acc=" << max_acc << " rtol=" << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " atol=" << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
|
||||
EXPECT_EQ(res.error_count, 0ull) << "group " << i << " produced mismatched results";
|
||||
EXPECT_GT(c_dev_absmax, 0.0f) << "group " << i << " produced an all-zero device output";
|
||||
pass &= (res.error_count == 0 && c_dev_absmax > 0.0f);
|
||||
// a_dev/b_dev/c_dev/... freed here (unique_ptr) before the next group.
|
||||
}
|
||||
|
||||
std::cout << "[int32-safety] aggregate total_M*N=" << total_MN << " (INT_MAX=" << kIntMax
|
||||
<< ") -> decomposition is the variable under test" << std::endl;
|
||||
EXPECT_TRUE(pass);
|
||||
#endif // CK_USE_GFX950
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_mx_grouped_gemm_util.hpp"
|
||||
#include "test_mx_grouped_gemm_pipeline_kernel_types.hpp"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileMxGemmPipelineCompTDMWmma
|
||||
: public TestCkTileMxGroupedGemm<T, TestCkTileMxGemmPipelineCompTDMWmma<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type()
|
||||
{
|
||||
using Base = TestCkTileMxGroupedGemm<T, TestCkTileMxGemmPipelineCompTDMWmma<T>>;
|
||||
|
||||
if constexpr(!is_valid_mx_scale_combination<typename Base::ADataType,
|
||||
typename Base::AScaleDataType,
|
||||
typename Base::BDataType,
|
||||
typename Base::BScaleDataType>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(CK_USE_GFX1250)
|
||||
using DeviceIp = ck_tile::gfx125_t;
|
||||
#else
|
||||
#error "Unsupported architecture for WMMA MX GEMM"
|
||||
#endif
|
||||
|
||||
return ck_tile::has_wmma_traits_v<DeviceIp,
|
||||
typename Base::ADataType,
|
||||
typename Base::BDataType,
|
||||
typename Base::AccDataType,
|
||||
ck_tile::constant<Base::M_Warp_Tile>::value,
|
||||
ck_tile::constant<Base::N_Warp_Tile>::value,
|
||||
ck_tile::constant<Base::K_Warp_Tile>::value>;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType>
|
||||
static constexpr bool is_valid_mx_scale_combination()
|
||||
{
|
||||
constexpr bool a_is_f4 = std::is_same_v<ADataType, ck_tile::pk_fp4_t>;
|
||||
constexpr bool b_is_f4 = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
constexpr bool a_scale_e8m0 = std::is_same_v<AScaleDataType, ck_tile::e8m0_t>;
|
||||
constexpr bool b_scale_e8m0 = std::is_same_v<BScaleDataType, ck_tile::e8m0_t>;
|
||||
|
||||
// Non-F4 must use E8M0 scale
|
||||
if constexpr(!a_is_f4 && !a_scale_e8m0)
|
||||
return false;
|
||||
if constexpr(!b_is_f4 && !b_scale_e8m0)
|
||||
return false;
|
||||
|
||||
// Both E8M0 -> always valid
|
||||
if constexpr(a_scale_e8m0 && b_scale_e8m0)
|
||||
return true;
|
||||
|
||||
// Both non-E8M0 -> must match (both are F4 by rule 1)
|
||||
if constexpr(!a_scale_e8m0 && !b_scale_e8m0)
|
||||
return std::is_same_v<AScaleDataType, BScaleDataType>;
|
||||
|
||||
// One side non-E8M0: the E8M0 side must not be F4
|
||||
if constexpr(!a_scale_e8m0)
|
||||
return !b_is_f4;
|
||||
if constexpr(!b_scale_e8m0)
|
||||
return !a_is_f4;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileMxGemmPipelineCompTDMWmma
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileMxGemmPipelineCompTDMWmma, KernelTypesMxGemmCompTDMWmma);
|
||||
|
||||
#include "test_mx_grouped_gemm_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
Reference in New Issue
Block a user