mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
* Turning compare warnings on
* Cleaning part I
* Cleaning part II
* Explicit static_cast to ck::type_convert
* Resolving large tensor size issue.
* format
* revert change to tensor descriptor; promote lementSpaceSize to 64bit
* use integer value for GEMM test
* Review remarks
* Review remarks + issues with (un)signed arithmetic
* Format fix
* Format
* Clang-format.
* fix 2gb limit issue
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
[ROCm/composable_kernel commit: f03a1738d9]
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
@@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
|
||||
}
|
||||
#endif
|
||||
|
||||
// Lengths..., Strides... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) Number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) LongNumber<>
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
@@ -68,10 +72,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
|
||||
}
|
||||
};
|
||||
|
||||
const auto element_space_size = f(f, Number<0>{}, Number<1>{});
|
||||
const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
|
||||
#else
|
||||
const auto element_space_size =
|
||||
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
|
||||
calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
|
||||
#endif
|
||||
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
@@ -82,9 +86,12 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
// Lengths... can be:
|
||||
// 1) index_t, which is known at run-time
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) Number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) LongNumber<>
|
||||
template <typename... Lengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
@@ -100,7 +107,7 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
|
||||
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
@@ -110,6 +117,12 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) Number<>, which is known at compile-time
|
||||
// align could be:
|
||||
// 1) index_t, or
|
||||
// 2) Number<>
|
||||
template <typename... Lengths, typename Align>
|
||||
__host__ __device__ constexpr auto
|
||||
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
|
||||
@@ -146,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align ali
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -635,11 +635,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d_grid_desc_mblock_mperblock_{},
|
||||
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(),
|
||||
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(),
|
||||
c_grid_desc_m_n_.GetElementSpaceSize(),
|
||||
d_grid_desc_m_.GetElementSpaceSize(),
|
||||
d_grid_desc_m_.GetElementSpaceSize()},
|
||||
compute_base_ptr_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
|
||||
@@ -384,9 +384,10 @@ struct DeviceBatchedGemmXdl
|
||||
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
|
||||
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
|
||||
c_grid_desc_m_n_.GetElementSpaceSize()},
|
||||
compute_ptr_offset_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
|
||||
@@ -697,7 +697,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
|
||||
@@ -1412,7 +1412,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
|
||||
@@ -861,17 +861,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Input tensors can't be bigger than 2GB each.
|
||||
constexpr std::size_t GB2 = 2 * 1e9;
|
||||
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
|
||||
|
||||
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2)
|
||||
if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 ||
|
||||
arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 ||
|
||||
arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -372,17 +372,18 @@ struct DeviceGroupedGemmXdl
|
||||
{
|
||||
grid_size_ = 0;
|
||||
|
||||
group_count_ = static_cast<int>(gemm_shapes.size());
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
|
||||
|
||||
if(!(group_count_ == p_a.size() && group_count_ == p_b.size() &&
|
||||
group_count_ == p_c.size()))
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_c.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
|
||||
}
|
||||
|
||||
gemm_desc_kernel_arg_.reserve(group_count_);
|
||||
|
||||
for(index_t i = 0; i < gemm_shapes.size(); i++)
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
{
|
||||
const index_t M = gemm_shapes[i].M;
|
||||
const index_t N = gemm_shapes[i].N;
|
||||
@@ -563,7 +564,7 @@ struct DeviceGroupedGemmXdl
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_)
|
||||
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
|
||||
@@ -8,5 +8,8 @@ namespace ck {
|
||||
template <index_t N>
|
||||
using Number = integral_constant<index_t, N>;
|
||||
|
||||
template <index_t N>
|
||||
using LongNumber = integral_constant<long_index_t, N>;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
return StaticBuffer<AddressSpace, T, N, true>{};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
|
||||
{
|
||||
return StaticBuffer<AddressSpace, T, N, true>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user