mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
tile m = 64 ok
This commit is contained in:
@@ -133,13 +133,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
//threadnum, mblock, nblock, kblock
|
||||
256, 32, 128, 128,
|
||||
256, 64, 128, 128,
|
||||
// ak1, bk1
|
||||
8, 8,
|
||||
// mn_perxdl
|
||||
32, 32,
|
||||
// mn_xdlperwave
|
||||
1, 1,
|
||||
2, 1,
|
||||
// a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra
|
||||
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
@@ -169,8 +169,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 8192;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 8;
|
||||
ck::index_t sorted_tile_size = 32;
|
||||
ck::index_t sorted_tile_num = 1;
|
||||
ck::index_t sorted_tile_size = 64;
|
||||
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
|
||||
ck::index_t tokens = 64;
|
||||
|
||||
@@ -368,7 +368,7 @@ int main(int argc, char* argv[])
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(
|
||||
sorted_token_ids, expert_ids, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < SORTED_SIZE; ++m)
|
||||
|
||||
@@ -176,8 +176,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
static_assert(NWave * warpSize == BlockSize);
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t Experts = 8;
|
||||
static constexpr index_t SortedTileSize = 32;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
|
||||
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
|
||||
@@ -29,15 +29,17 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ck::index_t>& sorted_token_ids,
|
||||
const Tensor<ck::index_t>& expert_ids,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<ck::index_t>& expert_ids,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: expert_ids_{expert_ids},
|
||||
sorted_token_ids_{sorted_token_ids},
|
||||
: sorted_token_ids_{sorted_token_ids},
|
||||
expert_ids_{expert_ids},
|
||||
sorted_tile_size_{sorted_tile_size},
|
||||
a_t_k_{a_t_k},
|
||||
b_e_n_k_{b_e_n_k},
|
||||
c_m_n_{c_m_n},
|
||||
@@ -56,7 +58,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t sorted_tile_size = 32;
|
||||
index_t sorted_tile_size_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -73,7 +75,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
const int t = arg.sorted_token_ids_(m);
|
||||
const int e = arg.expert_ids_(m / arg.sorted_tile_size);
|
||||
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
|
||||
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
|
||||
if(t < token_cnt) {
|
||||
for(int k = 0; k < K; ++k)
|
||||
@@ -135,6 +137,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
|
||||
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
|
||||
const Tensor<ck::index_t>& expert_ids,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
@@ -142,7 +145,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{sorted_token_ids, expert_ids, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op};
|
||||
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
Reference in New Issue
Block a user