mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
[GEMM] Add macros for multiple optimization options
This commit is contained in:
@@ -13,6 +13,11 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
if(DEFINED kernel)
|
||||
message("Compiling with Kernel: ${kernel}")
|
||||
target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel}=1)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
|
||||
@@ -6,10 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#define mfma_m32_n32_k8 0
|
||||
#define mfma_m32_n32_k16 0
|
||||
#define mfma_m16_n16_k16 0
|
||||
#define mfma_m16_n16_k32 1
|
||||
#include "config.h"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -20,8 +17,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if mfma_m32_n32_k8
|
||||
#pragma message ("mfma m32 n32 k8")
|
||||
#if defined(USING_MFMA_32x32x_8x2)
|
||||
#pragma message ("mfma m32 n32 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -34,8 +31,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif mfma_m32_n32_k16
|
||||
#pragma message ("mfma m32 n32 k16")
|
||||
#elif defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message ("mfma m32 n32 k8")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -48,7 +45,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif mfma_m16_n16_k16
|
||||
|
||||
#elif defined(USING_MFMA_16x16x16)
|
||||
#pragma message("mfma m16 n16 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
@@ -62,7 +60,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif mfma_m16_n16_k32
|
||||
#elif defined(USING_MFMA_16x16x_16x2)
|
||||
#pragma message("mfma m16 n16 k32")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
|
||||
@@ -114,7 +114,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
#if 1
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message ("prefetch")
|
||||
// prefetch
|
||||
// global read 0
|
||||
|
||||
@@ -7,10 +7,7 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "block_gemm_asmem_bsmem_creg.hpp"
|
||||
|
||||
#define BANK_CONFLICT_K_FIRST 0
|
||||
#define PADDING_K_FIRST 0
|
||||
#define PADDING_MN_FIRST 0
|
||||
#define XOR 1
|
||||
#include "config.h"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -26,7 +23,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if BANK_CONFLICT_K_FIRST
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message ("BANK_CONFLICT: K_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
@@ -41,7 +38,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_K_FIRST
|
||||
#elif defined(PADDING_K_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
@@ -56,7 +53,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_MN_FIRST
|
||||
#elif defined(PADDING_MN_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
@@ -71,7 +68,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif XOR
|
||||
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
|
||||
#pragma message ("BANK_CONFLICT: XOR")
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
@@ -125,7 +122,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
#if BANK_CONFLICT_K_FIRST
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message ("BANK_CONFLICT: K_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
@@ -140,7 +137,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_K_FIRST
|
||||
#elif defined(PADDING_K_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
@@ -155,7 +152,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif PADDING_MN_FIRST
|
||||
#elif defined(PADDING_MN_FIRST)
|
||||
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
|
||||
@@ -170,7 +167,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
#elif XOR
|
||||
#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE)
|
||||
#pragma message ("BANK_CONFLICT: XOR")
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
|
||||
31
example/ck_tile/99_toy_example/02_gemm/config.h
Normal file
31
example/ck_tile/99_toy_example/02_gemm/config.h
Normal file
@@ -0,0 +1,31 @@
|
||||
|
||||
#if defined(KERNEL_A)
|
||||
#define USING_MFMA_32x32x_8x2
|
||||
#elif defined(KERNEL_B)
|
||||
#define USING_MFMA_16x16x16
|
||||
#elif defined(KERNEL_C)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#elif defined(KERNEL_D)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#elif defined(KERNEL_E)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#define ADJUST_BLOCK_TILE_SHAPE
|
||||
#elif defined(KERNEL_F)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#define ADJUST_BLOCK_TILE_SHAPE
|
||||
#define ENABLE_PREFETCH
|
||||
#define ENABLE_INSTRUCTION_SCH
|
||||
#elif defined(KERNEL_G)
|
||||
#define USING_MFMA_16x16x_16x2
|
||||
#define USING_XOR_BASED_BANK_CONFLICT_FREE
|
||||
#define ADJUST_BLOCK_TILE_SHAPE
|
||||
#define ENABLE_PREFETCH
|
||||
#define ENABLE_INSTRUCTION_SCH
|
||||
#define ENABLE_CACHE_AWARE_WG_SCH
|
||||
#else
|
||||
#define NAIVE_IMPLEMENTATION
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
|
||||
#include "config.h"
|
||||
#include "gemm.hpp"
|
||||
|
||||
/*
|
||||
@@ -47,6 +48,43 @@ int main(int argc, char* argv[])
|
||||
K = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
#if defined(KERNEL_A)
|
||||
printf("*** KernelA test *** \n");
|
||||
printf(" --> Using mfma_16x16x(8x2)\n");
|
||||
#elif defined(KERNEL_B)
|
||||
printf("*** KernelB test *** \n");
|
||||
printf(" --> Using mfma_16x16x16\n");
|
||||
#elif defined(KERNEL_C)
|
||||
printf("*** KernelC test *** \n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
#elif defined(KERNEL_D)
|
||||
printf("*** KernelD test *** \n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based banck conflict-free\n");
|
||||
#elif defined(KERNEL_E)
|
||||
printf("*** KernelE test ***\n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based banck conflict-free\n");
|
||||
printf(" --> Adjust block tile shape\n");
|
||||
#elif defined(KERNEL_F)
|
||||
printf("*** KernelF test ***\n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based banck conflict-free\n");
|
||||
printf(" --> Adjust block tile shape\n");
|
||||
printf(" --> Enable prefetch\n");
|
||||
printf(" --> Enable instruction schedule\n");
|
||||
#elif defined(KERNEL_G)
|
||||
printf("*** KernelG test ***\n");
|
||||
printf(" --> Using mfma_16x16x(16x2)\n");
|
||||
printf(" --> XOR-based banck conflict-free\n");
|
||||
printf(" --> Adjust block tile shape\n");
|
||||
printf(" --> Enable prefetch\n");
|
||||
printf(" --> Enable instruction schedule\n");
|
||||
printf(" --> Enable cache-aware thread blocks schedule\n");
|
||||
#else
|
||||
printf("*** Naive implementation test ***\n");
|
||||
#endif
|
||||
|
||||
const ck_tile::index_t Lda = K;
|
||||
const ck_tile::index_t Ldb = K;
|
||||
const ck_tile::index_t Ldc = N;
|
||||
@@ -82,7 +120,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
#if 1
|
||||
#ifdef ADJUST_BLOCK_TILE_SHAPE
|
||||
#pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size")
|
||||
constexpr ck_tile::index_t kGemmMPerBlock = 128;
|
||||
constexpr ck_tile::index_t kGemmKPerBlock = 64;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "config.h"
|
||||
#include "grid_gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -28,7 +29,7 @@ struct GridGemmProblem
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
#ifndef INSTRUCTION_SCHEDULE
|
||||
#ifndef ENABLE_INSTRUCTION_SCH
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
@@ -85,7 +86,7 @@ struct Gemm
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0,
|
||||
index_t N0)
|
||||
{
|
||||
#if 1
|
||||
#if defined(ENABLE_CACHE_AWARE_WG_SCH)
|
||||
#pragma message ("Cache-aware work group sch")
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 4;
|
||||
@@ -147,7 +148,7 @@ struct Gemm
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef INSTRUCTION_SCHEDULE
|
||||
#ifndef ENABLE_INSTRUCTION_SCH
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#define INSTRUCTION_SCHEDULE
|
||||
|
||||
#ifdef INSTRUCTION_SCHEDULE
|
||||
#ifdef ENABLE_INSTRUCTION_SCH
|
||||
#include "instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "instruction_schedule/gemm_pipeline_problem.hpp"
|
||||
#include "instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
@@ -59,7 +58,7 @@ struct GridGemm
|
||||
auto b_block_window = make_tile_window(
|
||||
b_grid, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {iN, 0});
|
||||
|
||||
#ifndef INSTRUCTION_SCHEDULE
|
||||
#ifndef ENABLE_INSTRUCTION_SCH
|
||||
#pragma message ("disable instruction scheduling")
|
||||
// Block GEMM pipeline w/o instruction scheduling
|
||||
constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline<Problem>();
|
||||
|
||||
Reference in New Issue
Block a user