merge ck moe merge

This commit is contained in:
coderfeli
2025-03-03 14:58:15 +00:00
16 changed files with 883 additions and 237 deletions

View File

@@ -1,9 +1,19 @@
list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
target_compile_options(example_moe_gemm1 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS})
target_compile_options(example_moe_gemm2 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS})
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
add_example_executable(example_moe_gemm1 moe_gemm1.cpp)
add_example_executable(example_moe_gemm2 moe_gemm2.cpp)
add_example_executable(example_moe_pk_i4_gemm1 moe_pk_i4_gemm1.cpp)
set(EXAMPLE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker -g -fverbose-asm)
target_compile_options(example_moe_pk_i4_gemm1 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
add_example_executable(example_moe_pk_i4_gemm2 moe_pk_i4_gemm2.cpp)

View File

@@ -133,8 +133,8 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t NXDLPerWave = 1;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
@@ -174,10 +174,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
4, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>,
// clang-format on
@@ -197,8 +198,6 @@ int main(int argc, char* argv[])
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 544;
ck::index_t topk = 2;
@@ -217,6 +216,17 @@ int main(int argc, char* argv[])
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else if(argc == 9) {
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
sorted_tile_num = std::stoi(argv[7]);
valid_tile_num = std::stoi(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
@@ -227,6 +237,8 @@ int main(int argc, char* argv[])
exit(0);
}
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
if (tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
@@ -246,8 +258,10 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
int eids[] = {0, 0,1,2, 3,3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = eids[i];
}
@@ -287,8 +301,8 @@ int main(int argc, char* argv[])
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
@@ -296,12 +310,21 @@ int main(int argc, char* argv[])
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
// d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
// d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());

View File

@@ -64,7 +64,11 @@ struct MulABScale
const float& d0,
const float& d1) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
#else
e = ck::type_convert<EDataType>(c * d1 * d0);
#endif
}
};
@@ -84,7 +88,11 @@ struct MulABScaleSilu
{
// act
float x0 = 0;
#if CK_USE_PK4_LAYOUT_SHUFFLE
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16);
#else
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
#endif
e = ck::type_convert<EDataType>(x0);
}
};
@@ -131,13 +139,13 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
#if 0
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t MPerBlock = 64;
static constexpr ck::index_t MXDLPerWave = 1;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
@@ -154,8 +162,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, NXDLPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
// clang-format on
@@ -167,13 +175,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, 128, 128, 64,
256, MPerBlock, 128, 128,
16, 32,
32, 32,
4, 1,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
4, 1, S<1, 32, 1, 8>, S<4, 1, 1>,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
2, 1, S<1, 32, 1, 8>, S<4, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
// clang-format on
#endif
@@ -359,6 +367,7 @@ int main(int argc, char* argv[])
}
#endif
#if CK_USE_PK4_LAYOUT_SHUFFLE
// vector pk_i4x4 permute
for(int e = 0; e < experts; e++)
{
@@ -410,6 +419,7 @@ int main(int argc, char* argv[])
}
}
}
#endif
b0_device_buf.ToDevice(b0_preshuffled.mData.data());

View File

@@ -69,7 +69,12 @@ struct MulABScaleExpertWeight
//for real kernel use
//warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix
(void) d0;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d1 * d2 * 16);
#else
e = ck::type_convert<EDataType>(c * d1 * d2);
#endif
}
// for reference cpu
template <>
@@ -81,7 +86,11 @@ struct MulABScaleExpertWeight
const float& d2) const
{
// for reference cpu
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d0 * d1 * d2 * 16);
#else
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
#endif
}
};
@@ -137,7 +146,7 @@ static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t CShuffleNLane = 32;
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
@@ -151,7 +160,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
MNPerXDL, MNPerXDL,
MXDLPerWave, NXDLPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
MXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
// clang-format on
@@ -320,6 +329,7 @@ int main(int argc, char* argv[])
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, device_op.GetPreShuffleParameters());
#if CK_USE_PK4_LAYOUT_SHUFFLE
// vector pk_i4x4 permute
for(int e = 0; e < experts; e++)
{
@@ -371,6 +381,7 @@ int main(int argc, char* argv[])
}
}
}
#endif
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
@@ -443,8 +454,19 @@ int main(int argc, char* argv[])
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k_k,
b0_e_n_k,
d0_t_n,
d1_e_n,
d2_e_n,
c_t_n,
PassThrough{},
PassThrough{},
cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)

