mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
moe gemm2 bpreshuffle function passed
This commit is contained in:
@@ -175,7 +175,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
@@ -183,14 +183,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 64,
|
||||
MPerBlock, 32, KPerBlock,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
2, 2,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
@@ -202,14 +202,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 2;
|
||||
constexpr ck::index_t valid_tile_num = 2;
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 2;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
@@ -351,30 +351,44 @@ int main(int argc, char* argv[])
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 8:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
|
||||
@@ -965,6 +965,54 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
#if 0
|
||||
printf("blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: "
|
||||
"%d, ikxdl: %d, a_thread_vec=<0x%08x, 0x%08x, 0x%08x, "
|
||||
"0x%08x>, "
|
||||
"b_thread_vec=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, "
|
||||
"a_scale=0x%08x, "
|
||||
"b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
im_minor,
|
||||
in_minor,
|
||||
ik_minor,
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(a_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(a_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(a_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(a_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec.template AsType<f4x8_t>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec.template AsType<f4x8_t>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec.template AsType<f4x8_t>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec.template AsType<f4x8_t>()[Number<3>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(a_scale_thread_vec
|
||||
.template AsType<AScaleDataType>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_scale_thread_vec
|
||||
.template AsType<BScaleDataType>()[Number<0>{}]))),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<0>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<1>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<2>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<3>{}]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
|
||||
|
||||
@@ -42,12 +42,12 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
@@ -79,12 +79,12 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Even>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -885,7 +885,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
{
|
||||
// contiguous in LDS
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<AK0Number>{}, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
|
||||
@@ -2025,17 +2025,30 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
problem.NPadded,
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(problem.M / (MXdlPack * MPerXdl),
|
||||
// We pad the M unconditionaly for Scale
|
||||
const auto Padded_Scale_M =
|
||||
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / APackedSize)) *
|
||||
MPerXdl * MXdlPack / scale_pack_size_a,
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a,
|
||||
1));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / BPackedSize)) *
|
||||
NPerXdl * NXdlPack / scale_pack_size_b,
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b,
|
||||
1));
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -2102,23 +2115,15 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
|
||||
|
||||
// Gride buffer creation
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
#if 0
|
||||
printf("blkx: %u, blky: %u, tidx: %u, a_grid_size: %ld\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
#endif
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
const auto b_grid_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::SYSTEM_NT1>(
|
||||
p_b_grid + expert_id * expert_stride,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -2585,16 +2590,18 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin =
|
||||
@@ -2615,40 +2622,41 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = 3; // hack fix felix
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray type
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CDEBlockTransferCluster,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
|
||||
3, // index_t SrcVectorDim,
|
||||
3, // index_t DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
|
||||
// Sequence support
|
||||
// arbitray type
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CDEBlockTransferCluster,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
|
||||
3, // index_t SrcVectorDim,
|
||||
3, // index_t DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
IndexType,
|
||||
1, // ScatterDim
|
||||
true, // OutputScatter: false, only use scatter weights
|
||||
scatter_weight_idx // ScatterWeightIdx: ascale
|
||||
>{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
Reference in New Issue
Block a user