support index type

This commit is contained in:
coderfeli
2025-03-10 09:56:15 +00:00
parent 402a8443d9
commit dd42d8e8fa
8 changed files with 322 additions and 340 deletions

View File

@@ -182,7 +182,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, S<4, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ck::index_t, Nswizzle, true, A0DataType>;
// clang-format on
#endif

View File

@@ -162,7 +162,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,ck::index_t, false, false, A0DataType>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -41,7 +41,8 @@ template <typename ThreadGroup,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t GatherDim = 1,
typename ScatterIdxType = index_t,
index_t GatherDim = 1,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_mod8
{
@@ -59,7 +60,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
const StaticallyIndexedArray<ScatterIdxType, gather_num>& gather_offsets)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
@@ -178,25 +179,26 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1_gather<decltype(thread_slice_lengths),
SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
GatherDim,
NumThreadScratch>;
SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
ScatterIdxType,
GatherDim,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};

View File

@@ -42,8 +42,9 @@ template <typename ThreadGroup,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1,
bool OutputScatter = true,
typename ScatterIdxType = index_t,
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3_scatter
@@ -51,14 +52,15 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}) ; // Dirty HACK FELIX, TODO fix
static constexpr index_t mod_num =
ThreadClusterLengths{}.At(Number<3>{}); // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter(
const SrcDescs& src_descs,
@@ -108,13 +110,20 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; },
[&](auto i) {
return src_block_slice_origins[i] +
src_thread_cluster_idx * thread_slice_lengths;
},
Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index( OutputScatter ? ThreadGroup::GetThreadId() % mod_num : ThreadGroup::GetThreadId()));
make_multi_index(OutputScatter ? ThreadGroup::GetThreadId() % mod_num
: ThreadGroup::GetThreadId()));
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
[&](auto i) {
return dst_block_slice_origins[i] +
dst_thread_cluster_idx * thread_slice_lengths;
},
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
@@ -125,7 +134,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
@@ -141,16 +150,18 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
threadwise_transfer_.RunWrite(
dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
threadwise_transfer_.RunWrite(
dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
}
}
@@ -159,8 +170,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
@@ -206,24 +217,25 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVectors,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim,
OutputScatter,
ScatterWeightIdx,
NumThreadScratch>;
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVectors,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterIdxType,
ScatterDim,
OutputScatter,
ScatterWeightIdx,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};

View File

@@ -66,8 +66,9 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
bool IsInputGemm = true,
typename ScatterIdxType = index_t,
bool NSwizzle = false,
bool IsInputGemm = true,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
@@ -86,59 +87,59 @@ struct DeviceMoeGemm
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm =
GridwiseMoeGemm<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
NSwizzle,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
using GridwiseGemm =
GridwiseMoeGemm<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ScatterIdxType,
NSwizzle,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
using Argument = typename GridwiseGemm::Argument;

View File

@@ -146,7 +146,8 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
typename ScatterIdxType = index_t,
bool NSwizzle = false,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
@@ -1211,7 +1212,7 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
StaticallyIndexedArray<ScatterIdxType, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -1243,37 +1244,37 @@ struct GridwiseMoeGemm
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
ScatterIdxType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
@@ -1523,16 +1524,16 @@ struct GridwiseMoeGemm
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim
true, //OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
ScatterIdxType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -1567,7 +1568,8 @@ struct GridwiseMoeGemm
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<ScatterIdxType, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
@@ -1717,7 +1719,8 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
StaticallyIndexedArray<ScatterIdxType, AMRepeats>
gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -1749,37 +1752,37 @@ struct GridwiseMoeGemm
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1,
2>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
ScatterIdxType,
1,
2>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
@@ -2035,16 +2038,16 @@ struct GridwiseMoeGemm
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim
true, //OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
ScatterIdxType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -2079,7 +2082,8 @@ struct GridwiseMoeGemm
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<ScatterIdxType, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;

View File