View File

@@ -194,17 +194,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_ds_read_inst_a = ck::math::integer_divide_ceil(num_ds_read_inst_a,MRepeat);
constexpr auto staged_num_mfma = ck::math::integer_divide_ceil(num_mfma , MRepeat);
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_ds_read_a = ck::math::integer_divide_ceil(staged_num_mfma , staged_num_ds_read_inst_a);
if constexpr(stage.value == 0)
{
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
constexpr auto staged_num_buffer_load_b_per_ds_read_a = ck::math::integer_divide_ceil(
num_buffer_load_inst_b , staged_num_ds_read_inst_a);
constexpr auto staged_num_mfma_per_buffer_load_b =ck::math::integer_divide_ceil(
staged_num_mfma , num_buffer_load_inst_b);
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;

View File

@@ -190,7 +190,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});

View File

@@ -65,16 +65,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op,
scatter_offsets,
scatter_weights)
element_op)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
@@ -129,12 +125,13 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
}
}
@@ -144,15 +141,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
}
}
@@ -160,10 +158,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <index_t ISrc>

View File

@@ -247,6 +247,7 @@ struct DeviceMoeGemm
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
// now");
constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
@@ -279,7 +280,6 @@ struct DeviceMoeGemm
// }
// else
{
constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel = kernel_moe_gemm<
@@ -304,8 +304,9 @@ struct DeviceMoeGemm
}
}
}
// else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
// {
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
// if(arg.KBatch > 1)
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
@@ -332,31 +333,29 @@ struct DeviceMoeGemm
// }
// }
// else
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel =
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// true,
// IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// else
// {
// const auto kernel =
// kernel_moe_gemm_gather_2lds<
// GridwiseGemm,
// true,
// IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
// }
// }
// }
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
else
{
throw std::runtime_error("todo: only v1 & v2 support now");

View File

@@ -79,13 +79,30 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
return res.template AsType<half4_t>()[Number<0>{}];
}
__device__ inline f8x4_t i4_to_f8x4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
int lo = amd_assembly_and_b32(q, LO);
int hi = amd_assembly_and_b32(q, HI);
float f32_0 = amd_assemble_cvt_f32_i4(lo);
float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16);
float f32_2 = amd_assemble_cvt_f32_i4(hi);
float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16);
// vector_type<f8_t, 4> res;
// res.template AsType<f8x4_t>()(Number<0>{}) = amd_assemble_cvt_f8_f32(f32_1st, f32_2nd, f32_3rd, f32_4th);
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
}
__device__ inline f8x8_t i4_to_fp8x8(int q)
{
vector_type<f8_t, 8> res;
res.template AsType<f8x8_t>()(Number<0>{}) = amd_assembly_i4_to_fp8x2(q);
return res.template AsType<f8x8_t>()[Number<0>{}];
// f8x8_t res;
// amd_assembly_i4_to_fp8x8(res, q);
// return res;
return amd_assembly_i4_to_fp8x8(q);
}
__device__ inline bhalf4_t i4_to_bhalf4(int q)
@@ -154,13 +171,55 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<f8_t, 8> result;
y = i4_to_fp8x8(bit_cast<int>(x));
result.template AsType<f8x8_t>()(Number<0>{}) = i4_to_fp8x8(bit_cast<int>(x));
// vector_type<f8_t, 8> result;
y = result.template AsType<f8x8_t>()[Number<0>{}];
// result.template AsType<f8x4_t>()(Number<0>{}) = i4_to_f8x4(bit_cast<int>(x));
// result.template AsType<f8x4_t>()(Number<1>{}) = i4_to_f8x4(bit_cast<int>(x) >> 8);
// y = result.template AsType<f8x8_t>()[Number<0>{}];
#else
// Added pk_i4_t to f8x2_fnuz_t conversion
vector_type<f8_t, 8> dst;
vector_type<float, 8> dst_tmp;
vector_type<pk_i4_t, 4> src{x};
// pk_i4_t to float2_t conversion
dst_tmp.template AsType<float2_t>()(Number<0>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst_tmp.template AsType<float2_t>()(Number<1>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst_tmp.template AsType<float2_t>()(Number<2>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst_tmp.template AsType<float2_t>()(Number<3>{}) =
type_convert<float2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
// float to f8_t conversion
dst.template AsType<f8_t>()(Number<0>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<0>{}]);
dst.template AsType<f8_t>()(Number<1>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<1>{}]);
dst.template AsType<f8_t>()(Number<2>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<2>{}]);
dst.template AsType<f8_t>()(Number<3>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<3>{}]);
dst.template AsType<f8_t>()(Number<4>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<4>{}]);
dst.template AsType<f8_t>()(Number<5>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<5>{}]);
dst.template AsType<f8_t>()(Number<6>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<6>{}]);
dst.template AsType<f8_t>()(Number<7>{}) =
type_convert<f8_t>(dst_tmp.template AsType<float>()[Number<7>{}]);
y = dst.template AsType<f8x8_t>()[Number<0>{}];
#endif
}

