mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Integration of a new pipeline for weight preshuffle into gemm examples (#2516)
* something khushbu can help with * v1 v2 works with flatmm develop * v0 v1 v2 numerical error gone * Fixing numerical error, and interchange preshuffle configs to match with flatmm * Refactor GEMM pipeline configurations and integrate preshuffle support - Updated preshuffle pipeline definitions to include multiple versions (V1, V2, V3). - Changed the pipeline constant from CK_TILE_PIPELINE_PRESHUFFLE to CK_TILE_PIPELINE_PRESHUFFLE_V3 in relevant configurations. - Removed obsolete code and comments * clang format * fix vectorloadsize bug * add the Preshuffle3 * update kwarp calculation in gemm utils * update vector size A and B correctly in V2 pipeline; Added few more changes to align with dteng's branch * fix: add CK_GFX950_SUPPORT macro for gfx950 detection * default disable rotating buffer * docs(CHANGELOG): update changelog for rocm 7.0 * Revert "docs(CHANGELOG): update changelog for rocm 7.0" This reverts commit2bc16fff84. * Remove unused Preshuffle V3 pipeline and related code; update gemm function to use Preshuffle V2; clean up comments and formatting in various files. * revert example/ck_tile/flatmm to its original state * remove comment added by second author * switch to xor ALDSDescriptor * modify the MakeALdsDescriptor() * temporary profiling script * getting rid of line marker compiler error * UniversalWeightPreshufflePipelineAgBgCrPolicy now derives from UniversalGemmBasePolicy * add a minor fix for the config * typo fix * Fix formatting in lambda function for WeightPreshufflePipelineAGmemBGmemCRegV2 * revert change in include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp * revert change in include/ck_tile/core/arch/amd_buffer_addressing.hpp * reenable the GemmSpatiallyLocalTilePartitioner * make GemmConfigPreshuffle_1 for v1 pipeline, GemmConfigPreshuffle_2 for v2 pipeline * remove hardcoded true for preshuffle bool template argument * rename script * remove gemm_profilie.sh script * merge conflict resolve * clang formatted * typo fix * Remove duplicate include of block_gemm_areg_bsmem_creg_v2r1.hpp in gemm.hpp * Remove commented-out code in UniversalWeightPreshufflePipelineAgBgCrPolicy * Fix missing newline at end of file in run_gemm_example.inc * Remove unused barrier call in BlockWeightPreshuffleASmemBSmemCRegV1 * addressing review comments * removing debug code * addressing review comments * Revert "addressing review comments" This reverts commit29c45192ba. * updating tile_engine code * addressing review comments --------- Co-authored-by: amd-khushbu <khuagarw@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -2,9 +2,15 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
|
||||
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
|
||||
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE 5
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
@@ -36,7 +37,7 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
@@ -231,7 +232,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
struct GemmConfigPreshuffle_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -247,13 +248,13 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
struct GemmConfigPreshuffle_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -263,15 +264,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
@@ -429,7 +430,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
@@ -438,6 +439,16 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
@@ -279,13 +279,11 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshufle_1>(argc, argv);
|
||||
return !run_gemm_example<GemmConfigPreshuffle_2>(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
// Return a non-zero code to indicate failure
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -219,6 +219,7 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
|
||||
0
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Executable file → Normal file
0
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Executable file → Normal file
@@ -32,6 +32,7 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
|
||||
@@ -48,8 +48,9 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Executable file → Normal file
0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Executable file → Normal file
@@ -112,7 +112,7 @@ struct GemmTile1DPartitioner
|
||||
* @param N GEMM's N dimension.
|
||||
* @return dim3 Structure holding grid's X,Y and Z dimensions.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
|
||||
@@ -9,77 +9,19 @@
|
||||
namespace ck_tile {
|
||||
|
||||
struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
: public UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
using BasePolicy = UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>;
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
if constexpr(MPerXdl == 16 && NPerXdl == 16)
|
||||
{
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
/*xor*/
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
@@ -87,8 +29,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
|
||||
number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
@@ -96,119 +38,29 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return a_lds_block_desc;
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum global memory vector load size.
|
||||
*
|
||||
* @tparam Problem The UniversalGemmPipelineProblem object.
|
||||
* @tparam DataType The tensor data type we're considering.
|
||||
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
|
||||
* @tparam XPerTile The contiguous Tile dimension size.
|
||||
* @return Maximum DRAM vector load size.
|
||||
*/
|
||||
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
|
||||
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
|
||||
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -426,7 +278,6 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -276,12 +276,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
auto b_flat_dram_window =
|
||||
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
@@ -468,5 +467,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
107
script/gemm_profile.sh
Executable file
107
script/gemm_profile.sh
Executable file
@@ -0,0 +1,107 @@
|
||||
#!/bin/bash
|
||||
|
||||
BIN=./bin/tile_example_gemm_weight_preshuffle
|
||||
PREC=fp8
|
||||
VERBOSITY=2
|
||||
|
||||
# List of all (m, n, k) triplets
|
||||
ARGS_LIST=(
|
||||
"1 2048 5120"
|
||||
"1 5120 1024"
|
||||
"2 2048 5120"
|
||||
"2 5120 1024"
|
||||
"3 2048 5120"
|
||||
"3 5120 1024"
|
||||
"4 2048 5120"
|
||||
"4 5120 1024"
|
||||
"5 2048 5120"
|
||||
"5 5120 1024"
|
||||
"6 2048 5120"
|
||||
"6 5120 1024"
|
||||
"7 2048 5120"
|
||||
"7 5120 1024"
|
||||
"8 2048 5120"
|
||||
"8 5120 1024"
|
||||
"9 2048 5120"
|
||||
"9 5120 1024"
|
||||
"10 2048 5120"
|
||||
"10 5120 1024"
|
||||
"11 2048 5120"
|
||||
"11 5120 1024"
|
||||
"12 2048 5120"
|
||||
"12 5120 1024"
|
||||
"13 2048 5120"
|
||||
"13 5120 1024"
|
||||
"14 2048 5120"
|
||||
"14 5120 1024"
|
||||
"15 2048 5120"
|
||||
"15 5120 1024"
|
||||
"16 2048 5120"
|
||||
"16 5120 1024"
|
||||
"2048 5120 1024"
|
||||
"2048 5120 8192"
|
||||
"2048 7168 8192"
|
||||
"2048 8192 3584"
|
||||
"16384 7168 8192"
|
||||
"16384 8192 3584"
|
||||
)
|
||||
|
||||
# Output file
|
||||
OUTPUT_FILE="gemm_profile_results.csv"
|
||||
|
||||
# Output header
|
||||
echo "m,n,k,Pipeline,Time_ms,TFlops,GBps,Verification" > "$OUTPUT_FILE"
|
||||
|
||||
# Loop over each argument set
|
||||
for args in "${ARGS_LIST[@]}"; do
|
||||
read -r m n k <<< "$args"
|
||||
|
||||
echo "Testing: m=$m, n=$n, k=$k"
|
||||
OUTPUT=$($BIN -m=$m -n=$n -k=$k -prec=$PREC -v=$VERBOSITY 2>/dev/null)
|
||||
|
||||
# Extract pipeline information
|
||||
# Format: "Launching kernel with args: gemm_fp8_pipeline_AGmemBGmemCRegV2_128x256x256x256_16x16x128_16x16_0x0x0"
|
||||
PIPELINE=$(echo "$OUTPUT" | grep "Launching kernel with args:" | sed -n 's/.*Launching kernel with args: \(.*\)/\1/p')
|
||||
|
||||
# Extract TFlops and GB/s from the output
|
||||
# Format: "Run Gemm kernel with M=3840 N=4096 K=2048 ... : 0.042338 ms, 1521.67 TFlops, 1126.89 GB/s,"
|
||||
PERF_LINE=$(echo "$OUTPUT" | grep "TFlops")
|
||||
|
||||
# Extract verification result
|
||||
# Format: "The GPU verification result is: correct"
|
||||
VERIFICATION=$(echo "$OUTPUT" | grep "The GPU verification result is:" | sed -n 's/.*The GPU verification result is: \(.*\)/\1/p')
|
||||
|
||||
if [ -n "$PERF_LINE" ]; then
|
||||
# Extract execution time in ms
|
||||
TIME_MS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ ms' | grep -o '[0-9]\+\.[0-9]\+')
|
||||
# Extract TFlops value - more robust regex
|
||||
TFLOPS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ TFlops' | grep -o '[0-9]\+\.[0-9]\+')
|
||||
# Extract GB/s value - more robust regex
|
||||
GBPS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ GB/s' | grep -o '[0-9]\+\.[0-9]\+')
|
||||
|
||||
# Use extracted pipeline or default if not found
|
||||
if [ -z "$PIPELINE" ]; then
|
||||
PIPELINE="gemm_basic"
|
||||
fi
|
||||
|
||||
# Print to terminal
|
||||
echo " Pipeline: $PIPELINE"
|
||||
echo " Time: ${TIME_MS} ms"
|
||||
echo " TFlops: ${TFLOPS}"
|
||||
echo " GB/s: ${GBPS}"
|
||||
|
||||
|
||||
# Save to CSV file
|
||||
echo "$m,$n,$k,$PIPELINE,$TIME_MS,$TFLOPS,$GBPS,$VERIFICATION" >> "$OUTPUT_FILE"
|
||||
else
|
||||
echo " ERROR: Could not parse performance data"
|
||||
echo ""
|
||||
echo "$m,$n,$k,$PIPELINE,,,,$VERIFICATION" >> "$OUTPUT_FILE"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=========================================="
|
||||
echo "Profile completed!"
|
||||
echo "Results saved to: $OUTPUT_FILE"
|
||||
echo "Total tests run: ${#ARGS_LIST[@]}"
|
||||
echo "=========================================="
|
||||
Reference in New Issue
Block a user