mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676)
* Block universal gemm. * Universal block gemm with interwave scheduler - draft. * Refactoring * Move a/b_warp_tiles into BlockGemmImpl * set BlockGemmImpl as a class member * Change tile size for more suitable to memory bound cases. * Introduce kKPerThread to WarpGemm * Add documentation comment. * Fix Interwave scheduler block gemm. * Add compute/memory friendly tile configuration. * Clean * New tile configurations in gemm mem example. * Add more static checks and fix loop order in block gemm. * Add more static checks and use warp gemm mfma dispatcher. * Add default scheduler block gemm. * Remove logging in example.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -21,9 +21,10 @@ struct WarpGemmAtrributeMfma
|
||||
using BVecType = typename Impl::BVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
@@ -86,9 +87,10 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
@@ -197,9 +199,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
@@ -260,9 +263,10 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
using BVecType = typename Impl::AVecType;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
@@ -330,9 +334,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
@@ -444,10 +449,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
static constexpr index_t kM = Impl::kN;
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
@@ -583,10 +589,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
|
||||
using CVecType = typename Impl::CVecType;
|
||||
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,6 +14,11 @@ struct WarpGemmImpl
|
||||
static constexpr index_t kM = WarpGemmAttribute::kM;
|
||||
static constexpr index_t kN = WarpGemmAttribute::kN;
|
||||
static constexpr index_t kK = WarpGemmAttribute::kK;
|
||||
/// @brief The number of elements in K dimension processed by single thread in wavefront.
|
||||
///
|
||||
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
|
||||
/// In such situation this value reflects this fact.
|
||||
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
|
||||
|
||||
using ADataType = typename WarpGemmAttribute::ADataType;
|
||||
using BDataType = typename WarpGemmAttribute::BDataType;
|
||||
|
||||
Reference in New Issue
Block a user