View File

@@ -62,39 +62,43 @@ __global__ void
#endif // end of if (defined(__gfx9__))
}
// template <typename GridwiseGemm,
// bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// index_t MinimumOccupancy = 1,
// TailNumber TailNum = TailNumber::Even>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
// #endif
// // __attribute__((amdgpu_waves_per_eu(1, 1)))
// kernel_moe_gemm_gather_2lds(typename GridwiseGemm::Argument karg)
// {
// #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_ds_grid,
// karg.p_c_grid,
// p_shared,
// p_shared1,
// karg,
// karg.a_element_op,
// karg.b_element_op,
// karg.c_element_op);
// #else
// ignore = karg;
// #endif // end of if (defined(__gfx9__))
// }
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
}
template <typename ALayout,
typename BLayout,
@@ -743,40 +747,19 @@ struct GridwiseMoeGemm
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) /APackedSize < 1
? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeA);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor(make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
make_tuple(make_xor_with_modulo_transform(
make_tuple(Number<MPerBlock>{}, Number<AK0Number>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
return a_lds_block_desc_permuted;
}
else // ColumnMajor A
{
@@ -1508,32 +1491,6 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights_0 = p_ds_grid[I0];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
@@ -1569,9 +1526,7 @@ struct GridwiseMoeGemm
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op,
scatter_offsets,
scatter_weights};
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -1600,8 +1555,37 @@ struct GridwiseMoeGemm
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
// if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
// printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
});
block_sync_lds();
// each thread write its data from VGPR to LDS
@@ -1619,7 +1603,522 @@ struct GridwiseMoeGemm
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
tie(c_grid_buf),
scatter_offsets,
scatter_weights
);
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
IsInputGemm? problem.NumTokens * problem.TopK : problem.NumTokens , problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if (expert_block_id * MPerBlock >= max_token_id)
return;
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
const auto block_mn = [&]() -> std::pair<int, int> {
if constexpr (NSwizzle)
{
// const index_t expert_block_id = blockIdx.x / problem.NBlock; //
// const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
// const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
// const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
// const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
// const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
// const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
// if(threadIdx.x==0)
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, es, p_sorted_expert_ids[expert_block_id]);
const index_t ecnt_prefix = p_max_token_id[1+expert_id];
const index_t prefix_block = ecnt_prefix * problem.NBlock;
const index_t ecnt = p_max_token_id[2+expert_id] - ecnt_prefix;
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; //p_max_token_id[expert_id + 1]; // 2
const index_t bid_new = blockIdx.x - prefix_block;
const index_t nid = __builtin_amdgcn_readfirstlane(bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
// if(threadIdx.x==0)
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, ecnt, expert_id);
return {nid, mid};
} else {
return {blockIdx.x, blockIdx.y};
}
}();
const index_t block_n_id = block_mn.first;
const index_t block_m_id = block_mn.second;
// if (threadIdx.x==0) {
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
// }
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
if constexpr (!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize());
// if(threadIdx.x==0)
// printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1,
2>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType *ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
if (i.value == 1)
{
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x % 16 ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return make_dynamic_buffer<AddressSpaceEnum::Global>(
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_m_id, 0, block_n_id, 0);
// return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim
true, //OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
>
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
// if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
// printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
});
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf),
scatter_offsets,
scatter_weights
);
if constexpr(access_id < num_access - 1)
{
@@ -1644,8 +2143,12 @@ struct GridwiseMoeGemm
// template <bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// bool IsInputGemm = true,
// TailNumber TailNum = TailNumber::Odd>
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
// __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
// const index_t* p_sorted_expert_ids,
// const index_t* p_max_token_id,
// const ADataType* p_a_grid,
// const BDataType* p_b_grid,
// DsGridPointer& p_ds_grid,
// CDataType* p_c_grid,
@@ -1656,37 +2159,7 @@ struct GridwiseMoeGemm
// BElementwiseOperation b_element_op,
// CElementwiseOperation c_element_op)
// {
// // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
// // Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// // p_a_grid,
// // p_b_grid,
// // p_ds_grid,
// // p_c_grid,
// // p_shared,
// // p_shared1,
// // problem,
// // a_element_op,
// // b_element_op,
// // c_element_op,
// // block_2_ctile_map);
// }
// template <typename Block2CTileMap,
// bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// TailNumber TailNum = TailNumber::Odd>
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
// const BDataType* p_b_grid,
// DsGridPointer& p_ds_grid,
// CDataType* p_c_grid,
// void* p_shared,
// void* p_shared1,
// const Problem& problem,
// AElementwiseOperation a_element_op,
// BElementwiseOperation b_element_op,
// CElementwiseOperation c_element_op,
// const Block2CTileMap& block_2_ctile_map)
// {
// }
};

