change tilem

This commit is contained in:
felix
2025-10-26 08:24:24 +00:00
parent a35bc01f27
commit 4fdde500eb

View File

@@ -144,7 +144,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -156,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 4,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
@@ -171,16 +171,11 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
ck::index_t N = 256;
ck::index_t K = 7168;
ck::index_t experts = 256;
ck::index_t tokens = 64;
ck::index_t topk = 8;
if(argc == 1)
{
@@ -216,6 +211,10 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t sorted_tile_num = experts > tokens * topk ? experts : tokens * topk;
ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
@@ -231,34 +230,18 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
expert_ids.mData[i] = i / (valid_tile_num / experts);
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
@@ -457,10 +440,11 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
int valid_expert = tokens * topk < experts? tokens * topk : experts;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
sizeof(A0DataType) * tokens * K * topk/ 2 + sizeof(B0DataType) * K * N * valid_expert / 2 +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
sizeof(XDataType) * K / ScaleBlockSize * N * valid_expert + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;