Integration of a new pipeline for weight preshuffle into gemm examples (#2516)

* something khushbu can help with

* v1 v2 works with flatmm develop

* v0 v1 v2 numerical error gone

* Fixing numerical error, and interchange preshuffle configs to match with flatmm

* Refactor GEMM pipeline configurations and integrate preshuffle support

- Updated preshuffle pipeline definitions to include multiple versions (V1, V2, V3).
- Changed the pipeline constant from CK_TILE_PIPELINE_PRESHUFFLE to CK_TILE_PIPELINE_PRESHUFFLE_V3 in relevant configurations.
- Removed obsolete code and comments

* clang format

* fix vectorloadsize bug

* add the Preshuffle3

* update kwarp calculation in gemm utils

* update vector size A and B correctly in V2 pipeline; Added few more changes to align with dteng's branch

* fix: add CK_GFX950_SUPPORT macro for gfx950 detection

* default disable rotating buffer

* docs(CHANGELOG): update changelog for rocm 7.0

* Revert "docs(CHANGELOG): update changelog for rocm 7.0"

This reverts commit 2bc16fff84.

* Remove unused Preshuffle V3 pipeline and related code; update gemm function to use Preshuffle V2; clean up comments and formatting in various files.

* revert example/ck_tile/flatmm to its original state

* remove comment added by second author

* switch to xor ALDSDescriptor

* modify the MakeALdsDescriptor()

* temporary profiling script

* getting rid of line marker compiler error

* UniversalWeightPreshufflePipelineAgBgCrPolicy now derives from UniversalGemmBasePolicy

* add a minor fix for the config

* typo fix

* Fix formatting in lambda function for WeightPreshufflePipelineAGmemBGmemCRegV2

* revert change in include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp

* revert change in include/ck_tile/core/arch/amd_buffer_addressing.hpp

* reenable the GemmSpatiallyLocalTilePartitioner

* make GemmConfigPreshuffle_1 for v1 pipeline, GemmConfigPreshuffle_2 for v2 pipeline

* remove hardcoded true for preshuffle bool template argument

* rename script

* remove gemm_profilie.sh script

* merge conflict resolve

* clang formatted

* typo fix

* Remove duplicate include of block_gemm_areg_bsmem_creg_v2r1.hpp in gemm.hpp

* Remove commented-out code in UniversalWeightPreshufflePipelineAgBgCrPolicy

* Fix missing newline at end of file in run_gemm_example.inc

* Remove unused barrier call in BlockWeightPreshuffleASmemBSmemCRegV1

* addressing review comments

* removing debug code

* addressing review comments

* Revert "addressing review comments"

This reverts commit 29c45192ba.

* updating tile_engine code

* addressing review comments

---------

Co-authored-by: amd-khushbu <khuagarw@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>

[ROCm/composable_kernel commit: 1441a0a7ee]
This commit is contained in:
Aviral Goel
2025-08-01 03:04:54 -04:00
committed by GitHub
parent 7a6efa0194
commit 8bf96c18c6
13 changed files with 1231 additions and 187 deletions

View File

@@ -2,9 +2,15 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})

View File

@@ -14,12 +14,13 @@
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(__gfx950__)
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
@@ -36,7 +37,7 @@ constexpr ck_tile::index_t get_k_warp_tile()
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(__gfx950__)
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
@@ -231,7 +232,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
struct GemmConfigPreshuffle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -247,13 +248,13 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
struct GemmConfigPreshuffle_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -263,15 +264,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
@@ -429,7 +430,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
@@ -438,6 +439,16 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -279,13 +279,11 @@ int main(int argc, char* argv[])
{
try
{
return !run_gemm_example<GemmConfigPreshufle_1>(argc, argv);
return !run_gemm_example<GemmConfigPreshuffle_2>(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
// Return a non-zero code to indicate failure
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}

View File

@@ -219,6 +219,7 @@ int run_flatmm_example(int argc, char* argv[])
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(