View File

@@ -100,14 +100,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
const ElementwiseOperation& element_op)
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
element_op_(element_op),
scatter_offsets_(scatter_offsets),
scatter_weights_(scatter_weights)
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! cannot evenly divide");
@@ -158,6 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
@@ -181,9 +178,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number<iScatter>{}));
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number<iScatter>{}));
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights_(Number<iScatter>{}); });
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights(Number<iScatter>{}); });
}
else if constexpr(SrcScalarPerVectors{}[i] == 1)
{
@@ -418,6 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
OOBCheck(thread_scratch_id);
@@ -430,13 +428,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
if constexpr (OutputScatter)
{
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
scatter_offset = scatter_offsets_(Number<iScatter>{});
scatter_offset = scatter_offsets(Number<iScatter>{});
}
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();//hack felix, todo use coord
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
@@ -449,11 +447,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
dst_offset,
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
// if(1) {
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
// static_for<0, 1, 1>{}([&](auto idx) {
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
// using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid,
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid,
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
// });
// }
@@ -509,10 +507,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
@@ -683,8 +683,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
auto adjusted_step_idx_scatter = [&]()
{
Index step_;
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
});
return step_;
}
();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}
@@ -709,8 +719,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
SrcCoords src_coords_;
DstCoords dst_coords_;
const ElementwiseOperation element_op_;
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
StaticallyIndexedArray<float, scatter_num> scatter_weights_;
};
} // namespace ck

View File

