Stream-K Reduction option as Runtime parameter and Compilation Error Fix (SK- Reduction) (#2145)

* reduction is passed as runtime parameter

* clang

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp

Co-authored-by: John Afaganis <john.afaganis@amd.com>

* Update include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp


* remove comment

---------
This commit is contained in:
Muhammed Emin Ozturk
2025-06-11 10:59:44 -07:00
committed by GitHub
parent 06e0b8436c
commit 6fad1c4874
7 changed files with 216 additions and 101 deletions

View File

@@ -15,6 +15,8 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
@@ -57,8 +59,9 @@ struct ProblemSizeStreamK_universal final
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;
ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
ck::StreamKReductionStrategy reduction_strategy = ck::StreamKReductionStrategy::Atomic;
};
struct ProblemSizeSplitK final
@@ -173,7 +176,19 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
if(argc >= 11)
{
problem_size.Streamk_sel = std::stoi(argv[10]);
problem_size.Grid_size = std::stoi(argv[11]);
if(argc >= 12)
{
problem_size.Grid_size = std::stoi(argv[11]);
if(argc >= 13)
{
int reduction_strategy = std::stoi(argv[12]);
problem_size.reduction_strategy = reduction_strategy == 0
? ck::StreamKReductionStrategy::Atomic
: ck::StreamKReductionStrategy::Reduction;
}
}
}
}
else
@@ -185,7 +200,9 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
<< std::endl
<< "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
<< std::endl
<< "arg11: Grid_size(-1 for max occupancy)" << std::endl
<< "arg12: Reduction strategy (0: Atomic, 1: Reduction)" << std::endl;
return false;
}

View File

@@ -21,6 +21,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto Grid_size = problem_size.Grid_size;
auto Streamk_sel = problem_size.Streamk_sel;
auto reduction_strategy = problem_size.reduction_strategy;
if(reduction_strategy == ck::StreamKReductionStrategy::Atomic)
{
std::cout << "Using Atomic reduction strategy" << std::endl;
}
else
{
std::cout << "Using Parallel reduction strategy" << std::endl;
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
@@ -152,7 +162,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Grid_size,
a_element_op,
b_element_op,
c_element_op);
c_element_op,
reduction_strategy);
if(!gemm.IsSupportedArgument(argument))
{
@@ -242,7 +253,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
<< " GB/s, " << gemm.GetTypeString()
<< (reduction_strategy == ck::StreamKReductionStrategy::Atomic ? " (Atomic)"
: " (Reduction)")
<< std::endl;
}
return pass;
}

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
namespace ck {
namespace tensor_operation {
@@ -20,21 +21,22 @@ template <typename ALayout,
typename CElementwiseOperation>
struct DeviceGemm_Streamk_V2 : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t Streamk_sel,
ck::index_t Grid_size,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t Streamk_sel,
ck::index_t Grid_size,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

View File

@@ -149,8 +149,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
{
hip_check_error(hipMemsetAsync(
@@ -198,26 +197,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else
{
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(arg.reduction_strategy == StreamKReductionStrategy::Atomic)
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
else if(arg.reduction_strategy == StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore =
reinterpret_cast<char*>(arg.p_workspace_) +
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
sizeof(GemmAccDataType));
auto preprocess = [&]() {
hipMemsetAsync(
hipError_t status = hipMemsetAsync(
workspace_semaphore,
0,
// sizeof(uint32_t),
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
stream_config.stream_id_);
// Check the status
hip_check_error(status);
};
ave_time = launch_and_time_kernel_with_preprocess(
@@ -437,8 +437,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(p_arg->reduction_strategy == StreamKReductionStrategy::Reduction)
{
return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType));
}
@@ -491,20 +490,22 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
static auto
MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic)
{
constexpr index_t minimum_occupancy =
@@ -705,26 +706,39 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
}
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size};
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
streamk_sel,
Grid_size,
reduction_strategy};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
@@ -736,7 +750,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
StrideB,
StrideC,
streamk_sel,
Grid_size);
Grid_size,
reduction_strategy);
}
// polymorphic

View File

