mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
revert back to v1
This commit is contained in:
@@ -139,7 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Nswizzle = true;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
@@ -175,9 +175,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
4, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, Nswizzle, true, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
// DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -257,8 +258,10 @@ int main(int argc, char* argv[])
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size, 0, 2, 4, 6, 8, 10, 12, 14, 16};
|
||||
int eids[] = {0, 0,1, 1, 2,2, 3,3, 4,4, 5, 5, 6, 6, 7,7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
for (int i = 0; i < sorted_tile_num; i++) {
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
|
||||
@@ -141,6 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
@@ -182,8 +183,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
|
||||
using Base::MWaves;
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num * MWaves;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num ;
|
||||
|
||||
@@ -1492,7 +1492,6 @@ struct GridwiseMoeGemm
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
|
||||
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
@@ -1555,12 +1554,13 @@ struct GridwiseMoeGemm
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
|
||||
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
|
||||
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
|
||||
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
|
||||
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
|
||||
const float *p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
@@ -1568,8 +1568,6 @@ struct GridwiseMoeGemm
|
||||
|
||||
auto dstidx = sfc_cde_block.GetIndex(access_id);
|
||||
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
// if(threadIdx.x==0 && blockIdx.x==0)
|
||||
// printf("cidx %d %d tpos %d\n", dstidx(I0), dstidx(I1), c_token_pos);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1581,13 +1579,13 @@ struct GridwiseMoeGemm
|
||||
const float *p_sorted_weights_2 = p_ds_grid[I2];
|
||||
weight = weight * p_sorted_weights_2[c_token_pos + m0];
|
||||
}
|
||||
|
||||
// if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
|
||||
// printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_weights(m0) = weight;
|
||||
// if(threadIdx.x % 16 == 0)
|
||||
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
|
||||
});
|
||||
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
@@ -1605,7 +1603,11 @@ struct GridwiseMoeGemm
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(c_grid_buf));
|
||||
tie(c_grid_buf),
|
||||
scatter_offsets,
|
||||
scatter_weights
|
||||
);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_lds_and_global_step =
|
||||
|
||||
Reference in New Issue
Block a user