mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
support index type
This commit is contained in:
@@ -182,7 +182,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<4, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ck::index_t, Nswizzle, true, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,ck::index_t, false, false, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -41,7 +41,8 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t GatherDim = 1,
|
||||
typename ScatterIdxType = index_t,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
{
|
||||
@@ -59,7 +60,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
|
||||
const StaticallyIndexedArray<ScatterIdxType, gather_num>& gather_offsets)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
@@ -178,25 +179,26 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1_gather<decltype(thread_slice_lengths),
|
||||
SrcElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
GatherDim,
|
||||
NumThreadScratch>;
|
||||
SrcElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
ScatterIdxType,
|
||||
GatherDim,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
@@ -42,8 +42,9 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
typename ScatterIdxType = index_t,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
@@ -51,14 +52,15 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
static constexpr index_t nDim =
|
||||
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
|
||||
|
||||
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}) ; // Dirty HACK FELIX, TODO fix
|
||||
static constexpr index_t mod_num =
|
||||
ThreadClusterLengths{}.At(Number<3>{}); // Dirty HACK FELIX, TODO fix
|
||||
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
|
||||
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
|
||||
static constexpr index_t scatter_num = thread_slice_lengths.At(Number<ScatterDim>{});
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter(
|
||||
const SrcDescs& src_descs,
|
||||
@@ -108,13 +110,20 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
const auto src_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; },
|
||||
[&](auto i) {
|
||||
return src_block_slice_origins[i] +
|
||||
src_thread_cluster_idx * thread_slice_lengths;
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index( OutputScatter ? ThreadGroup::GetThreadId() % mod_num : ThreadGroup::GetThreadId()));
|
||||
make_multi_index(OutputScatter ? ThreadGroup::GetThreadId() % mod_num
|
||||
: ThreadGroup::GetThreadId()));
|
||||
const auto dst_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
|
||||
[&](auto i) {
|
||||
return dst_block_slice_origins[i] +
|
||||
dst_thread_cluster_idx * thread_slice_lengths;
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
|
||||
@@ -125,7 +134,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
@@ -141,16 +150,18 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
|
||||
threadwise_transfer_.RunWrite(
|
||||
dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
|
||||
else
|
||||
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
|
||||
threadwise_transfer_.RunWrite(
|
||||
dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,8 +170,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
|
||||
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
@@ -206,24 +217,25 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v7r3_scatter<SrcDatas,
|
||||
DstDatas,
|
||||
SrcDescs,
|
||||
DstDescs,
|
||||
ElementwiseOperation,
|
||||
DstInMemOps,
|
||||
decltype(thread_slice_lengths),
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVectors,
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
ScatterDim,
|
||||
OutputScatter,
|
||||
ScatterWeightIdx,
|
||||
NumThreadScratch>;
|
||||
DstDatas,
|
||||
SrcDescs,
|
||||
DstDescs,
|
||||
ElementwiseOperation,
|
||||
DstInMemOps,
|
||||
decltype(thread_slice_lengths),
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVectors,
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
ScatterIdxType,
|
||||
ScatterDim,
|
||||
OutputScatter,
|
||||
ScatterWeightIdx,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
@@ -66,8 +66,9 @@ template <typename ALayout,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
typename ScatterIdxType = index_t,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
@@ -86,59 +87,59 @@ struct DeviceMoeGemm
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm =
|
||||
GridwiseMoeGemm<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
NSwizzle,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
using GridwiseGemm =
|
||||
GridwiseMoeGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ScatterIdxType,
|
||||
NSwizzle,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
|
||||
@@ -146,7 +146,8 @@ template <typename ALayout,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
bool NSwizzle = false,
|
||||
typename ScatterIdxType = index_t,
|
||||
bool NSwizzle = false,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ADataType,
|
||||
@@ -1211,7 +1212,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
StaticallyIndexedArray<ScatterIdxType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1243,37 +1244,37 @@ struct GridwiseMoeGemm
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
|
||||
ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
ScatterIdxType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
@@ -1523,16 +1524,16 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
1, //ScatterDim
|
||||
true, //OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
ScatterIdxType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
@@ -1567,7 +1568,8 @@ struct GridwiseMoeGemm
|
||||
const float *p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<ScatterIdxType, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
|
||||
@@ -1717,7 +1719,8 @@ struct GridwiseMoeGemm
|
||||
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
StaticallyIndexedArray<ScatterIdxType, AMRepeats>
|
||||
gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1749,37 +1752,37 @@ struct GridwiseMoeGemm
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
1,
|
||||
2>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
|
||||
ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
ScatterIdxType,
|
||||
1,
|
||||
2>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
@@ -2035,16 +2038,16 @@ struct GridwiseMoeGemm
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
1, //ScatterDim
|
||||
true, //OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
ScatterIdxType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
@@ -2079,7 +2082,8 @@ struct GridwiseMoeGemm
|
||||
const float *p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<ScatterIdxType, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
|
||||
|
||||
@@ -31,8 +31,8 @@ template <typename SliceLengths,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector_,
|
||||
index_t DstScalarPerVector_,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
@@ -41,7 +41,8 @@ template <typename SliceLengths,
|
||||
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
index_t GatherDim = 1,
|
||||
typename ScatterIdxType = index_t,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
{
|
||||
@@ -54,31 +55,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
static constexpr auto I10 = Number<10>{};
|
||||
static constexpr auto I12 = Number<12>{};
|
||||
static constexpr auto I13 = Number<13>{};
|
||||
static constexpr auto I14 = Number<14>{};
|
||||
static constexpr auto I16 = Number<16>{};
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
|
||||
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr index_t gather_num = SliceLengths{}.At(Number<GatherDim>{});
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather(
|
||||
@@ -88,29 +65,26 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
|
||||
const StaticallyIndexedArray<ScatterIdxType, gather_num>& gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op),
|
||||
gather_offsets_(gather_offsets)
|
||||
{
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
|
||||
"SrcData != DstData");
|
||||
|
||||
static_assert(
|
||||
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
|
||||
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
|
||||
|
||||
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
|
||||
auto adjusted_origin_idx = [&]() {
|
||||
Index idx;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
|
||||
});
|
||||
return idx;
|
||||
}();
|
||||
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
@@ -134,15 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim];
|
||||
constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim];
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
@@ -210,19 +183,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
auto gather_offset = gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
auto gather_offset =
|
||||
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
// maintain a container record is_src_valid, waiting for RunWrite use.
|
||||
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize();//hack felix, todo use coord
|
||||
//coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512);
|
||||
const bool is_src_valid =
|
||||
ld_offset <
|
||||
src_desc
|
||||
.GetElementSpaceSize(); // hack felix, todo use coord
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
|
||||
// src_coord_) && (gather_offset < 32*512);
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
// if(threadIdx.x==0)
|
||||
// printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number<ordered_gather_dim>{}](), src_coord_.GetOffset(), gather_offset );
|
||||
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(ld_offset, true)};
|
||||
|
||||
@@ -236,22 +213,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
else if constexpr(is_detected<is_pack4_invocable_t,
|
||||
decltype(src_element_op_)>::value)
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
else if constexpr(is_detected<is_pack2_invocable_t,
|
||||
decltype(src_element_op_)>::value)
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
@@ -269,14 +241,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<dst_vector_t>(src_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// if(1) {
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// static_for<0, SrcScalarPerVector, 1>{}([&](auto idx) {
|
||||
// printf("tid %d %f\n",threadIdx.x, type_convert<float>(src_vector_container.template AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
|
||||
auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
@@ -288,9 +254,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
move_on_dim_(i) &= i.value != ordered_gather_dim;
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, move_on_dim_(i), ordered_gather_dim);
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
@@ -298,9 +261,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
();
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
// if(threadIdx.x==0)
|
||||
// printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]);
|
||||
if constexpr(move_on_dim[i])
|
||||
if(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
@@ -313,10 +274,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
// if(threadIdx.x==0)
|
||||
// printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset());
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
@@ -349,7 +307,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
|
||||
// OOB Check
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -420,8 +378,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
|
||||
"in-register transpose is not supported for pk_i4_t");
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
@@ -482,12 +438,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto packed_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
|
||||
|
||||
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
@@ -515,7 +466,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -609,16 +560,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset() / PackedSize,
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// if(1) {
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
|
||||
// printf("tid %d off %d valid %d val %f\n",threadIdx.x, dst_coord_.GetOffset(), is_dst_valid, type_convert<float>(dst_vector_container.template AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
@@ -669,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -714,7 +659,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; });
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i];
|
||||
});
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
@@ -726,7 +673,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -811,7 +758,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -860,7 +807,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -871,7 +818,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -951,7 +898,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
DstCoord dst_coord_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
StaticallyIndexedArray<ScatterIdxType, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -43,8 +43,9 @@ template <typename SrcDatas,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
typename ScatterIdxType = index_t,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
@@ -61,8 +62,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nDst = DstDescs::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
static constexpr index_t scatter_num = SliceLengths{}.At(Number<ScatterDim>{});
|
||||
using Index = MultiIndex<nDim>;
|
||||
static constexpr index_t scatter_num = SliceLengths{}.At(Number<ScatterDim>{});
|
||||
|
||||
// return a tuple of coordiantes for a tuple of tensor
|
||||
template <typename Descs,
|
||||
@@ -127,7 +128,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
{
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
|
||||
// printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x, dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1], dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3], dst_coords_(i).GetOffset());
|
||||
// printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x,
|
||||
// dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1],
|
||||
// dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3],
|
||||
// dst_coords_(i).GetOffset());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -154,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// loop over space-filling curve
|
||||
@@ -173,14 +177,19 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
src_coords_[i]);
|
||||
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if (i.value == ScatterWeightIdx)
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1, "scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number<iScatter>{}));
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights(Number<iScatter>{}); });
|
||||
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value,
|
||||
// scatter_weights(Number<iScatter>{}));
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
else if constexpr(SrcScalarPerVectors{}[i] == 1)
|
||||
{
|
||||
@@ -189,7 +198,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
|
||||
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x,
|
||||
// i.value, src_coords_[i].GetOffset(), tmp);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
@@ -415,7 +425,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
OOBCheck(thread_scratch_id);
|
||||
@@ -423,36 +433,37 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
auto scatter_offset = 0;
|
||||
if constexpr (OutputScatter)
|
||||
if constexpr(OutputScatter)
|
||||
{
|
||||
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
constexpr auto iScatter =
|
||||
DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
scatter_offset = scatter_offsets(Number<iScatter>{});
|
||||
}
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(), scatter_offset );
|
||||
// printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(),
|
||||
// scatter_offset );
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_offset,
|
||||
is_dst_valid,
|
||||
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
|
||||
// static_for<0, 1, 1>{}([&](auto idx) {
|
||||
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid,
|
||||
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset,
|
||||
// is_dst_valid, type_convert<float>(dst_vectors[i].template
|
||||
// AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
});
|
||||
@@ -468,18 +479,20 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
|
||||
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
|
||||
// ordered_gather_dim);
|
||||
});
|
||||
|
||||
return step_;
|
||||
}
|
||||
();
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
move_tensor_coordinate(dst_descs[i],
|
||||
dst_coords_(i),
|
||||
make_tensor_coordinate_step(dst_descs[i], forward_step_scatter));
|
||||
move_tensor_coordinate(
|
||||
dst_descs[i],
|
||||
dst_coords_(i),
|
||||
make_tensor_coordinate_step(dst_descs[i], forward_step_scatter));
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -508,8 +521,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
|
||||
StaticallyIndexedArray<ScatterIdxType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
RunWrite(dst_descs, dst_bufs, scatter_offsets);
|
||||
@@ -535,15 +548,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
|
||||
constexpr auto reset_step =
|
||||
DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
|
||||
auto reset_step_scatter = [&]() constexpr
|
||||
{
|
||||
Index step_;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
|
||||
|
||||
step_(i) =
|
||||
(i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
|
||||
// ordered_gather_dim);
|
||||
});
|
||||
|
||||
return step_;
|
||||
@@ -683,18 +699,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
auto adjusted_step_idx_scatter = [&]()
|
||||
{
|
||||
auto adjusted_step_idx_scatter = [&]() {
|
||||
Index step_;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
|
||||
step_(i) =
|
||||
(i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
|
||||
});
|
||||
|
||||
return step_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
|
||||
const auto adjusted_step =
|
||||
make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
|
||||
|
||||
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user