@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector_,
index_t DstScalarPerVector_,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
@@ -41,7 +41,8 @@ template <typename SliceLengths,
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
index_t GatherDim = 1,
typename ScatterIdxType = index_t,
index_t GatherDim = 1,
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1_gather
{
@@ -54,31 +55,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I12 = Number<12>{};
static constexpr auto I13 = Number<13>{};
static constexpr auto I14 = Number<14>{};
static constexpr auto I16 = Number<16>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
static constexpr auto I0 = Number<0>{};
static constexpr index_t gather_num = SliceLengths{}.At(Number<GatherDim>{});
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather(
@@ -88,29 +65,26 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
const StaticallyIndexedArray<ScatterIdxType, gather_num>& gather_offsets)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
dst_element_op_(dst_element_op),
gather_offsets_(gather_offsets)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
auto adjusted_origin_idx = [&]() {
Index idx;
static_for<0, nDim, 1>{}([&](auto i) {
idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
});
return idx;
}();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
@@ -134,15 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim];
constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim];
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
@@ -210,19 +183,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
auto gather_offset = gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
auto gather_offset =
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
// maintain a container record is_src_valid, waiting for RunWrite use.
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize();//hack felix, todo use coord
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512);
const bool is_src_valid =
ld_offset <
src_desc
.GetElementSpaceSize(); // hack felix, todo use coord
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
// src_coord_) && (gather_offset < 32*512);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// if(threadIdx.x==0)
// printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number<ordered_gather_dim>{}](), src_coord_.GetOffset(), gather_offset );
auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
@@ -236,22 +213,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
else
{
return 1;
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
@@ -269,14 +241,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<dst_vector_t>(src_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
// if(1) {
// using print_vec_t = typename vector_type<DstData, 1>::type;
// static_for<0, SrcScalarPerVector, 1>{}([&](auto idx) {
// printf("tid %d %f\n",threadIdx.x, type_convert<float>(src_vector_container.template AsType<print_vec_t>()[idx]));
// });
// }
constexpr auto move_on_dim = [&]() constexpr
auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
@@ -288,9 +254,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
move_on_dim_(i) &= i.value != ordered_gather_dim;
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, move_on_dim_(i), ordered_gather_dim);
});
return move_on_dim_;
@@ -298,9 +261,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
// if(threadIdx.x==0)
// printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]);
if constexpr(move_on_dim[i])
if(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
@@ -313,10 +274,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
// if(threadIdx.x==0)
// printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset());
});
});
// move src coordinate back to slice origin (or not)
@@ -349,7 +307,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -420,8 +378,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
@@ -482,12 +438,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
}
else
{
constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
static_ford<SliceLengths>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
}
@@ -515,7 +466,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
@@ -609,16 +560,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset() / PackedSize,
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
// if(1) {
// using print_vec_t = typename vector_type<DstData, 1>::type;
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
// printf("tid %d off %d valid %d val %f\n",threadIdx.x, dst_coord_.GetOffset(), is_dst_valid, type_convert<float>(dst_vector_container.template AsType<print_vec_t>()[idx]));
// });
// }
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
@@ -669,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -714,7 +659,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; });
static_for<0, nDim, 1>{}([&](auto i) {
reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i];
});
return reset_src_data_step_;
}();
@@ -726,7 +673,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
@@ -811,7 +758,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -860,7 +807,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -871,7 +818,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
@@ -951,7 +898,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
StaticallyIndexedArray<ScatterIdxType, gather_num> gather_offsets_;
};
} // namespace ck

View File

@@ -43,8 +43,9 @@ template <typename SrcDatas,
index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
index_t ScatterDim = 1,
bool OutputScatter = true,
typename ScatterIdxType = index_t,
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v7r3_scatter
@@ -61,8 +62,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>;
static constexpr index_t scatter_num = SliceLengths{}.At(Number<ScatterDim>{});
using Index = MultiIndex<nDim>;
static constexpr index_t scatter_num = SliceLengths{}.At(Number<ScatterDim>{});
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
@@ -127,7 +128,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
{
static_for<0, nDst, 1>{}([&](auto i) {
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
// printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x, dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1], dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3], dst_coords_(i).GetOffset());
// printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x,
// dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1],
// dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3],
// dst_coords_(i).GetOffset());
});
}
@@ -154,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
@@ -173,14 +177,19 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_[i]);
oob_val = oob_val & is_src_valid;
if (i.value == ScatterWeightIdx)
if(i.value == ScatterWeightIdx)
{
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
"scatter weight dim, should only one vec");
constexpr auto iScatter =
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number<iScatter>{}));
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights(Number<iScatter>{}); });
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value,
// scatter_weights(Number<iScatter>{}));
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
src_vectors(i).template AsType<float>()(j) =
scatter_weights(Number<iScatter>{});
});
}
else if constexpr(SrcScalarPerVectors{}[i] == 1)
{
@@ -189,7 +198,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x,
// i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
}
@@ -415,7 +425,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
OOBCheck(thread_scratch_id);
@@ -423,36 +433,37 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
auto scatter_offset = 0;
if constexpr (OutputScatter)
if constexpr(OutputScatter)
{
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
constexpr auto iScatter =
DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
scatter_offset = scatter_offsets(Number<iScatter>{});
}
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
// if(threadIdx.x==0)
// printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(), scatter_offset );
// printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(),
// scatter_offset );
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_offset,
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
// static_for<0, 1, 1>{}([&](auto idx) {
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
// using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid,
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset,
// is_dst_valid, type_convert<float>(dst_vectors[i].template
// AsType<print_vec_t>()[idx]));
// });
// }
});
@@ -468,18 +479,20 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
// ordered_gather_dim);
});
return step_;
}
();
static_for<0, nDst, 1>{}([&](auto i) {
move_tensor_coordinate(dst_descs[i],
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step_scatter));
move_tensor_coordinate(
dst_descs[i],
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step_scatter));
});
}
});
@@ -508,8 +521,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
@@ -535,15 +548,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
}
else
{
constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
constexpr auto reset_step =
DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
auto reset_step_scatter = [&]() constexpr
{
Index step_;
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
step_(i) =
(i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
// ordered_gather_dim);
});
return step_;
@@ -683,18 +699,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
auto adjusted_step_idx_scatter = [&]()
{
auto adjusted_step_idx_scatter = [&]() {
Index step_;
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
step_(i) =
(i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
});
return step_;
}
();
}();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
const auto adjusted_step =
make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}