Stream-K Reduction option as Runtime parameter and Compilation Error Fix (SK- Reduction) (#2145)

* reduction is passed as runtime parameter

* clang

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp

Co-authored-by: John Afaganis <john.afaganis@amd.com>

* Update include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp


* remove comment

---------
This commit is contained in:
Muhammed Emin Ozturk
2025-06-11 10:59:44 -07:00
committed by GitHub
parent 06e0b8436c
commit 6fad1c4874
7 changed files with 216 additions and 101 deletions

View File

@@ -149,8 +149,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
{
hip_check_error(hipMemsetAsync(
@@ -198,26 +197,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else
{
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
else if(arg.reduction_strategy == StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore =
reinterpret_cast<char*>(arg.p_workspace_) +
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
sizeof(GemmAccDataType));
auto preprocess = [&]() {
hipMemsetAsync(
hipError_t status = hipMemsetAsync(
workspace_semaphore,
0,
// sizeof(uint32_t),
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
stream_config.stream_id_);
// Check the status
hip_check_error(status);
};
ave_time = launch_and_time_kernel_with_preprocess(
@@ -437,8 +437,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(p_arg->reduction_strategy == StreamKReductionStrategy::Reduction)
{
return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
}
@@ -491,20 +490,22 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
static auto
MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic)
{
constexpr index_t minimum_occupancy =
@@ -705,26 +706,39 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
}
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size};
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
streamk_sel,
Grid_size,
reduction_strategy};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
@@ -736,7 +750,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
StrideB,
StrideC,
streamk_sel,
Grid_size);
Grid_size,
reduction_strategy);
}
// polymorphic