Integration of a new pipeline for weight preshuffle into gemm examples (#2516)

* something khushbu can help with

* v1 v2 works with flatmm develop

* v0 v1 v2 numerical error gone

* Fixing numerical error, and interchange preshuffle configs to match with flatmm

* Refactor GEMM pipeline configurations and integrate preshuffle support

- Updated preshuffle pipeline definitions to include multiple versions (V1, V2, V3).
- Changed the pipeline constant from CK_TILE_PIPELINE_PRESHUFFLE to CK_TILE_PIPELINE_PRESHUFFLE_V3 in relevant configurations.
- Removed obsolete code and comments

* clang format

* fix vectorloadsize bug

* add the Preshuffle3

* update kwarp calculation in gemm utils

* update vector size A and B correctly in V2 pipeline; Added few more changes to align with dteng's branch

* fix: add CK_GFX950_SUPPORT macro for gfx950 detection

* default disable rotating buffer

* docs(CHANGELOG): update changelog for rocm 7.0

* Revert "docs(CHANGELOG): update changelog for rocm 7.0"

This reverts commit 2bc16fff84.

* Remove unused Preshuffle V3 pipeline and related code; update gemm function to use Preshuffle V2; clean up comments and formatting in various files.

* revert example/ck_tile/flatmm to its original state

* remove comment added by second author

* switch to xor ALDSDescriptor

* modify the MakeALdsDescriptor()

* temporary profiling script

* getting rid of line marker compiler error

* UniversalWeightPreshufflePipelineAgBgCrPolicy now derives from UniversalGemmBasePolicy

* add a minor fix for the config

* typo fix

* Fix formatting in lambda function for WeightPreshufflePipelineAGmemBGmemCRegV2

* revert change in include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp

* revert change in include/ck_tile/core/arch/amd_buffer_addressing.hpp

* reenable the GemmSpatiallyLocalTilePartitioner

* make GemmConfigPreshuffle_1 for v1 pipeline, GemmConfigPreshuffle_2 for v2 pipeline

* remove hardcoded true for preshuffle bool template argument

* rename script

* remove gemm_profilie.sh script

* merge conflict resolve

* clang formatted

* typo fix

* Remove duplicate include of block_gemm_areg_bsmem_creg_v2r1.hpp in gemm.hpp

* Remove commented-out code in UniversalWeightPreshufflePipelineAgBgCrPolicy

* Fix missing newline at end of file in run_gemm_example.inc

* Remove unused barrier call in BlockWeightPreshuffleASmemBSmemCRegV1

* addressing review comments

* removing debug code

* addressing review comments

* Revert "addressing review comments"

This reverts commit 29c45192ba.

* updating tile_engine code

* addressing review comments

---------

Co-authored-by: amd-khushbu <khuagarw@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Aviral Goel
2025-08-01 03:04:54 -04:00
committed by GitHub
parent 88d72178d6
commit 1441a0a7ee
13 changed files with 1231 additions and 187 deletions

View File

@@ -2,9 +2,15 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})

View File

