mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
impl int64 but result not correct
This commit is contained in:
@@ -132,7 +132,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = true;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
@@ -170,11 +170,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 8;
|
||||
ck::index_t valid_tile_num = 8;
|
||||
ck::index_t tokens = 128;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t experts = 1;
|
||||
ck::index_t sorted_tile_num = 2000;
|
||||
ck::index_t valid_tile_num = 2000;
|
||||
ck::index_t tokens = 256000;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
// ck::index_t tokens = batch * topk;
|
||||
@@ -237,11 +237,10 @@ int main(int argc, char* argv[])
|
||||
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
max_token_id.mData = {valid_size};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
expert_ids.mData[i] = 0;
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
|
||||
@@ -58,7 +58,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
|
||||
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
|
||||
@@ -149,7 +149,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
@@ -169,7 +169,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
|
||||
@@ -491,7 +491,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
template <typename ELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
MakeCGridDescriptor_M_N(long_index_t M, long_index_t MPad, long_index_t N, long_index_t NPad, long_index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
@@ -1171,6 +1171,8 @@ struct GridwiseMoeGemm
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
|
||||
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
|
||||
if(threadIdx.x==0 && blockIdx.y+blockIdx.x==0)
|
||||
printf("%ld %ld %ld\n", a_grid_desc_ak0_m_ak1.GetElementSpaceSize(), c_grid_desc_m_n.GetElementSpaceSize(), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
if(expert_block_id * MPerBlock >= max_token_id)
|
||||
return;
|
||||
const index_t expert_id =
|
||||
@@ -1210,7 +1212,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets;
|
||||
StaticallyIndexedArray<long_index_t, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1563,7 +1565,7 @@ struct GridwiseMoeGemm
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats>
|
||||
StaticallyIndexedArray<long_index_t, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
@@ -1713,7 +1715,7 @@ struct GridwiseMoeGemm
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
|
||||
token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats>
|
||||
StaticallyIndexedArray<long_index_t, AMRepeats>
|
||||
gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
@@ -2073,7 +2075,7 @@ struct GridwiseMoeGemm
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
StaticallyIndexedArray<index_t, EMRepeats>
|
||||
StaticallyIndexedArray<long_index_t, EMRepeats>
|
||||
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
|
||||
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
|
||||
// too hack here, 2 specific for topk weights, fixme
|
||||
|
||||
@@ -88,7 +88,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op,
|
||||
const StaticallyIndexedArray<index_t, gather_num>& gather_offsets)
|
||||
const StaticallyIndexedArray<long_index_t, gather_num>& gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
|
||||
src_element_op_(src_element_op),
|
||||
@@ -935,7 +935,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
DstCoord dst_coord_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
StaticallyIndexedArray<long_index_t, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -412,7 +412,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
OOBCheck(thread_scratch_id);
|
||||
@@ -431,15 +431,24 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
long_index_t dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
|
||||
// if(dst_offset>80740352 && threadIdx.x==0) {
|
||||
// static_for<0, 1, 1>{}([&](auto idx) {
|
||||
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %ld valid %d %ld %f\n",threadIdx.x, dst_offset,
|
||||
// is_dst_valid, dst_descs[i].GetElementSpaceSize(), type_convert<float>(dst_vectors[i].template
|
||||
// AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
@@ -491,7 +500,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
StaticallyIndexedArray<index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<long_index_t, scatter_num>& scatter_offsets,
|
||||
StaticallyIndexedArray<float, scatter_num>& scatter_weights)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, scatter_weights);
|
||||
|
||||
@@ -33,7 +33,7 @@ struct DynamicBuffer
|
||||
ElementSpaceSize element_space_size_;
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
static constexpr long_index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
@@ -59,21 +59,21 @@ struct DynamicBuffer
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
||||
__host__ __device__ constexpr const T& operator[](long_index_t i) const { return p_data_[i]; }
|
||||
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
__host__ __device__ constexpr T& operator()(long_index_t i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value ||
|
||||
!is_native_type<X>(),
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||
__host__ __device__ constexpr auto Get(long_index_t i, bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -86,7 +86,7 @@ struct DynamicBuffer
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
@@ -140,7 +140,7 @@ struct DynamicBuffer
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void Update(long_index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == InMemoryDataOperationEnum::Set)
|
||||
{
|
||||
@@ -189,10 +189,10 @@ struct DynamicBuffer
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t NumElemsPerThread>
|
||||
template <typename DstBuffer, long_index_t NumElemsPerThread>
|
||||
__host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
|
||||
index_t src_offset,
|
||||
index_t dst_offset,
|
||||
long_index_t src_offset,
|
||||
long_index_t dst_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
// Copy data from global to LDS memory using direct loads.
|
||||
@@ -214,12 +214,12 @@ struct DynamicBuffer
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value ||
|
||||
!is_native_type<X>(),
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void Set(long_index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -238,7 +238,7 @@ struct DynamicBuffer
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
@@ -342,12 +342,13 @@ struct DynamicBuffer
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
#if 0
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
// if(i >= 2169041600)
|
||||
*c_style_pointer_cast<X*>(p_data_ + i) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -357,14 +358,14 @@ struct DynamicBuffer
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void AtomicAdd(long_index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -390,7 +391,7 @@ struct DynamicBuffer
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
@@ -408,12 +409,12 @@ struct DynamicBuffer
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
|
||||
__host__ __device__ void AtomicMax(long_index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
constexpr long_index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
constexpr long_index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
@@ -429,7 +430,7 @@ struct DynamicBuffer
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr long_index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
|
||||
@@ -18,6 +18,13 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user