fix uint32

This commit is contained in:
coderfeli
2025-03-17 01:18:32 +00:00
parent da2659d502
commit bccc5192cf
2 changed files with 17 additions and 5 deletions

View File

@@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, uint32_t, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
@@ -211,6 +211,18 @@ int main(int argc, char* argv[])
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
sorted_tile_num = std::stoi(argv[7]);
valid_tile_num = std::stoi(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
@@ -236,10 +248,10 @@ int main(int argc, char* argv[])
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = 0;//eids[i];
}
if(tokens * topk > valid_size)
{

View File

@@ -1587,7 +1587,7 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
IndexType token_offset = fused_token & 0xffffff;
float weight = token_offset < problem.NumTokens
float weight = token_offset < static_cast<IndexType>(problem.NumTokens)
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
: 0.0;
if constexpr(IsInputGemm)
@@ -2102,7 +2102,7 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = token_offset < problem.NumTokens
float weight = token_offset < static_cast<IndexType>(problem.NumTokens)
? p_sorted_weights_0[token_offset * problem.StrideDs[0]]
: 0.0;
// float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];