impl int64 but result not correct

This commit is contained in:
coderfeli
2025-03-14 13:01:07 +00:00
parent d4925e1637
commit f911cf7396
8 changed files with 66 additions and 48 deletions

View File

@@ -132,7 +132,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = true;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
@@ -170,11 +170,11 @@ int main(int argc, char* argv[])
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8;
ck::index_t valid_tile_num = 8;
ck::index_t tokens = 128;
ck::index_t K = 6144;
ck::index_t experts = 1;
ck::index_t sorted_tile_num = 2000;
ck::index_t valid_tile_num = 2000;
ck::index_t tokens = 256000;
ck::index_t topk = 2;
// ck::index_t tokens = batch * topk;
@@ -237,11 +237,10 @@ int main(int argc, char* argv[])
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = 0;
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;

View File

@@ -58,7 +58,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,

View File

@@ -149,7 +149,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
@@ -169,7 +169,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);

View File

@@ -491,7 +491,7 @@ struct GridwiseMoeGemm
template <typename ELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
MakeCGridDescriptor_M_N(long_index_t M, long_index_t MPad, long_index_t N, long_index_t NPad, long_index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
@@ -1171,6 +1171,8 @@ struct GridwiseMoeGemm
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if(threadIdx.x==0 && blockIdx.y+blockIdx.x==0)
printf("%ld %ld %ld\n", a_grid_desc_ak0_m_ak1.GetElementSpaceSize(), c_grid_desc_m_n.GetElementSpaceSize(), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
if(expert_block_id * MPerBlock >= max_token_id)
return;
const index_t expert_id =
@@ -1210,7 +1212,7 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets;
StaticallyIndexedArray<long_index_t, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -1563,7 +1565,7 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats>
StaticallyIndexedArray<long_index_t, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
@@ -1713,7 +1715,7 @@ struct GridwiseMoeGemm
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats>
StaticallyIndexedArray<long_index_t, AMRepeats>
gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
@@ -2073,7 +2075,7 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats>
StaticallyIndexedArray<long_index_t, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme

View File

@@ -88,7 +88,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
@@ -935,7 +935,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
StaticallyIndexedArray<long_index_t, gather_num> gather_offsets_;
};
} // namespace ck

View File

@@ -412,7 +412,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
OOBCheck(thread_scratch_id);
@@ -431,15 +431,24 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
long_index_t dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
// if(dst_offset>80740352 && threadIdx.x==0) {
// static_for<0, 1, 1>{}([&](auto idx) {
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
// using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %ld valid %d %ld %f\n",threadIdx.x, dst_offset,
// is_dst_valid, dst_descs[i].GetElementSpaceSize(), type_convert<float>(dst_vectors[i].template
// AsType<print_vec_t>()[idx]));
// });
// }
});
// move coordinate
@@ -491,7 +500,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
{
RunRead(src_descs, src_bufs, scatter_weights);

View File

@@ -33,7 +33,7 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
static constexpr long_index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
@@ -59,21 +59,21 @@ struct DynamicBuffer
return BufferAddressSpace;
}
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr const T& operator[](long_index_t i) const { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
__host__ __device__ constexpr T& operator()(long_index_t i) { return p_data_[i]; }
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
__host__ __device__ constexpr auto Get(long_index_t i, bool is_valid_element) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -86,7 +86,7 @@ struct DynamicBuffer
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(InvalidElementUseNumericalZeroValue)
{
@@ -140,7 +140,7 @@ struct DynamicBuffer
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
__host__ __device__ void Update(long_index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == InMemoryDataOperationEnum::Set)
{
@@ -189,10 +189,10 @@ struct DynamicBuffer
}
}
template <typename DstBuffer, index_t NumElemsPerThread>
template <typename DstBuffer, long_index_t NumElemsPerThread>
__host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
index_t src_offset,
index_t dst_offset,
long_index_t src_offset,
long_index_t dst_offset,
bool is_valid_element) const
{
// Copy data from global to LDS memory using direct loads.
@@ -214,12 +214,12 @@ struct DynamicBuffer
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
__host__ __device__ void Set(long_index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -238,7 +238,7 @@ struct DynamicBuffer
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
@@ -342,12 +342,13 @@ struct DynamicBuffer
{
if(is_valid_element)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#if 0
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
// if(i >= 2169041600)
*c_style_pointer_cast<X*>(p_data_ + i) = x;
#endif
}
}
@@ -357,14 +358,14 @@ struct DynamicBuffer
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
__host__ __device__ void AtomicAdd(long_index_t i, bool is_valid_element, const X& x)
{
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -390,7 +391,7 @@ struct DynamicBuffer
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
@@ -408,12 +409,12 @@ struct DynamicBuffer
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
__host__ __device__ void AtomicMax(long_index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
@@ -429,7 +430,7 @@ struct DynamicBuffer
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);

View File

@@ -18,6 +18,13 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber<N>)
{
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
{