mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
input output all ok
This commit is contained in:
@@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::long_index_t, A0DataType>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -235,12 +235,12 @@ int main(int argc, char* argv[])
|
||||
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size};
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// max_token_id.mData = {valid_size};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = 0;
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
|
||||
@@ -159,7 +159,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
@@ -176,7 +176,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
|
||||
@@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -41,6 +41,7 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
@@ -58,7 +59,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
|
||||
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
@@ -190,6 +191,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
IndexType,
|
||||
GatherDim,
|
||||
NumThreadScratch>;
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ template <typename ThreadGroup,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
typename IndexType,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
@@ -149,7 +150,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
@@ -169,7 +170,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
@@ -230,6 +231,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
IndexType,
|
||||
ScatterDim,
|
||||
OutputScatter,
|
||||
ScatterWeightIdx,
|
||||
|
||||
@@ -67,6 +67,7 @@ template <typename ALayout,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
@@ -132,6 +133,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
NSwizzle,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
|
||||
@@ -148,6 +148,7 @@ template <typename ALayout,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
bool NSwizzle = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ADataType,
|
||||
@@ -298,7 +299,7 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
|
||||
IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
@@ -491,7 +492,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <typename ELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(long_index_t M, long_index_t MPad, long_index_t N, long_index_t NPad, long_index_t StrideC)
|
||||
MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
@@ -1210,7 +1211,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<long_index_t, AMRepeats> gather_offsets;
|
||||
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1218,7 +1219,7 @@ struct GridwiseMoeGemm
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
|
||||
@@ -1226,7 +1227,7 @@ struct GridwiseMoeGemm
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
const auto a_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
@@ -1261,6 +1262,7 @@ struct GridwiseMoeGemm
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -1519,6 +1521,7 @@ struct GridwiseMoeGemm
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
@@ -1528,8 +1531,16 @@ struct GridwiseMoeGemm
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// using BufferType = std::conditional_t<
|
||||
// std::is_same_v<IndexType, long_index_t>,
|
||||
// decltype(make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())),
|
||||
// decltype(make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()))
|
||||
// >;
|
||||
auto c_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// BufferType c_grid_buf = std::is_same_v<IndexType, long_index_t> ?
|
||||
// make_long_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) :
|
||||
// make_dynamic_buffer<AddressSpaceEnum::Global>(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
@@ -1563,7 +1574,7 @@ struct GridwiseMoeGemm
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<long_index_t, EMRepeats>
|
||||
StaticallyIndexedArray<IndexType, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
@@ -1575,7 +1586,7 @@ struct GridwiseMoeGemm
|
||||
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
long_index_t token_offset = fused_token & 0xffffff;
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
float weight = token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
|
||||
: 0.0;
|
||||
@@ -1588,7 +1599,7 @@ struct GridwiseMoeGemm
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
@@ -1713,7 +1724,7 @@ struct GridwiseMoeGemm
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
|
||||
token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<long_index_t, AMRepeats>
|
||||
StaticallyIndexedArray<IndexType, AMRepeats>
|
||||
gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
@@ -1722,7 +1733,7 @@ struct GridwiseMoeGemm
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = token_offset * problem.K;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
|
||||
@@ -1730,7 +1741,7 @@ struct GridwiseMoeGemm
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
const auto a_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
@@ -1765,6 +1776,7 @@ struct GridwiseMoeGemm
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
2>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -2029,6 +2041,7 @@ struct GridwiseMoeGemm
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
@@ -2038,7 +2051,10 @@ struct GridwiseMoeGemm
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_long_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
auto c_grid_buf = std::is_same_v<IndexType, long_index_t> ?
|
||||
make_long_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()):
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
@@ -2073,7 +2089,7 @@ struct GridwiseMoeGemm
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<long_index_t, EMRepeats>
|
||||
StaticallyIndexedArray<IndexType, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
@@ -2099,7 +2115,7 @@ struct GridwiseMoeGemm
|
||||
const float* p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
});
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ template <typename SliceLengths,
|
||||
bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
@@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
|
||||
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
@@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
auto gather_offset =
|
||||
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, true);
|
||||
|
||||
@@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
DstCoord dst_coord_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
StaticallyIndexedArray<long_index_t, gather_num> gather_offsets_;
|
||||
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -43,6 +43,7 @@ template <typename SrcDatas,
|
||||
index_t DstScalarPerVector,
|
||||
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
|
||||
typename IndexType,
|
||||
index_t ScatterDim = 1,
|
||||
bool OutputScatter = true,
|
||||
index_t ScatterWeightIdx = 3,
|
||||
@@ -412,7 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
OOBCheck(thread_scratch_id);
|
||||
@@ -421,7 +422,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
// loop over space-filling curve
|
||||
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
|
||||
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
|
||||
long_index_t scatter_offset = 0;
|
||||
IndexType scatter_offset = 0;
|
||||
if constexpr(OutputScatter)
|
||||
{
|
||||
constexpr auto iScatter =
|
||||
@@ -431,7 +432,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
long_index_t dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
@@ -490,7 +491,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<IndexType, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
|
||||
@@ -225,8 +225,8 @@ struct DynamicBuffer
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
#if 0
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
#if CK_USE_AMD_BUFFER_LOAD
|
||||
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -349,7 +349,7 @@ struct DynamicBuffer
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
// if(i >= 2169041600)
|
||||
*c_style_pointer_cast<X*>(p_data_ + i) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -380,12 +380,13 @@ struct DynamicBuffer
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
|
||||
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
sizeof(IndexType) <= sizeof(int32_t) && (
|
||||
is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
|
||||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
|
||||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
@@ -424,7 +425,7 @@ struct DynamicBuffer
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user