@@ -14,12 +14,13 @@
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(__gfx950__)
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
@@ -36,7 +37,7 @@ constexpr ck_tile::index_t get_k_warp_tile()
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(__gfx950__)
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
@@ -231,7 +232,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
struct GemmConfigPreshuffle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -247,13 +248,13 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
struct GemmConfigPreshuffle_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -263,15 +264,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
@@ -429,7 +430,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
@@ -438,6 +439,16 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -279,13 +279,11 @@ int main(int argc, char* argv[])
{
try
{
return !run_gemm_example<GemmConfigPreshufle_1>(argc, argv);
return !run_gemm_example<GemmConfigPreshuffle_2>(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
// Return a non-zero code to indicate failure
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}

View File

@@ -219,6 +219,7 @@ int run_flatmm_example(int argc, char* argv[])
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(

0
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Executable file → Normal file
View File

View File

@@ -32,6 +32,7 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV1<Problem>
{

View File

@@ -48,8 +48,9 @@
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"

0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Executable file → Normal file
View File

View File

@@ -112,7 +112,7 @@ struct GemmTile1DPartitioner
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST static auto
CK_TILE_HOST_DEVICE static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;

View File

@@ -9,77 +9,19 @@
namespace ck_tile {
struct UniversalWeightPreshufflePipelineAgBgCrPolicy
: public UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
using BasePolicy = UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>;
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
if constexpr(MPerXdl == 16 && NPerXdl == 16)
{
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
else
{
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
/*xor*/
#if 0
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
@@ -87,8 +29,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
number<kMPerBlock / MLdsLayer>{},
number<kKPack>{}),
number<kMPerBlock / MLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
@@ -96,119 +38,29 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
number<kKPerBlock / kKPack * MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(
make_merge_transform(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
#endif
}
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{
return (PackedSize * 32 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (PackedSize * 16 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
{
return (PackedSize * 8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
{
return (PackedSize * 4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
{
return (PackedSize * 2 / sizeof(DataType));
}
else
{
return PackedSize;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
}
template <typename Problem>
@@ -426,7 +278,6 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle()
{
// using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,

View File

@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
namespace ck_tile {
@@ -276,12 +276,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
auto b_flat_dram_window =
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
@@ -468,5 +467,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
p_smem);
}
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

107
script/gemm_profile.sh Executable file
View File

@@ -0,0 +1,107 @@
#!/bin/bash
BIN=./bin/tile_example_gemm_weight_preshuffle
PREC=fp8
VERBOSITY=2
# List of all (m, n, k) triplets
ARGS_LIST=(
"1 2048 5120"
"1 5120 1024"
"2 2048 5120"
"2 5120 1024"
"3 2048 5120"
"3 5120 1024"
"4 2048 5120"
"4 5120 1024"
"5 2048 5120"
"5 5120 1024"
"6 2048 5120"
"6 5120 1024"
"7 2048 5120"
"7 5120 1024"
"8 2048 5120"
"8 5120 1024"
"9 2048 5120"
"9 5120 1024"
"10 2048 5120"
"10 5120 1024"
"11 2048 5120"
"11 5120 1024"
"12 2048 5120"
"12 5120 1024"
"13 2048 5120"
"13 5120 1024"
"14 2048 5120"
"14 5120 1024"
"15 2048 5120"
"15 5120 1024"
"16 2048 5120"
"16 5120 1024"
"2048 5120 1024"
"2048 5120 8192"
"2048 7168 8192"
"2048 8192 3584"
"16384 7168 8192"
"16384 8192 3584"
)
# Output file
OUTPUT_FILE="gemm_profile_results.csv"
# Output header
echo "m,n,k,Pipeline,Time_ms,TFlops,GBps,Verification" > "$OUTPUT_FILE"
# Loop over each argument set
for args in "${ARGS_LIST[@]}"; do
read -r m n k <<< "$args"
echo "Testing: m=$m, n=$n, k=$k"
OUTPUT=$($BIN -m=$m -n=$n -k=$k -prec=$PREC -v=$VERBOSITY 2>/dev/null)
# Extract pipeline information
# Format: "Launching kernel with args: gemm_fp8_pipeline_AGmemBGmemCRegV2_128x256x256x256_16x16x128_16x16_0x0x0"
PIPELINE=$(echo "$OUTPUT" | grep "Launching kernel with args:" | sed -n 's/.*Launching kernel with args: \(.*\)/\1/p')
# Extract TFlops and GB/s from the output
# Format: "Run Gemm kernel with M=3840 N=4096 K=2048 ... : 0.042338 ms, 1521.67 TFlops, 1126.89 GB/s,"
PERF_LINE=$(echo "$OUTPUT" | grep "TFlops")
# Extract verification result
# Format: "The GPU verification result is: correct"
VERIFICATION=$(echo "$OUTPUT" | grep "The GPU verification result is:" | sed -n 's/.*The GPU verification result is: \(.*\)/\1/p')
if [ -n "$PERF_LINE" ]; then
# Extract execution time in ms
TIME_MS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ ms' | grep -o '[0-9]\+\.[0-9]\+')
# Extract TFlops value - more robust regex
TFLOPS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ TFlops' | grep -o '[0-9]\+\.[0-9]\+')
# Extract GB/s value - more robust regex
GBPS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ GB/s' | grep -o '[0-9]\+\.[0-9]\+')
# Use extracted pipeline or default if not found
if [ -z "$PIPELINE" ]; then
PIPELINE="gemm_basic"
fi
# Print to terminal
echo " Pipeline: $PIPELINE"
echo " Time: ${TIME_MS} ms"
echo " TFlops: ${TFLOPS}"
echo " GB/s: ${GBPS}"
# Save to CSV file
echo "$m,$n,$k,$PIPELINE,$TIME_MS,$TFLOPS,$GBPS,$VERIFICATION" >> "$OUTPUT_FILE"
else
echo " ERROR: Could not parse performance data"
echo ""
echo "$m,$n,$k,$PIPELINE,,,,$VERIFICATION" >> "$OUTPUT_FILE"
fi
done
echo "=========================================="
echo "Profile completed!"
echo "Results saved to: $OUTPUT_FILE"
echo "Total tests run: ${#ARGS_LIST[@]}"
echo "=========================================="