input output all ok

This commit is contained in:
coderfeli
2025-03-15 14:26:30 +00:00
parent d1e999c05c
commit da2659d502
11 changed files with 66 additions and 41 deletions

View File

@@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::long_index_t, A0DataType>;
// clang-format on
@@ -235,12 +235,12 @@ int main(int argc, char* argv[])
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size};
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// max_token_id.mData = {valid_size};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = 0;
expert_ids.mData[i] = eids[i];
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;

View File

@@ -159,7 +159,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>;
// clang-format on
#else
static constexpr ck::index_t MPerBlock = 128;
@@ -176,7 +176,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>;
// clang-format on
#endif

View File

@@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;

View File

@@ -149,7 +149,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -41,6 +41,7 @@ template <typename ThreadGroup,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
typename IndexType,
index_t GatherDim = 1,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_gather
@@ -58,7 +59,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
@@ -190,6 +191,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
IndexType,
GatherDim,
NumThreadScratch>;

View File

@@ -42,6 +42,7 @@ template <typename ThreadGroup,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
typename IndexType,
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
@@ -149,7 +150,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
@@ -169,7 +170,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);
@@ -230,6 +231,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
IndexType,
ScatterDim,
OutputScatter,
ScatterWeightIdx,

View File

@@ -67,6 +67,7 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
bool IsInputGemm = true,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
@@ -132,6 +133,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
BlkGemmPipeSched,
BlkGemmPipelineVer,
NSwizzle,
IndexType,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,

View File

@@ -148,6 +148,7 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
@@ -298,7 +299,7 @@ struct GridwiseMoeGemm
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
@@ -491,7 +492,7 @@ struct GridwiseMoeGemm
template <typename ELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(long_index_t M, long_index_t MPad, long_index_t N, long_index_t NPad, long_index_t StrideC)
MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
@@ -1210,7 +1211,7 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<long_index_t, AMRepeats> gather_offsets;
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -1218,7 +1219,7 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
@@ -1226,7 +1227,7 @@ struct GridwiseMoeGemm
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto a_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride / BPackedSize,
@@ -1261,6 +1262,7 @@ struct GridwiseMoeGemm
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1519,6 +1521,7 @@ struct GridwiseMoeGemm
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
@@ -1528,8 +1531,16 @@ struct GridwiseMoeGemm
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
auto c_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// using BufferType = std::conditional_t<
// std::is_same_v<IndexType, long_index_t>,
// decltype(make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())),
// decltype(make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()))
// >;
auto c_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// BufferType c_grid_buf = std::is_same_v<IndexType, long_index_t> ?
// make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) :
// make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
@@ -1563,7 +1574,7 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<long_index_t, EMRepeats>
StaticallyIndexedArray<IndexType, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
@@ -1575,7 +1586,7 @@ struct GridwiseMoeGemm
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
long_index_t token_offset = fused_token & 0xffffff;
IndexType token_offset = fused_token & 0xffffff;
float weight = token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
: 0.0;
@@ -1588,7 +1599,7 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
scatter_weights(m0) = weight;
});
@@ -1713,7 +1724,7 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<long_index_t, AMRepeats>
StaticallyIndexedArray<IndexType, AMRepeats>
gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
@@ -1722,7 +1733,7 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
@@ -1730,7 +1741,7 @@ struct GridwiseMoeGemm
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto a_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride / BPackedSize,
@@ -1765,6 +1776,7 @@ struct GridwiseMoeGemm
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
2>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -2029,6 +2041,7 @@ struct GridwiseMoeGemm
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
@@ -2038,7 +2051,10 @@ struct GridwiseMoeGemm
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
auto c_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
auto c_grid_buf = std::is_same_v<IndexType, long_index_t> ?
make_long_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()):
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
@@ -2073,7 +2089,7 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<long_index_t, EMRepeats>
StaticallyIndexedArray<IndexType, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
@@ -2099,7 +2115,7 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
scatter_weights(m0) = weight;
});

View File

@@ -41,6 +41,7 @@ template <typename SliceLengths,
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
typename IndexType,
index_t GatherDim = 1,
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1_gather
@@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
@@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
auto gather_offset =
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, true);
@@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
StaticallyIndexedArray<long_index_t, gather_num> gather_offsets_;
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
};
} // namespace ck

View File

@@ -43,6 +43,7 @@ template <typename SrcDatas,
index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename IndexType,
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
@@ -412,7 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
OOBCheck(thread_scratch_id);
@@ -421,7 +422,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
long_index_t scatter_offset = 0;
IndexType scatter_offset = 0;
if constexpr(OutputScatter)
{
constexpr auto iScatter =
@@ -431,7 +432,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
long_index_t dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
@@ -490,7 +491,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);

View File

@@ -225,8 +225,8 @@ struct DynamicBuffer
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
#if 0
bool constexpr use_amd_buffer_addressing = true;
#if CK_USE_AMD_BUFFER_LOAD
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
@@ -349,7 +349,7 @@ struct DynamicBuffer
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
// if(i >= 2169041600)
*c_style_pointer_cast<X*>(p_data_ + i) = x;
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
}
}
@@ -380,12 +380,13 @@ struct DynamicBuffer
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
sizeof(IndexType) <= sizeof(int32_t) && (
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
@@ -424,7 +425,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif