fix per token quant

This commit is contained in:
coderfeli
2025-03-04 14:24:22 +00:00
parent a007ce04e7
commit 7cca528aa7
3 changed files with 15 additions and 36 deletions

View File

@@ -224,7 +224,7 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 0};
ck::index_t KBatch = 1;
@@ -257,8 +257,8 @@ int main(int argc, char* argv[])
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
// expert_ids.savetxt("expert_ids.txt", "int");
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));

View File

@@ -311,12 +311,12 @@ int main(int argc, char* argv[])
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
d2_e_n.savetxt("d2_e_n.txt", "int");
// a0_t_k_k.savetxt("a.txt");
// expert_ids.savetxt("expert_ids.txt", "int");
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
// d0_t_n.savetxt("d0_t_n.txt", "int");
// d1_e_n.savetxt("d1_e_n.txt", "int");
// d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
@@ -434,8 +434,8 @@ int main(int argc, char* argv[])
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
e_t_n_device_result.savetxt("out.txt");
e_t_n_host_result.savetxt("ref.txt");
// e_t_n_device_result.savetxt("out.txt");
// e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0

View File

@@ -1492,7 +1492,7 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
constexpr index_t scatter_weight_idx = 1;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -1576,7 +1576,7 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
@@ -2000,7 +2000,7 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
constexpr index_t scatter_weight_idx = 1;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -2084,7 +2084,7 @@ struct GridwiseMoeGemm
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
if constexpr(IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
@@ -2139,27 +2139,6 @@ struct GridwiseMoeGemm
});
}
}
// template <bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// bool IsInputGemm = true,
// TailNumber TailNum = TailNumber::Odd>
// __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
// const index_t* p_sorted_expert_ids,
// const index_t* p_max_token_id,
// const ADataType* p_a_grid,
// const BDataType* p_b_grid,
// DsGridPointer& p_ds_grid,
// CDataType* p_c_grid,
// void* p_shared,
// void* p_shared1,
// const Problem& problem,
// AElementwiseOperation a_element_op,
// BElementwiseOperation b_element_op,
// CElementwiseOperation c_element_op)
// {
// }
};
} // namespace ck