[rocm-libraries] ROCm/rocm-libraries#4407 (commit adde219)

[CK][CK TILE] Add has hot loop check for pipeline v1

## Motivation

Add has hot loop check for pipeline v1 (v1 basic and v1 basic async).
Enable more tests which have been fixed by this change.

## Technical Details

Hot loop has been executed without num loop check.

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

Passed

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
AICK-651
AICK-663
This commit is contained in:
Bartłomiej Kocot
2026-02-11 13:43:01 +00:00
committed by assistant-librarian[bot]
parent e88f139c6c
commit 2dd2f114b3
28 changed files with 352 additions and 240 deletions

View File

@@ -44,15 +44,20 @@ struct BaseGemmPipelineAgBgCrCompAsync
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
// Handle all the valid cases.
if(has_hot_loop)
if(has_hot_loop_first_lane)
{
if(tail_number == TailNumber::Three)
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Two>{});
@@ -60,12 +65,12 @@ struct BaseGemmPipelineAgBgCrCompAsync
}
else
{
if(tail_number == TailNumber::Three)
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
@@ -430,7 +435,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
Base::GlobalPrefetchAsync(
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
if(HasHotLoop)
if constexpr(HasHotLoop)
{
// we have had 3 global prefetches so far, indexed (0, 1, 2).
index_t i_global_read = amd_wave_read_first_lane(3);

View File

@@ -46,6 +46,12 @@ struct BaseGemmPipelineAgBgCrCompV3
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
constexpr auto scenarios = []() {
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
return std::array<std::pair<bool, ck_tile::TailNumber>, 5>{
@@ -62,7 +68,8 @@ struct BaseGemmPipelineAgBgCrCompV3
std::make_pair(false, TailNumber::Even),
};
}();
if(has_hot_loop == scenarios[I].first && tail_number == scenarios[I].second)
if(has_hot_loop_first_lane == scenarios[I].first &&
tail_number_first_lane == scenarios[I].second)
return run_func(bool_constant<scenarios[I].first>{}, constant<scenarios[I].second>{});
else if constexpr(I + 1 < scenarios.size())
return TailHandler<I + 1>(run_func, has_hot_loop, tail_number);

View File

@@ -47,15 +47,20 @@ struct BaseGemmPipelineAgBgCrCompV4
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
// Handle all the valid cases.
if(has_hot_loop)
if(has_hot_loop_first_lane)
{
if(tail_number == TailNumber::Three)
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Two>{});
@@ -63,12 +68,12 @@ struct BaseGemmPipelineAgBgCrCompV4
}
else
{
if(tail_number == TailNumber::Three)
if(tail_number_first_lane == TailNumber::Three)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
else if(tail_number_first_lane == TailNumber::Two)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});

View File

@@ -43,15 +43,20 @@ struct BaseGemmPipelineAgBgCrCompV6
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
// Handle all the valid cases.
if(has_hot_loop)
if(has_hot_loop_first_lane)
{
if(tail_number == TailNumber::Odd)
if(tail_number_first_lane == TailNumber::Odd)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else if(tail_number == TailNumber::Even)
else if(tail_number_first_lane == TailNumber::Even)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
@@ -59,12 +64,12 @@ struct BaseGemmPipelineAgBgCrCompV6
}
else
{
if(tail_number == TailNumber::Odd)
if(tail_number_first_lane == TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else if(tail_number == TailNumber::Even)
else if(tail_number_first_lane == TailNumber::Even)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
@@ -567,7 +572,7 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
if(HasHotLoop)
if constexpr(HasHotLoop)
{
index_t i = 0;
do

View File

@@ -93,9 +93,14 @@ struct BaseGemmPipelineAgBgCrMem
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
// Wrap the hot_loop dispatch first.
auto tail_dispatch = [&](auto tail_num_constant) {
if(has_hot_loop)
if(has_hot_loop_first_lane)
{
return run_func(bool_constant<true>{}, tail_num_constant);
}
@@ -106,7 +111,7 @@ struct BaseGemmPipelineAgBgCrMem
};
#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \
else if(tail_number == TailNumber::TAIL_NUMBER) \
else if(tail_number_first_lane == TailNumber::TAIL_NUMBER) \
{ \
if constexpr(PrefetchStages > PREFETCH_VALUE) \
{ \
@@ -114,11 +119,11 @@ struct BaseGemmPipelineAgBgCrMem
} \
}
// Handle all the valid cases.
if(tail_number == TailNumber::One)
if(tail_number_first_lane == TailNumber::One)
{
return tail_dispatch(integral_constant<TailNumber, TailNumber::One>{});
}
else if(tail_number == TailNumber::Full)
else if(tail_number_first_lane == TailNumber::Full)
{
return tail_dispatch(integral_constant<TailNumber, TailNumber::Full>{});
}

View File

@@ -16,6 +16,7 @@ namespace ck_tile {
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
@@ -117,7 +118,8 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
{
using Base = PipelineImplBase;
template <typename AsDramBlockWindowTmp,
template <bool HasHotLoop,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
@@ -268,25 +270,28 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
block_sync_lds_direct_load();
index_t iCounter = num_loop - 1;
while(iCounter > 0)
if constexpr(HasHotLoop)
{
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
block_sync_lds();
block_sync_lds();
Base::GlobalPrefetchAsync(
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
Base::GlobalPrefetchAsync(
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
Base::GlobalPrefetchAsync(
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
// GEMM i
block_gemm(c_block_tile, a_block_tile, b_block_tile);
// GEMM i
block_gemm(c_block_tile, a_block_tile, b_block_tile);
block_sync_lds_direct_load();
block_sync_lds_direct_load();
iCounter--;
iCounter--;
}
}
// tail
@@ -311,12 +316,18 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
template <typename ADramBlockWindowTmp,
@@ -349,12 +360,17 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
};

View File

@@ -19,7 +19,10 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
@@ -27,9 +30,21 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
// Use amd_wave_read_first_lane to avoid higher resource usage.
// It forces to store these values in SGPR.
// Compiler cannot deduce if one path is used for all threads
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
if(has_hot_loop_first_lane)
{
return run_func(ck_tile::bool_constant<true>{});
}
else
{
return run_func(ck_tile::bool_constant<false>{});
}
}
};
@@ -39,6 +54,7 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
@@ -137,7 +153,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
{
using Base = PipelineImplBase;
template <typename AsDramBlockWindowTmp,
template <bool HasHotLoop,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
@@ -216,6 +233,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
// Block GEMM
auto block_gemm = BlockGemm();
@@ -238,10 +263,10 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
// move to 1
// Move each A — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
// Move each B — the enhanced function move_tile_window is executed, which takes a
// tuple as input.
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -273,54 +298,57 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
}
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
if constexpr(HasHotLoop)
{
// global read i + 1
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
block_sync_lds();
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
if constexpr(is_a_col_major)
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
}
// global read i + 1
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
block_sync_lds();
// LDS write i + 1
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
}
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
iCounter--;
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// move to i + 2
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
// LDS write i + 1
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
}
// LDS write i + 1
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
}
iCounter--;
}
}
// tail
@@ -340,7 +368,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
{
using Base = PipelineImplBase;
template <typename AsDramBlockWindowTmp,
template <bool HasHotLoop,
typename AsDramBlockWindowTmp,
typename BsDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction,
@@ -476,50 +505,53 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
}
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
if constexpr(HasHotLoop)
{
// global read i + 1
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
block_sync_lds();
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// move to i + 2
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
if constexpr(is_a_col_major)
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
}
// global read i + 1
elementwise_As_res =
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
block_sync_lds();
elementwise_Bs_res =
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
// LDS write i + 1
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
}
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
iCounter--;
// move to i + 2
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
}
else
{
store_tile(a_copy_lds_window, elementwise_As_res);
}
// LDS write i + 1
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
}
else
{
store_tile(b_copy_lds_window, elementwise_Bs_res);
}
iCounter--;
}
}
// tail
@@ -543,13 +575,18 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.operator()(
a_dram_block_window_tmp,
[](auto& e, const ADataType & a) { e = a; },
b_dram_block_window_tmp,
[](auto& e, const BDataType & b) { e = b; },
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
a_dram_block_window_tmp,
element_wise::PassThrough{},
b_dram_block_window_tmp,
element_wise::PassThrough{},
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
template <typename ADramBlockWindowTmp,
@@ -582,12 +619,17 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
const auto RunPipeline = [&](auto hot_loop_) {
constexpr bool hot_loop = hot_loop_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop);
}
};

View File

@@ -1135,6 +1135,7 @@ struct GroupedConvolutionBackwardDataKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// Disable Async for other archs than gfx950
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)

View File

@@ -906,6 +906,7 @@ struct GroupedConvolutionBackwardWeightKernel
__shared__ char smem_ptr[GetSmemSize()];
// Disable Async for other archs than gfx950
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)

View File

@@ -1149,6 +1149,7 @@ struct GroupedConvolutionForwardKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// Disable Async for other archs than gfx950
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)