CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676)

* Block universal gemm.

* Universal block gemm with interwave scheduler - draft.

* Refactoring

* Move a/b_warp_tiles into BlockGemmImpl
* set BlockGemmImpl as a class member

* Change tile size for more suitable to memory bound cases.

* Introduce kKPerThread to WarpGemm

* Add documentation comment.

* Fix Interwave scheduler block gemm.

* Add compute/memory friendly tile configuration.

* Clean

* New tile configurations in gemm mem example.

* Add more static checks and fix loop order in block gemm.

* Add more static checks and use warp gemm mfma dispatcher.

* Add default scheduler block gemm.

* Remove logging in example.
This commit is contained in:
Adam Osewski
2024-11-26 08:45:14 +01:00
committed by GitHub
parent 440e28b08f
commit b6bcd76d88
11 changed files with 779 additions and 56 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -21,9 +21,10 @@ struct WarpGemmAtrributeMfma
using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
@@ -86,9 +87,10 @@ struct WarpGemmAtrributeMfmaIterateK
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
@@ -197,9 +199,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
@@ -260,9 +263,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
@@ -330,9 +334,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
@@ -444,10 +449,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
@@ -583,10 +589,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -14,6 +14,11 @@ struct WarpGemmImpl
static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType;