mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
No raw index calculation (#31)
* Replace most raw index calculation to coordinate transformation * Overhaul blockwise and threadwise GEMM * Overhaul driver for gridwies GEMM kernel Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
396
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
Normal file
396
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
Normal file
@@ -0,0 +1,396 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N,
|
||||
typename BBlockTransferThreadClusterLengths_K_N,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
const FloatAB* p_b_global,
|
||||
FloatC* p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
|
||||
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
|
||||
|
||||
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm =
|
||||
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGlobalDesc,
|
||||
BGlobalDesc,
|
||||
CGlobalDesc,
|
||||
CBlockClusterDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
ABlockTransferThreadSliceLengths_K_M,
|
||||
ABlockTransferThreadClusterLengths_K_M,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N,
|
||||
BBlockTransferThreadClusterLengths_K_N,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
|
||||
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
|
||||
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
|
||||
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
|
||||
|
||||
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
|
||||
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
|
||||
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
|
||||
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
// TODO remove dependency on deprecated tensor descriptor
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -44,5 +45,30 @@ __host__ __device__ constexpr auto make_cluster_descriptor(
|
||||
return ClusterDescriptor<Lengths, decltype(order)>{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor_v2(
|
||||
const Lengths& lengths,
|
||||
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
|
||||
{
|
||||
constexpr index_t ndim_low = Lengths::Size();
|
||||
|
||||
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
|
||||
|
||||
const auto low_lengths = generate_tuple(
|
||||
[&](auto idim_low) { return reordered_lengths[idim_low]; }, Number<ndim_low>{});
|
||||
|
||||
const auto transform = make_merge_transform(low_lengths);
|
||||
|
||||
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
|
||||
|
||||
constexpr auto up_dim_new_top_ids = Sequence<0>{};
|
||||
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1282,7 +1282,7 @@ struct DynamicFreeze
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low = low_idx_;
|
||||
@@ -1299,7 +1299,7 @@ struct DynamicFreeze
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>)
|
||||
{
|
||||
idx_diff_low(Number<0>{}) = index_t{Number<0>{}};
|
||||
idx_diff_low(Number<0>{}) = 0;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
@@ -1328,5 +1328,90 @@ struct DynamicFreeze
|
||||
}
|
||||
};
|
||||
|
||||
template <typename VectorSize, typename UpLength>
|
||||
struct DynamicVectorize
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(UpLength{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
VectorSize vector_size_;
|
||||
|
||||
__host__ __device__ constexpr DynamicVectorize() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size,
|
||||
const UpLength& up_length)
|
||||
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
|
||||
UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
|
||||
|
||||
idx_low += idx_diff_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ static constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicVectorize, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -74,5 +74,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
|
||||
return DynamicFreeze<LowerIndex>{low_idx};
|
||||
}
|
||||
|
||||
template <typename VectorSize, typename UpLength>
|
||||
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
|
||||
const UpLength& up_length)
|
||||
{
|
||||
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -12,25 +12,6 @@ struct DynamicTensorCoordinate;
|
||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||
struct DynamicTensorCoordinateIterator;
|
||||
|
||||
template <typename LowerDimensionIdss, typename UpperDimensionIdss>
|
||||
__host__ __device__ constexpr index_t GetNumOfHiddenDimension(LowerDimensionIdss,
|
||||
UpperDimensionIdss)
|
||||
{
|
||||
constexpr auto all_low_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::Size();
|
||||
}
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionIdss : Tuple<Sequence<...>, ...>
|
||||
@@ -374,13 +355,13 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
|
||||
|
||||
// put everything together
|
||||
const auto all_transforms = container_cat(old_tensor_desc.GetTransforms(), new_transforms);
|
||||
const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms);
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_cat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
|
||||
container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_cat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
|
||||
container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
|
||||
|
||||
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
|
||||
|
||||
|
||||
456
composable_kernel/include/tensor_description/tensor_adaptor.hpp
Normal file
456
composable_kernel/include/tensor_description/tensor_adaptor.hpp
Normal file
@@ -0,0 +1,456 @@
|
||||
#ifndef CK_TENSOR_ADAPTOR_HPP
|
||||
#define CK_TENSOR_ADAPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// BottomDimensionHiddenIds : Sequence<...>
|
||||
// TopDimensionHiddenIds : Sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
struct TensorAdaptor
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss()
|
||||
{
|
||||
return LowerDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss()
|
||||
{
|
||||
return UpperDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
|
||||
{
|
||||
return TopDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
|
||||
{
|
||||
return BottomDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_top) {
|
||||
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
|
||||
|
||||
constexpr index_t itran = tmp[Number<0>{}];
|
||||
constexpr index_t idim_up = tmp[Number<1>{}];
|
||||
constexpr bool found = tmp[Number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
Number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of Number and index_t
|
||||
return container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
|
||||
{
|
||||
constexpr auto idim_top = Number<IDim>{};
|
||||
|
||||
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == idim_hidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
|
||||
{
|
||||
return BottomDimensionHiddenIds::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfTopDimension()
|
||||
{
|
||||
return TopDimensionHiddenIds::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::Size();
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = GetNumOfTransform();
|
||||
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
|
||||
constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension();
|
||||
constexpr static index_t ndim_top_ = GetNumOfTopDimension();
|
||||
|
||||
using HiddenIndex = MultiIndex<ndim_hidden_>;
|
||||
using BottomIndex = MultiIndex<ndim_bottom_>;
|
||||
using TopIndex = MultiIndex<ndim_top_>;
|
||||
|
||||
// may be index_t or Number<>
|
||||
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr TensorAdaptor() = default;
|
||||
|
||||
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
|
||||
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
|
||||
{
|
||||
static_assert(Transforms::Size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::Size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::Size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = GetNumOfTransform();
|
||||
constexpr index_t ndim_hidden = GetNumOfHiddenDimension();
|
||||
|
||||
MultiIndex<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize uppest index
|
||||
set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
|
||||
auto itran = itran_p1 - Number<1>{};
|
||||
const auto& tran = GetTransforms().At(itran);
|
||||
constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran);
|
||||
constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
MultiIndex<dims_low.Size()> idx_low;
|
||||
|
||||
tran.CalculateLowerIndex(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("TensorAdaptor, ");
|
||||
static_for<0, ntransform_, 1>{}([&](auto i) {
|
||||
printf("transforms: ");
|
||||
transforms_[i].Print();
|
||||
printf("LowerDimensionHiddenIds:");
|
||||
LowerDimensionHiddenIdss{}.At(i).Print();
|
||||
printf("UpperDimensionHiddenIds:");
|
||||
UpperDimensionHiddenIdss{}.At(i).Print();
|
||||
});
|
||||
|
||||
printf("BottomDimensionHiddenIds:");
|
||||
BottomDimensionHiddenIds::Print();
|
||||
printf("TopDimensionHiddenIds:");
|
||||
TopDimensionHiddenIds::Print();
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
template <typename TensorAdaptor0, typename TensorAdaptor1>
|
||||
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
|
||||
const TensorAdaptor1& adaptor1)
|
||||
{
|
||||
static_assert(TensorAdaptor0::GetNumOfTopDimension() ==
|
||||
TensorAdaptor1::GetNumOfBottomDimension(),
|
||||
"wrong!");
|
||||
|
||||
// all_transforms = transform0 + transform1
|
||||
const auto all_transforms =
|
||||
container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms());
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id = NumericLimits<index_t>::Min();
|
||||
|
||||
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
adaptor0_max_hidden_id =
|
||||
math::max(adaptor0_max_hidden_id,
|
||||
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
|
||||
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor0_max_hidden_id =
|
||||
math::max(adaptor0_max_hidden_id,
|
||||
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor0_max_hidden_id;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id = NumericLimits<index_t>::Max();
|
||||
|
||||
static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
|
||||
|
||||
// get the min of all lower dimenions, but not bottom dimension (because their id will
|
||||
// be matched with top id from adaptor0)
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t low_dim_hidden_id =
|
||||
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
|
||||
|
||||
bool is_bottom_dim = false;
|
||||
static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) {
|
||||
if constexpr(low_dim_hidden_id ==
|
||||
TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
|
||||
{
|
||||
is_bottom_dim = true;
|
||||
}
|
||||
});
|
||||
|
||||
if(!is_bottom_dim)
|
||||
{
|
||||
adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
|
||||
|
||||
// get the min of all upper dimensions
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor1_min_hidden_id =
|
||||
math::min(adaptor1_min_hidden_id,
|
||||
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor1_min_hidden_id;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_hidden_id_shift =
|
||||
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
|
||||
|
||||
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
|
||||
|
||||
// all_low_dim_hidden_idss =
|
||||
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
|
||||
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size();
|
||||
|
||||
constexpr auto low_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
// match hidden id
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
|
||||
// if this low dim is bottom dim, then do id matching
|
||||
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
|
||||
TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
|
||||
{
|
||||
low_dim_hidden_ids_1_mod(idim_low_1) =
|
||||
TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod;
|
||||
}
|
||||
();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
Number<ndim_low_1>{});
|
||||
},
|
||||
Number<TensorAdaptor1::GetNumOfTransform()>{});
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1);
|
||||
|
||||
// all_up_dim_hidden_idss =
|
||||
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
|
||||
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size();
|
||||
|
||||
constexpr auto up_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
|
||||
up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod;
|
||||
}
|
||||
();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return Number<up_dim_hidden_ids_1_mod[i]>{}; },
|
||||
Number<ndim_up_1>{});
|
||||
},
|
||||
Number<TensorAdaptor1::GetNumOfTransform()>{});
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1);
|
||||
|
||||
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
|
||||
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds();
|
||||
|
||||
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
|
||||
|
||||
// put everything together
|
||||
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
|
||||
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
|
||||
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
|
||||
LowerDimensionOldTopIdss,
|
||||
UpperDimensionNewTopIdss)
|
||||
{
|
||||
constexpr index_t ntransform = Transforms::Size();
|
||||
|
||||
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
|
||||
UpperDimensionNewTopIdss::Size() == ntransform,
|
||||
"wrong!");
|
||||
|
||||
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
|
||||
constexpr auto all_low_dim_old_top_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_up_dim_new_top_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
|
||||
|
||||
// low_dim_hidden_idss
|
||||
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
|
||||
|
||||
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
|
||||
Number<ntransform>{});
|
||||
|
||||
// bottom_dim_hidden_ids
|
||||
constexpr auto bottom_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
|
||||
|
||||
// top_dim_hidden_ids
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
|
||||
|
||||
return TensorAdaptor<remove_cv_t<Transforms>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Xs,
|
||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -67,26 +67,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_id_begin);
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_id_begin);
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto CalculateThreadDataBegin()
|
||||
{
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
|
||||
|
||||
return thread_cluster_id * ThreadSliceLengths{};
|
||||
}
|
||||
|
||||
template <typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcData* p_src,
|
||||
@@ -141,8 +133,9 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths,
|
||||
|
||||
@@ -7,341 +7,107 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. BlockMatrixA is known at compile-time
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BlockMatrixA is known at compile-time
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. ThreadMatrixC is known at compile-time
|
||||
// 1. CThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename BlockMatrixA,
|
||||
typename BlockMatrixB,
|
||||
typename ThreadMatrixC,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t KPerThreadLoop,
|
||||
typename ABlockDesc,
|
||||
typename BBlockDesc,
|
||||
typename CThreadDesc,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0ThreadCluster,
|
||||
index_t NLevel0ThreadCluster,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t ThreadGemmADataPerRead_M,
|
||||
index_t ThreadGemmBDataPerRead_N,
|
||||
typename std::enable_if<BlockMatrixA::IsKnownAtCompileTime() &&
|
||||
BlockMatrixB::IsKnownAtCompileTime() &&
|
||||
ThreadMatrixC::IsKnownAtCompileTime(),
|
||||
index_t AThreadCopyScalarPerVector_M1,
|
||||
index_t BThreadCopyScalarPerVector_N1,
|
||||
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
|
||||
BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
|
||||
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
private:
|
||||
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{})));
|
||||
|
||||
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{})));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
BlockMatrixA,
|
||||
decltype(a_thread_mtx_desc_),
|
||||
Sequence<KPerThreadLoop, MPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmADataPerRead_M,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
BlockMatrixB,
|
||||
decltype(b_thread_mtx_desc_),
|
||||
Sequence<KPerThreadLoop, NPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmBDataPerRead_N,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
MatrixIndex c_thread_begin_mtx_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
|
||||
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
||||
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)},
|
||||
b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)}
|
||||
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
|
||||
{
|
||||
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
|
||||
BlockMatrixB::IsKnownAtCompileTime() &&
|
||||
ThreadMatrixC::IsKnownAtCompileTime(),
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
|
||||
NLevel0ThreadCluster * NLevel1ThreadCluster,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
|
||||
MLevel1ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
|
||||
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
|
||||
|
||||
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
|
||||
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
|
||||
"wrong! Cannot evenly divide work among");
|
||||
|
||||
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
|
||||
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
|
||||
"wrong! ThreadMatrixC lengths is wrong");
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetThreadMatrixCLengths()
|
||||
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
|
||||
// 4-d data space into 4-d thread space
|
||||
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_vectorize_transform(M0, 1),
|
||||
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
|
||||
make_vectorize_transform(N0, 1),
|
||||
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr index_t MRepeat =
|
||||
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
|
||||
constexpr index_t NRepeat =
|
||||
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
|
||||
// thread position 4-d thread space
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
|
||||
}
|
||||
// 4-d thread space to 1-d thread space
|
||||
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
|
||||
NLevel1ThreadCluster,
|
||||
MLevel0ThreadCluster,
|
||||
NLevel0ThreadCluster))),
|
||||
make_tuple(Sequence<0, 2, 1, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
{
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
|
||||
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
|
||||
|
||||
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
|
||||
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
|
||||
|
||||
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
|
||||
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
|
||||
|
||||
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBlockBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx_desc = ThreadMatrixC{};
|
||||
|
||||
constexpr auto K = a_block_mtx.GetLength(I0);
|
||||
|
||||
constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0);
|
||||
constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline");
|
||||
|
||||
// thread A-sub, B-sub
|
||||
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
|
||||
make_tuple(Number<MPerThread>{}, Number<1>{}));
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}),
|
||||
make_tuple(Number<NPerThread>{}, Number<1>{}));
|
||||
|
||||
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
|
||||
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
|
||||
make_tuple(Number<NPerThread>{}, Number<1>{}));
|
||||
|
||||
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_mtx_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_mtx_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_sub_mtx),
|
||||
decltype(b_thread_sub_mtx),
|
||||
decltype(c_thread_sub_mtx)>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(I0, Number<NPerLevel1Cluster>{}),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(I0, Number<MPerLevel1Cluster>{}),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(k, I0),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(k, I0),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(k, Number<NPerLevel1Cluster>{}),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(k, Number<MPerLevel1Cluster>{}),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
|
||||
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
@@ -349,28 +115,394 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
|
||||
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
CThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<M0_, M1PerThread>,
|
||||
Sequence<N0_, N1PerThread>>{};
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
constexpr index_t K = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
if constexpr(MRepeat == 2 && NRepeat == 2)
|
||||
{
|
||||
Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#else
|
||||
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
|
||||
#endif
|
||||
static_for<0, K, KPerThread>{}([&](auto k) {
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
|
||||
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPerThread, M0_, M1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPerThread, N0_, N1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc,
|
||||
typename BBlockDesc,
|
||||
typename CThreadDesc,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0ThreadCluster,
|
||||
index_t NLevel0ThreadCluster,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t AThreadCopyScalarPerVector_M1,
|
||||
index_t BThreadCopyScalarPerVector_N1,
|
||||
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
|
||||
BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
|
||||
NLevel0ThreadCluster * NLevel1ThreadCluster,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 &&
|
||||
CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2,
|
||||
"wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
|
||||
{
|
||||
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
// 4-d data space into 4-d thread space
|
||||
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_vectorize_transform(M0, 1),
|
||||
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
|
||||
make_vectorize_transform(N0, 1),
|
||||
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// thread position 4-d thread space
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
// 4-d thread space to 1-d thread space
|
||||
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
|
||||
NLevel1ThreadCluster,
|
||||
MLevel0ThreadCluster,
|
||||
NLevel0ThreadCluster))),
|
||||
make_tuple(Sequence<0, 2, 1, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
|
||||
|
||||
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
CThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<1, M1PerThread>,
|
||||
Sequence<1, N1PerThread>>{};
|
||||
|
||||
constexpr index_t K = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I1, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I1, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I1, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I1, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
|
||||
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -12,7 +12,36 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename AGlobalDesc,
|
||||
typename FloatA,
|
||||
typename BGlobalDesc,
|
||||
typename FloatB,
|
||||
typename CGlobalDesc,
|
||||
typename FloatC,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
|
||||
const FloatA* __restrict__ p_a_global,
|
||||
const BGlobalDesc b_k_n_global_desc,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
{
|
||||
GridwiseGemm{}.Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
@@ -23,16 +52,18 @@ template <typename GridwiseGemm,
|
||||
typename FloatB,
|
||||
typename CGlobalDesc,
|
||||
typename FloatC,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const FloatA* __restrict__ p_a_global,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global)
|
||||
__global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const FloatA* __restrict__ p_a_global,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
{
|
||||
// first cast void __CONSTANT__* to void*
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k_m_global_desc =
|
||||
@@ -42,12 +73,16 @@ __global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_gl
|
||||
const auto c_m0_m1_n0_n1_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
|
||||
|
||||
const auto c_block_cluster_desc =
|
||||
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
|
||||
|
||||
GridwiseGemm{}.Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
@@ -61,6 +96,7 @@ template <index_t BlockSize,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
@@ -131,37 +167,30 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
#if 0
|
||||
const auto m_block_work_num = M / Number<MPerBlock>{};
|
||||
const auto n_block_work_num = N / Number<NPerBlock>{};
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
|
||||
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
|
||||
// HACK: this force m/n_block_data_idx_on_global into SGPR
|
||||
const index_t m_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
#else
|
||||
// Hack: this force result into SGPR
|
||||
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock);
|
||||
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
|
||||
|
||||
const index_t m_block_work_id =
|
||||
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num);
|
||||
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
|
||||
#endif
|
||||
|
||||
const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
|
||||
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
|
||||
const index_t n_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
|
||||
@@ -204,7 +233,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_k_m_global_desc,
|
||||
make_multi_index(0, m_block_data_on_global),
|
||||
make_multi_index(0, m_block_data_idx_on_global),
|
||||
a_k_m_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
@@ -233,7 +262,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_k_n_global_desc,
|
||||
make_multi_index(0, n_block_data_on_global),
|
||||
make_multi_index(0, n_block_data_idx_on_global),
|
||||
b_k_n_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
@@ -251,28 +280,45 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
|
||||
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k_m_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k_n_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
decltype(c_m0m1_n0n1_thread_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
MPerThread,
|
||||
NPerThread>{};
|
||||
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
MPerThread,
|
||||
NPerThread>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -286,12 +332,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf =
|
||||
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize());
|
||||
make_static_buffer<FloatAcc>(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m0m1_n0n1_thread_desc),
|
||||
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{}
|
||||
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0});
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
|
||||
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
|
||||
@@ -427,30 +473,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
|
||||
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
|
||||
Number<MPerThread>{},
|
||||
Number<NRepeat>{},
|
||||
Number<NPerThread>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
|
||||
|
||||
constexpr auto tmp = make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
|
||||
const auto c_thread_data_idx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
@@ -465,11 +492,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>(c_m0_m1_n0_n1_global_desc,
|
||||
make_multi_index(m_thread_data_on_global / M1,
|
||||
m_thread_data_on_global % M1,
|
||||
n_thread_data_on_global / N1,
|
||||
n_thread_data_on_global % N1))
|
||||
true>{
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
|
||||
c_thread_data_idx_on_block[I1],
|
||||
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
|
||||
c_thread_data_idx_on_block[I3])}
|
||||
.Run(c_m0_m1_n0_n1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
@@ -486,6 +514,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
@@ -499,6 +528,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
|
||||
@@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, "wrong!");
|
||||
}
|
||||
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
|
||||
@@ -140,5 +140,103 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
|
||||
}
|
||||
};
|
||||
|
||||
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
|
||||
// Tensor element can be vectorized data
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
typename KLengths,
|
||||
typename MLengths,
|
||||
typename NLengths,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto K = KLengths{}[I0];
|
||||
constexpr auto M0 = MLengths{}[I0];
|
||||
constexpr auto M1 = MLengths{}[I1];
|
||||
constexpr auto N0 = NLengths{}[I0];
|
||||
constexpr auto N1 = NLengths{}[I1];
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
static_for<0, M0, 1>{}([&](auto m0) {
|
||||
static_for<0, M1, 1>{}([&](auto m1) {
|
||||
static_for<0, N0, 1>{}([&](auto n0) {
|
||||
static_for<0, N1, 1>{}([&](auto n1) {
|
||||
|
||||
constexpr index_t a_offset =
|
||||
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
|
||||
constexpr index_t b_offset =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
|
||||
constexpr index_t c_offset = CDesc{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(m0, m1, n0, n1));
|
||||
|
||||
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "container_helper.hpp"
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_element_picker.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "float_type.hpp"
|
||||
#include "buffer.hpp"
|
||||
#include "functional.hpp"
|
||||
|
||||
@@ -20,7 +20,8 @@ struct ContainerElementPicker
|
||||
|
||||
__host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array}
|
||||
{
|
||||
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
|
||||
constexpr index_t imax =
|
||||
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
|
||||
|
||||
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
|
||||
}
|
||||
@@ -85,7 +86,8 @@ struct ConstantContainerElementPicker
|
||||
|
||||
__host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array}
|
||||
{
|
||||
constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
|
||||
constexpr index_t imax =
|
||||
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
|
||||
|
||||
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
|
||||
}
|
||||
|
||||
@@ -26,13 +26,13 @@ __host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>
|
||||
template <typename... Ts, typename T>
|
||||
__host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_cat(make_tuple(x), a);
|
||||
return container_concat(make_tuple(x), a);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_cat(a, make_tuple(x));
|
||||
return container_concat(a, make_tuple(x));
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
@@ -158,6 +158,7 @@ __host__ __device__ constexpr auto container_reduce_impl(
|
||||
}
|
||||
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
// container reduce with initial value
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
@@ -299,27 +300,27 @@ container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto container_cat(const X& x, const Ys&... ys)
|
||||
__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
|
||||
{
|
||||
return container_cat(x, container_cat(ys...));
|
||||
return container_concat(x, container_concat(ys...));
|
||||
}
|
||||
|
||||
template <typename T, index_t NX, index_t NY>
|
||||
__host__ __device__ constexpr auto container_cat(const Array<T, NX>& ax, const Array<T, NY>& ay)
|
||||
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto container_cat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
__host__ __device__ constexpr auto container_cat(const Container& x)
|
||||
__host__ __device__ constexpr auto container_concat(const Container& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
24
composable_kernel/include/utility/data_type.hpp
Normal file
24
composable_kernel/include/utility/data_type.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef CK_DATA_TYPE_HPP
|
||||
#define CK_DATA_TYPE_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
struct NumericLimits;
|
||||
|
||||
template <>
|
||||
struct NumericLimits<int32_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int32_t Min()
|
||||
{
|
||||
return std::numeric_limits<int32_t>::min();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr int32_t Max()
|
||||
{
|
||||
return std::numeric_limits<int32_t>::max();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -43,11 +43,17 @@ struct multiplies_v2
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct maxer
|
||||
struct maximize
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct minimize
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user