mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Merge commit '608232ce82636e7c9ab8dec55dc7507c6792fb65' into develop
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -36,6 +36,9 @@ tags
|
||||
# Editors
|
||||
.vscode
|
||||
|
||||
# CMake formatting configuration (local)
|
||||
.cmake-format.yaml
|
||||
|
||||
# Cline
|
||||
.cline*
|
||||
|
||||
|
||||
@@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \
|
||||
git sparse-checkout set projects/hipblaslt shared/origami && \
|
||||
cd projects/hipblaslt && \
|
||||
git show --oneline -s && \
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx90a;gfx942;gfx950" -j 128 --skip_rocroller
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller
|
||||
|
||||
@@ -69,7 +69,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -128,7 +133,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
|
||||
|
||||
@@ -36,17 +36,13 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
// TODO: Is this 32K value gfx9 arch specific?
|
||||
static constexpr index_t MinMemInFlyBytes = 32768;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
|
||||
static constexpr index_t WgpPerCU = ck_tile::max(4 * get_warp_size() / BlockSize, 1);
|
||||
static constexpr index_t FullMemBandPrefetchStages =
|
||||
integer_divide_ceil(MinMemInFlyBytes / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) / APackedSize +
|
||||
NPerBlock * sizeof(BDataType) / BPackedSize) *
|
||||
KPerBlock);
|
||||
static constexpr index_t PrefetchStages =
|
||||
FullMemBandPrefetchStages >= 2
|
||||
? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
|
||||
: 2;
|
||||
static constexpr index_t PrefetchStages = ck_tile::clamp(FullMemBandPrefetchStages, 2, 8);
|
||||
|
||||
static constexpr index_t LocalPrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = PrefetchStages;
|
||||
|
||||
@@ -80,6 +80,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -165,6 +168,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename ADramWindow, typename ABlockTile_, typename DramTileWindowStep>
|
||||
CK_TILE_DEVICE static void
|
||||
LoadAndConvertATile(ABlockTile_& a_block_tile,
|
||||
ADramWindow& a_dram_window,
|
||||
const DramTileWindowStep& dram_tile_window_step)
|
||||
{
|
||||
using DestDataType = typename ABlockTile_::DataType;
|
||||
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
move_tile_window(a_dram_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -177,11 +193,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
|
||||
index_t m,
|
||||
[[maybe_unused]] index_t m,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
(void)m; // unused variable
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
@@ -197,11 +212,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
|
||||
static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!");
|
||||
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Aq block window has incorrect lengths for defined AqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
@@ -217,7 +228,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
// A/B tiles in LDS - using the same approach as regular gemm pipeline
|
||||
auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
|
||||
auto ab_lds_blocks = Base::template GetABLdsTensorViews<BDataType, BDataType>(p_smem);
|
||||
auto& a_lds_block = ab_lds_blocks.at(I0{});
|
||||
auto& b_lds_block = ab_lds_blocks.at(I1{});
|
||||
|
||||
@@ -249,7 +260,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
decltype(make_static_distributed_tensor<BDataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
using AQBlockTile =
|
||||
@@ -272,7 +283,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
|
||||
|
||||
// Global prefetch initialization - DRAM to VGPRs
|
||||
Base::GlobalPrefetch(
|
||||
LoadAndConvertATile(
|
||||
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
|
||||
@@ -282,10 +293,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS prefill - VGPRs to LDS
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
@@ -293,10 +304,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
@@ -306,9 +317,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
}
|
||||
// Additional prefetching for memory pipeline - DRAM to VGPRs
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
LoadAndConvertATile(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
@@ -325,16 +336,17 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile,
|
||||
aq_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
// Prepare next iteration data
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(
|
||||
a_shuffle_tmp,
|
||||
@@ -348,7 +360,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -365,9 +377,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
LoadAndConvertATile(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
@@ -381,20 +393,89 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
}
|
||||
|
||||
// Tail handling
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(
|
||||
c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window);
|
||||
auto HotLoopTail = [&](auto tail_num) {
|
||||
static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile,
|
||||
aq_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window);
|
||||
// no second block_sync_lds because it's interwave
|
||||
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp,
|
||||
a_block_tiles.get(number<prefetch_idx + 1>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx + 1>{}));
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp,
|
||||
b_block_tiles.get(number<prefetch_idx + 1>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx + 1>{}));
|
||||
}
|
||||
});
|
||||
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I1{}), a_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I1{}), b_element_func);
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile,
|
||||
aq_block_tiles.get(number<tail_num - 1>{}),
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window);
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(
|
||||
c_block_tile, aq_block_tiles.get(I1{}), a_lds_gemm_window, b_lds_gemm_window);
|
||||
c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Two)
|
||||
{
|
||||
HotLoopTail(number<2>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Three)
|
||||
{
|
||||
HotLoopTail(number<3>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Four)
|
||||
{
|
||||
HotLoopTail(number<4>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Five)
|
||||
{
|
||||
HotLoopTail(number<5>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Six)
|
||||
{
|
||||
HotLoopTail(number<6>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Seven)
|
||||
{
|
||||
HotLoopTail(number<7>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
HotLoopTail(number<PrefetchStages>{});
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
@@ -413,7 +494,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
return PipelineImpl<GemmPipelineScheduler::Interwave>{}
|
||||
.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const BDataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
|
||||
@@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout);
|
||||
}
|
||||
else if constexpr(is_same<decltype(layout), tensor_layout::convolution::NHWC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::KYXC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout);
|
||||
}
|
||||
else if constexpr(is_same<decltype(layout), tensor_layout::convolution::NHWC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::KYXC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user