mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Stream-K Reduction option as Runtime parameter and Compilation Error Fix (SK- Reduction) (#2145)
* reduction is passed as runtime parameter * clang * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp Co-authored-by: John Afaganis <john.afaganis@amd.com> * Update include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp * remove comment ---------
This commit is contained in:
committed by
GitHub
parent
06e0b8436c
commit
6fad1c4874
@@ -149,8 +149,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
hip_check_error(hipMemsetAsync(
|
||||
@@ -198,26 +197,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
|
||||
}
|
||||
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if(arg.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
char* workspace_semaphore =
|
||||
reinterpret_cast<char*>(arg.p_workspace_) +
|
||||
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
|
||||
sizeof(GemmAccDataType));
|
||||
auto preprocess = [&]() {
|
||||
hipMemsetAsync(
|
||||
hipError_t status = hipMemsetAsync(
|
||||
workspace_semaphore,
|
||||
0,
|
||||
// sizeof(uint32_t),
|
||||
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
|
||||
stream_config.stream_id_);
|
||||
|
||||
// Check the status
|
||||
hip_check_error(status);
|
||||
};
|
||||
|
||||
ave_time = launch_and_time_kernel_with_preprocess(
|
||||
@@ -437,8 +437,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(p_arg->reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
|
||||
}
|
||||
@@ -491,20 +490,22 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
static auto
|
||||
MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
@@ -705,26 +706,39 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
}
|
||||
}
|
||||
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size};
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
streamk_sel,
|
||||
Grid_size,
|
||||
reduction_strategy};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -736,7 +750,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
StrideB,
|
||||
StrideC,
|
||||
streamk_sel,
|
||||
Grid_size);
|
||||
Grid_size,
|
||||
reduction_strategy);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
Reference in New Issue
Block a user