@@ -1415,12 +1415,11 @@ template <uint32_t MPerBlock_,
index_t M01_ = 4>
struct BlockToCTileMap_GemmStreamK_v2
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
@@ -1433,10 +1432,17 @@ struct BlockToCTileMap_GemmStreamK_v2
MDiv k_iters_per_tile;
MDiv equiv_tiles_big; // for reduction
MDiv equiv_tiles_little; // for reduction
StreamKReductionStrategy reduction_strategy;
// prefer construct on host
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
uint32_t m,
uint32_t n,
uint32_t k,
uint32_t grid_size = 1,
uint32_t streamk_sel = 1,
StreamKReductionStrategy reduction_strategy_ = StreamKReductionStrategy::Atomic)
: reduction_strategy(reduction_strategy_)
{
// total output tiles
@@ -1546,7 +1552,7 @@ struct BlockToCTileMap_GemmStreamK_v2
// Using multiple blocks for parallel reduction
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if(reduction_strategy == ck::StreamKReductionStrategy::Reduction)
{
// Add additional safety checks
if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
@@ -1589,7 +1595,7 @@ struct BlockToCTileMap_GemmStreamK_v2
__host__ __device__ index_t get_grid_dims() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if(reduction_strategy == StreamKReductionStrategy::Reduction)
{
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
return reduction_start_block_idx + get_sk_tiles();

View File

@@ -513,7 +513,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideB_,
index_t StrideC_,
index_t Streamk_sel_,
index_t Grid_size_)
index_t Grid_size_,
StreamKReductionStrategy reduction_strategy_)
: M{M_},
N{N_},
K{K_},
@@ -522,6 +523,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
StrideC{StrideC_},
Streamk_sel{Streamk_sel_},
Grid_size{Grid_size_},
reduction_strategy{reduction_strategy_}, // Initialize the member variable
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KRead{CalculateKRead(K_, 1)},
@@ -550,8 +552,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel
<< ", Grid size:" << Grid_size << "}" << std::endl;
<< "NBlock: " << NBlock << ", "
<< "Stream-K Selection:" << Streamk_sel << ", "
<< "Grid size:" << Grid_size << ", "
<< "Reduction Strategy:"
<< (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic"
: "Reduction")
<< "}" << std::endl;
}
index_t M;
@@ -562,6 +569,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideC;
index_t Streamk_sel;
mutable index_t Grid_size;
StreamKReductionStrategy reduction_strategy;
index_t MPadded;
index_t NPadded;
index_t KRead;
@@ -585,13 +593,26 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideB_,
index_t StrideC_,
index_t Streamk_sel_,
index_t Grid_size_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
index_t Grid_size_,
StreamKReductionStrategy reduction_strategy_)
: Problem{M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
Streamk_sel_,
Grid_size_,
reduction_strategy_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
block_2_ctile_map_streamk(M_,
N_,
AK0Number * CalculateKPadded(K_, 1),
Grid_size_,
Streamk_sel_,
reduction_strategy_)
{
}
@@ -1267,11 +1288,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
problem.Streamk_sel,
problem.reduction_strategy);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
@@ -1286,6 +1309,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
@@ -1301,8 +1325,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
@@ -1890,8 +1913,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
@@ -1903,8 +1925,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
else if(problem.reduction_strategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
@@ -1936,8 +1958,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
});
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
@@ -1952,8 +1973,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
@@ -2008,7 +2028,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
problem.Streamk_sel,
problem.reduction_strategy);
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
@@ -2027,8 +2048,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
@@ -2644,8 +2664,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
@@ -2657,8 +2676,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
else if(problem.reduction_strategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
@@ -2693,16 +2712,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{

50
include/ck/utility/dynamic_buffer.hpp Normal file → Executable file
View File

@@ -139,7 +139,8 @@ struct DynamicBuffer
template <InMemoryDataOperationEnum Op,
typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x)
{
@@ -159,7 +160,37 @@ struct DynamicBuffer
{
auto tmp = this->template Get<X>(i, is_valid_element);
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
// handle bfloat addition
#if defined(__gfx942__) || defined(__gfx950__)
// Properly handle addition for all low-precision types
if constexpr(is_same_v<scalar_t, bhalf_t> || is_same_v<scalar_t, half_t>)
{
if constexpr(is_scalar_type<X>::value)
{
// Scalar type: Convert to float, add, convert back
auto result =
type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
this->template Set<X>(i, is_valid_element, result);
}
else
{
// Vector type
constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
const vector_type<scalar_t, vector_size> a_vector{tmp};
const vector_type<scalar_t, vector_size> b_vector{x};
// Process each element of the vector in higher precision
static_for<0, vector_size, 1>{}([&](auto idx) {
auto result = type_convert<scalar_t>(
type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
this->template Set<scalar_t>(i + idx, is_valid_element, result);
});
}
}
#else
// handle bfloat addition
if constexpr(is_same_v<scalar_t, bhalf_t>)
{
if constexpr(is_scalar_type<X>::value)
@@ -187,6 +218,8 @@ struct DynamicBuffer
{
this->template Set<X>(i, is_valid_element, x + tmp);
}
#endif
}
}
@@ -240,9 +273,20 @@ struct DynamicBuffer
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
using vector_t = typename vector_type_maker<remove_cvref_t<T>, t_per_x>::type::type;
vector_t tmp;
if constexpr(is_same_v<remove_cvref_t<X>, vector_t>)
{
tmp = x;
}
else
{
__builtin_memcpy(&tmp, &x, sizeof(vector_t));
}
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
tmp, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&