@@ -11,6 +11,13 @@
namespace ck {
inline __device__ int amd_assembly_and_b32(int a, int b)
{
int c;
asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{
int c;
@@ -32,7 +39,24 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
return c;
}
inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a)
inline __device__ float amd_assemble_cvt_f32_i4(int b)
{
float a;
asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b));
return a;
}
inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3)
{
f8x4_t a;
asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n"
"v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n"
: "=v"(a)
: "v"(b0), "v"(b1), "v"(b2), "v"(b3));
return a;
}
inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a)
{
uint32_t i4x8 = static_cast<uint32_t>(a);
uint32_t fp8x4_0;
@@ -60,14 +84,7 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a)
[v_src] "+v"(i4x8)
:);
union
{
uint64_t as_uint64;
f8x8_t as_f8x8;
} convert;
convert.as_uint64 = (static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0;
return convert.as_f8x8;
return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
}
// c0 += inner_product(a, b0)

View File

@@ -1191,11 +1191,15 @@ struct vector_type<T, 32, typename ck::enable_if_t<is_native_type<T>()>>
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
} data_ = { .d32_ = {0} };
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__attribute__((host)) __attribute__((device)) constexpr vector_type() { }
__attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { }
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
// __host__ __device__ constexpr vector_type() : data_{type{0}} {}
// __host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const

View File

@@ -877,19 +877,11 @@ struct MoeSortingKernel
}
__syncthreads();
}
if (tid == 0) {
//temp hack ptr for expert tile cnt
p_total_tokens_post_pad[1] = 0;
}
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
int e_start = smem_cumsum(i_e);
int e_end = smem_cumsum(i_e + 1);
//temp hack ptr for expert tile cnt
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + i_e;
p_sorted_expert_cnts[1] = unit_size_mdiv.div(e_end);
int expert_id = [&]() {
if constexpr(Problem::LocalExpertMasking)
{
@@ -994,11 +986,18 @@ struct MoeSortingKernel
__syncthreads();
}
if (tid == 0) {
//temp hack ptr for expert tile cnt
p_total_tokens_post_pad[1] = 0;
}
// add the skip number
for(int eid = tid; eid < num_experts; eid += block_size)
{
//temp hack ptr for expert tile cnt
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + eid;
int e_start = smem_cumsum(eid);
int e_end = smem_cumdup(eid + 1);
p_sorted_expert_cnts[1] = unit_size_mdiv.div(e_end);
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
@@ -1682,6 +1681,8 @@ struct MoeSortingMultiPhaseKernel_P2
if(position < kargs.num_experts)
{
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + position;//temp mock for p_sorted_expert_cnts, fixme:felix
p_sorted_expert_cnts[0] = out_0;
p_expert_cumsum[position] = out_0 * kargs.unit_size_mdiv.divisor;
}
@@ -1710,6 +1711,7 @@ struct MoeSortingMultiPhaseKernel_P2
{
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
p_total_tokens_post_pad[0] = total_tokens_post_pad;
p_total_tokens_post_pad[kargs.num_experts+1] = prev_cumsum_a; //temp mock for p_sorted_expert_cnts, fixme:felix
p_expert_cumsum[kargs.num_experts] = total_tokens_post_pad;
}
}

View File

@@ -91,7 +91,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_a = i4_to_f32_gfx9(i4);
#else
v_a = i4 - 8;
#endif
}
else
{
@@ -106,7 +110,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_b = i4_to_f32_gfx9(i4);
#else
v_b = i4 - 8;
#endif
}
else
{

View File

@@ -106,7 +106,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_a = i4_to_f32_gfx9(i4);
#else
v_a = i4 - 8;
#endif
}
else
{
@@ -120,7 +124,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_b = i4_to_f32_gfx9(i4);
#else
v_b = i4 - 8;
#endif
}
else
{
@@ -198,6 +206,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
return str.str();
}
#if CK_USE_PK4_LAYOUT_SHUFFLE
static float i4_to_f32_gfx9(uint8_t i4)
{
static std::unordered_map<uint8_t, float> u = {{0b1000, -0.5000f},
@@ -219,6 +228,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
return u[i4];
}
#endif
};
} // namespace host