mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4407 (commit adde219)
[CK][CK TILE] Add has hot loop check for pipeline v1 ## Motivation Add has hot loop check for pipeline v1 (v1 basic and v1 basic async). Enable more tests which have been fixed by this change. ## Technical Details Hot loop has been executed without num loop check. ## Test Plan test_grouped_convnd_fwd_tile ## Test Result Passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-651 AICK-663
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e88f139c6c
commit
2dd2f114b3
@@ -85,6 +85,13 @@ __device__ inline auto amd_wave_read_first_lane(const Object& obj)
|
||||
return out;
|
||||
}
|
||||
|
||||
// Overload for host to return the same value
|
||||
template <typename T>
|
||||
__host__ inline T amd_wave_read_first_lane(T v)
|
||||
{
|
||||
return v;
|
||||
}
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
struct __attribute__((packed)) buffer_resource
|
||||
|
||||
@@ -81,6 +81,13 @@ __device__ inline auto amd_wave_read_first_lane(const Object& obj)
|
||||
return out;
|
||||
}
|
||||
|
||||
// Overload for host to return the same value
|
||||
template <typename T>
|
||||
__host__ inline T amd_wave_read_first_lane(T v)
|
||||
{
|
||||
return v;
|
||||
}
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
struct __attribute__((packed)) buffer_resource
|
||||
|
||||
@@ -44,15 +44,20 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
@@ -60,12 +65,12 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
@@ -430,7 +435,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
if(HasHotLoop)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// we have had 3 global prefetches so far, indexed (0, 1, 2).
|
||||
index_t i_global_read = amd_wave_read_first_lane(3);
|
||||
|
||||
@@ -46,6 +46,12 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
|
||||
constexpr auto scenarios = []() {
|
||||
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
|
||||
return std::array<std::pair<bool, ck_tile::TailNumber>, 5>{
|
||||
@@ -62,7 +68,8 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
std::make_pair(false, TailNumber::Even),
|
||||
};
|
||||
}();
|
||||
if(has_hot_loop == scenarios[I].first && tail_number == scenarios[I].second)
|
||||
if(has_hot_loop_first_lane == scenarios[I].first &&
|
||||
tail_number_first_lane == scenarios[I].second)
|
||||
return run_func(bool_constant<scenarios[I].first>{}, constant<scenarios[I].second>{});
|
||||
else if constexpr(I + 1 < scenarios.size())
|
||||
return TailHandler<I + 1>(run_func, has_hot_loop, tail_number);
|
||||
|
||||
@@ -47,15 +47,20 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
@@ -63,12 +68,12 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
|
||||
@@ -43,15 +43,20 @@ struct BaseGemmPipelineAgBgCrCompV6
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
if(tail_number_first_lane == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
else if(tail_number_first_lane == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
@@ -59,12 +64,12 @@ struct BaseGemmPipelineAgBgCrCompV6
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
if(tail_number_first_lane == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
else if(tail_number_first_lane == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
@@ -567,7 +572,7 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
|
||||
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
|
||||
|
||||
if(HasHotLoop)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
|
||||
@@ -93,9 +93,14 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Wrap the hot_loop dispatch first.
|
||||
auto tail_dispatch = [&](auto tail_num_constant) {
|
||||
if(has_hot_loop)
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, tail_num_constant);
|
||||
}
|
||||
@@ -106,7 +111,7 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
};
|
||||
|
||||
#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \
|
||||
else if(tail_number == TailNumber::TAIL_NUMBER) \
|
||||
else if(tail_number_first_lane == TailNumber::TAIL_NUMBER) \
|
||||
{ \
|
||||
if constexpr(PrefetchStages > PREFETCH_VALUE) \
|
||||
{ \
|
||||
@@ -114,11 +119,11 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
} \
|
||||
}
|
||||
// Handle all the valid cases.
|
||||
if(tail_number == TailNumber::One)
|
||||
if(tail_number_first_lane == TailNumber::One)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::One>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Full)
|
||||
else if(tail_number_first_lane == TailNumber::Full)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
@@ -117,7 +118,8 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
@@ -268,25 +270,28 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_block_tile, b_block_tile);
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_block_tile, b_block_tile);
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
block_sync_lds_direct_load();
|
||||
|
||||
iCounter--;
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -311,12 +316,18 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -349,12 +360,17 @@ struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCReg
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -19,7 +19,10 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
@@ -27,9 +30,21 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
return run_func(ck_tile::bool_constant<true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_func(ck_tile::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -39,6 +54,7 @@ struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
@@ -137,7 +153,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
@@ -216,6 +233,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
@@ -238,10 +263,10 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
@@ -273,54 +298,57 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
iCounter--;
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -340,7 +368,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
@@ -476,50 +505,53 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
iCounter--;
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -543,13 +575,18 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -582,12 +619,17 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.operator()(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1135,6 +1135,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
@@ -906,6 +906,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
@@ -1149,6 +1149,7 @@ struct GroupedConvolutionForwardKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
Reference in New Issue
Block a user