mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
fix per token quant
This commit is contained in:
@@ -224,7 +224,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
@@ -257,8 +257,8 @@ int main(int argc, char* argv[])
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
// expert_ids.savetxt("expert_ids.txt", "int");
|
||||
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
|
||||
@@ -311,12 +311,12 @@ int main(int argc, char* argv[])
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
a0_t_k_k.savetxt("a.txt");
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
d1_e_n.savetxt("d1_e_n.txt", "int");
|
||||
d2_e_n.savetxt("d2_e_n.txt", "int");
|
||||
// a0_t_k_k.savetxt("a.txt");
|
||||
// expert_ids.savetxt("expert_ids.txt", "int");
|
||||
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
// d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
// d1_e_n.savetxt("d1_e_n.txt", "int");
|
||||
// d2_e_n.savetxt("d2_e_n.txt", "int");
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
@@ -434,8 +434,8 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
e_t_n_device_result.savetxt("out.txt");
|
||||
e_t_n_host_result.savetxt("ref.txt");
|
||||
// e_t_n_device_result.savetxt("out.txt");
|
||||
// e_t_n_host_result.savetxt("ref.txt");
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
|
||||
@@ -1492,7 +1492,7 @@ struct GridwiseMoeGemm
|
||||
using CDEBlockTransferCluster =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
|
||||
constexpr index_t scatter_weight_idx = 1;
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
@@ -1576,7 +1576,7 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
|
||||
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
@@ -2000,7 +2000,7 @@ struct GridwiseMoeGemm
|
||||
using CDEBlockTransferCluster =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
|
||||
constexpr index_t scatter_weight_idx = 1;
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
@@ -2084,7 +2084,7 @@ struct GridwiseMoeGemm
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
|
||||
float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]];
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
@@ -2139,27 +2139,6 @@ struct GridwiseMoeGemm
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// template <bool HasMainKBlockLoop,
|
||||
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
// bool IsInputGemm = true,
|
||||
// TailNumber TailNum = TailNumber::Odd>
|
||||
// __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
|
||||
// const index_t* p_sorted_expert_ids,
|
||||
// const index_t* p_max_token_id,
|
||||
// const ADataType* p_a_grid,
|
||||
// const BDataType* p_b_grid,
|
||||
// DsGridPointer& p_ds_grid,
|
||||
// CDataType* p_c_grid,
|
||||
// void* p_shared,
|
||||
// void* p_shared1,
|
||||
// const Problem& problem,
|
||||
// AElementwiseOperation a_element_op,
|
||||
// BElementwiseOperation b_element_op,
|
||||
// CElementwiseOperation c_element_op)
|
||||
// {
|
||||
|
||||
// }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user