mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Stream-K Reduction option as Runtime parameter and Compilation Error Fix (SK- Reduction) (#2145)
* reduction is passed as runtime parameter * clang * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp Co-authored-by: John Afaganis <john.afaganis@amd.com> * Update include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp * remove comment ---------
This commit is contained in:
committed by
GitHub
parent
06e0b8436c
commit
6fad1c4874
@@ -15,6 +15,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
@@ -57,8 +59,9 @@ struct ProblemSizeStreamK_universal final
|
||||
ck::index_t StrideB = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
|
||||
ck::index_t Grid_size = -1; // defaults to max occupancy
|
||||
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
|
||||
ck::index_t Grid_size = -1; // defaults to max occupancy
|
||||
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
|
||||
ck::StreamKReductionStrategy reduction_strategy = ck::StreamKReductionStrategy::Atomic;
|
||||
};
|
||||
|
||||
struct ProblemSizeSplitK final
|
||||
@@ -173,7 +176,19 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.Streamk_sel = std::stoi(argv[10]);
|
||||
problem_size.Grid_size = std::stoi(argv[11]);
|
||||
|
||||
if(argc >= 12)
|
||||
{
|
||||
problem_size.Grid_size = std::stoi(argv[11]);
|
||||
|
||||
if(argc >= 13)
|
||||
{
|
||||
int reduction_strategy = std::stoi(argv[12]);
|
||||
problem_size.reduction_strategy = reduction_strategy == 0
|
||||
? ck::StreamKReductionStrategy::Atomic
|
||||
: ck::StreamKReductionStrategy::Reduction;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -185,7 +200,9 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
|
||||
<< std::endl
|
||||
<< "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
|
||||
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
|
||||
<< std::endl
|
||||
<< "arg11: Grid_size(-1 for max occupancy)" << std::endl
|
||||
<< "arg12: Reduction strategy (0: Atomic, 1: Reduction)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
auto Grid_size = problem_size.Grid_size;
|
||||
auto Streamk_sel = problem_size.Streamk_sel;
|
||||
|
||||
auto reduction_strategy = problem_size.reduction_strategy;
|
||||
if(reduction_strategy == ck::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
std::cout << "Using Atomic reduction strategy" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Using Parallel reduction strategy" << std::endl;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
@@ -152,7 +162,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
Grid_size,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
c_element_op,
|
||||
reduction_strategy);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -242,7 +253,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
<< " GB/s, " << gemm.GetTypeString()
|
||||
<< (reduction_strategy == ck::StreamKReductionStrategy::Atomic ? " (Atomic)"
|
||||
: " (Reduction)")
|
||||
<< std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -20,21 +21,22 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm_Streamk_V2 : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t Streamk_sel,
|
||||
ck::index_t Grid_size,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t Streamk_sel,
|
||||
ck::index_t Grid_size,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -149,8 +149,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
hip_check_error(hipMemsetAsync(
|
||||
@@ -198,26 +197,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
|
||||
}
|
||||
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if(arg.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
char* workspace_semaphore =
|
||||
reinterpret_cast<char*>(arg.p_workspace_) +
|
||||
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
|
||||
sizeof(GemmAccDataType));
|
||||
auto preprocess = [&]() {
|
||||
hipMemsetAsync(
|
||||
hipError_t status = hipMemsetAsync(
|
||||
workspace_semaphore,
|
||||
0,
|
||||
// sizeof(uint32_t),
|
||||
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
|
||||
stream_config.stream_id_);
|
||||
|
||||
// Check the status
|
||||
hip_check_error(status);
|
||||
};
|
||||
|
||||
ave_time = launch_and_time_kernel_with_preprocess(
|
||||
@@ -437,8 +437,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
|
||||
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(p_arg->reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
|
||||
}
|
||||
@@ -491,20 +490,22 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
static auto
|
||||
MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
@@ -705,26 +706,39 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
}
|
||||
}
|
||||
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size};
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
streamk_sel,
|
||||
Grid_size,
|
||||
reduction_strategy};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -736,7 +750,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
|
||||
StrideB,
|
||||
StrideC,
|
||||
streamk_sel,
|
||||
Grid_size);
|
||||
Grid_size,
|
||||
reduction_strategy);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -1415,12 +1415,11 @@ template <uint32_t MPerBlock_,
|
||||
index_t M01_ = 4>
|
||||
struct BlockToCTileMap_GemmStreamK_v2
|
||||
{
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
|
||||
//--------------------------------------
|
||||
// pass to device
|
||||
@@ -1433,10 +1432,17 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
MDiv k_iters_per_tile;
|
||||
MDiv equiv_tiles_big; // for reduction
|
||||
MDiv equiv_tiles_little; // for reduction
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
|
||||
// prefer construct on host
|
||||
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
|
||||
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
|
||||
uint32_t m,
|
||||
uint32_t n,
|
||||
uint32_t k,
|
||||
uint32_t grid_size = 1,
|
||||
uint32_t streamk_sel = 1,
|
||||
StreamKReductionStrategy reduction_strategy_ = StreamKReductionStrategy::Atomic)
|
||||
: reduction_strategy(reduction_strategy_)
|
||||
{
|
||||
|
||||
// total output tiles
|
||||
@@ -1546,7 +1552,7 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
// Using multiple blocks for parallel reduction
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if(reduction_strategy == ck::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// Add additional safety checks
|
||||
if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
|
||||
@@ -1589,7 +1595,7 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
|
||||
__host__ __device__ index_t get_grid_dims() const
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if(reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
|
||||
return reduction_start_block_idx + get_sk_tiles();
|
||||
|
||||
75
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
75
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
@@ -513,7 +513,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t Streamk_sel_,
|
||||
index_t Grid_size_)
|
||||
index_t Grid_size_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
@@ -522,6 +523,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
StrideC{StrideC_},
|
||||
Streamk_sel{Streamk_sel_},
|
||||
Grid_size{Grid_size_},
|
||||
reduction_strategy{reduction_strategy_}, // Initialize the member variable
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
KRead{CalculateKRead(K_, 1)},
|
||||
@@ -550,8 +552,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< "AK0:" << AK0 << ", "
|
||||
<< "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel
|
||||
<< ", Grid size:" << Grid_size << "}" << std::endl;
|
||||
<< "NBlock: " << NBlock << ", "
|
||||
<< "Stream-K Selection:" << Streamk_sel << ", "
|
||||
<< "Grid size:" << Grid_size << ", "
|
||||
<< "Reduction Strategy:"
|
||||
<< (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic"
|
||||
: "Reduction")
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -562,6 +569,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideC;
|
||||
index_t Streamk_sel;
|
||||
mutable index_t Grid_size;
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
index_t KRead;
|
||||
@@ -585,13 +593,26 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t Streamk_sel_,
|
||||
index_t Grid_size_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
|
||||
index_t Grid_size_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
Streamk_sel_,
|
||||
Grid_size_,
|
||||
reduction_strategy_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
block_2_ctile_map_streamk(
|
||||
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
|
||||
block_2_ctile_map_streamk(M_,
|
||||
N_,
|
||||
AK0Number * CalculateKPadded(K_, 1),
|
||||
Grid_size_,
|
||||
Streamk_sel_,
|
||||
reduction_strategy_)
|
||||
|
||||
{
|
||||
}
|
||||
@@ -1267,11 +1288,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
problem.Streamk_sel,
|
||||
problem.reduction_strategy);
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
@@ -1286,6 +1309,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -1301,8 +1325,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
@@ -1890,8 +1913,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
@@ -1903,8 +1925,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if(problem.reduction_strategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
@@ -1936,8 +1958,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
@@ -1952,8 +1973,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
@@ -2008,7 +2028,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
problem.Streamk_sel,
|
||||
problem.reduction_strategy);
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -2027,8 +2048,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
@@ -2644,8 +2664,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
@@ -2657,8 +2676,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if(problem.reduction_strategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
@@ -2693,16 +2712,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds();
|
||||
}
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
|
||||
50
include/ck/utility/dynamic_buffer.hpp
Normal file → Executable file
50
include/ck/utility/dynamic_buffer.hpp
Normal file → Executable file
@@ -139,7 +139,8 @@ struct DynamicBuffer
|
||||
template <InMemoryDataOperationEnum Op,
|
||||
typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value ||
|
||||
!is_native_type<X>(),
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x)
|
||||
{
|
||||
@@ -159,7 +160,37 @@ struct DynamicBuffer
|
||||
{
|
||||
auto tmp = this->template Get<X>(i, is_valid_element);
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
// handle bfloat addition
|
||||
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
|
||||
// Properly handle addition for all low-precision types
|
||||
if constexpr(is_same_v<scalar_t, bhalf_t> || is_same_v<scalar_t, half_t>)
|
||||
{
|
||||
if constexpr(is_scalar_type<X>::value)
|
||||
{
|
||||
// Scalar type: Convert to float, add, convert back
|
||||
auto result =
|
||||
type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
|
||||
this->template Set<X>(i, is_valid_element, result);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Vector type
|
||||
constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
const vector_type<scalar_t, vector_size> a_vector{tmp};
|
||||
const vector_type<scalar_t, vector_size> b_vector{x};
|
||||
|
||||
// Process each element of the vector in higher precision
|
||||
static_for<0, vector_size, 1>{}([&](auto idx) {
|
||||
auto result = type_convert<scalar_t>(
|
||||
type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
|
||||
type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
|
||||
this->template Set<scalar_t>(i + idx, is_valid_element, result);
|
||||
});
|
||||
}
|
||||
}
|
||||
#else
|
||||
// handle bfloat addition
|
||||
if constexpr(is_same_v<scalar_t, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_scalar_type<X>::value)
|
||||
@@ -187,6 +218,8 @@ struct DynamicBuffer
|
||||
{
|
||||
this->template Set<X>(i, is_valid_element, x + tmp);
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,9 +273,20 @@ struct DynamicBuffer
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
using vector_t = typename vector_type_maker<remove_cvref_t<T>, t_per_x>::type::type;
|
||||
vector_t tmp;
|
||||
|
||||
if constexpr(is_same_v<remove_cvref_t<X>, vector_t>)
|
||||
{
|
||||
tmp = x;
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_memcpy(&tmp, &x, sizeof(vector_t));
|
||||
}
|
||||
|
||||
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
tmp, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
|
||||
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
|
||||
|
||||
Reference in New Issue
Block a user