mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
add vector_type support into thread_copy_v3r1 (#969)
* add vector_type support into thread_copy_v3r1
* remove unncessary type_convert
* fixed datatype
* fixed dataType
* changed API with is_packx_invocable
* changed example
* add missing cmake file
* fixed ci
* fixed cmake
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: 2ce9b56c64]
This commit is contained in:
10
example/60_gemm_multi_ABD/CMakeLists.txt
Normal file
10
example/60_gemm_multi_ABD/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
363
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp
Normal file
363
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp
Normal file
@@ -0,0 +1,363 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
static constexpr auto I0 = ck::Number<0>{};
|
||||
static constexpr auto I1 = ck::Number<1>{};
|
||||
static constexpr auto I2 = ck::Number<2>{};
|
||||
static constexpr auto I3 = ck::Number<3>{};
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const
|
||||
{
|
||||
const auto a0_v_t = ck::vector_type<ck::half_t, 4>{a0};
|
||||
const auto a1_v_t = ck::vector_type<ck::half_t, 4>{a1};
|
||||
|
||||
auto r_v_t = ck::vector_type<ck::half_t, 4>{};
|
||||
|
||||
r_v_t.AsType<ck::half_t>()(I0) =
|
||||
scale * (a0_v_t.AsType<ck::half_t>()[I0] + a1_v_t.AsType<ck::half_t>()[I0]);
|
||||
r_v_t.AsType<ck::half_t>()(I1) =
|
||||
scale * (a0_v_t.AsType<ck::half_t>()[I1] + a1_v_t.AsType<ck::half_t>()[I1]);
|
||||
r_v_t.AsType<ck::half_t>()(I2) =
|
||||
scale * (a0_v_t.AsType<ck::half_t>()[I2] + a1_v_t.AsType<ck::half_t>()[I2]);
|
||||
r_v_t.AsType<ck::half_t>()(I3) =
|
||||
scale * (a0_v_t.AsType<ck::half_t>()[I3] + a1_v_t.AsType<ck::half_t>()[I3]);
|
||||
|
||||
a = r_v_t.AsType<ck::half4_t>()[I0];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const
|
||||
{
|
||||
a = scale * (a0 + a1);
|
||||
}
|
||||
|
||||
// this attribute controls the copy_function applying element_wise_op with
|
||||
// pack4_data
|
||||
constexpr const static bool is_pack4_invocable = true;
|
||||
|
||||
float scale = 1.0;
|
||||
};
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
|
||||
ck::half_t& e, const float& c, const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
using AElementOp = AddScale;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AlphaBetaAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle<
|
||||
ck::Tuple<ALayout, ALayout>,
|
||||
ck::Tuple<BLayout>,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ck::Tuple<ADataType, ADataType>,
|
||||
ck::Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideD = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 1.0f;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
alpha = std::stof(argv[4]);
|
||||
beta = std::stof(argv[5]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[10]);
|
||||
|
||||
alpha = std::stof(argv[11]);
|
||||
beta = std::stof(argv[12]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
|
||||
"beta\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{0.2};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<ck::index_t, 2>{StrideA, StrideA},
|
||||
std::array<ck::index_t, 1>{StrideB},
|
||||
std::array<ck::index_t, 1>{StrideD},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
Tensor<ADataType> a_m_k({M, K});
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k));
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
[&](auto i) {
|
||||
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
|
||||
|
||||
return MakeAGridDescriptor_M_N<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
|
||||
return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeDataType,
|
||||
ComputeDataType, // ComputeDataType for A
|
||||
ComputeDataType, // ComputeDataType for B
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
dst_vector_type op_r_v;
|
||||
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
|
||||
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
|
||||
});
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<src_vector_t>(
|
||||
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
|
||||
.template SetAsType<dst_vector_t>(src_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// convert from SrcData to DstData here
|
||||
dst_thread_scratch_(idx) =
|
||||
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
#else
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
// TODO make this logic more generic for more sub-dword datatype
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<SrcData>>::value &&
|
||||
is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
|
||||
|
||||
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
|
||||
|
||||
// get DstScalarPerVector # of read-only references to src vectors from
|
||||
@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
Number<num_dst_vector>{});
|
||||
|
||||
// do data transpose
|
||||
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
|
||||
src_vector_refs, dst_vector_refs);
|
||||
});
|
||||
}
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
DstData dst_v;
|
||||
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
|
||||
dst_thread_scratch_(idx) = dst_v;
|
||||
});
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
SrcData,
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
using SrcThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData, // apply data_convert with SrcThreadScratch
|
||||
SrcScalarPerVector,
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
|
||||
@@ -132,9 +132,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
Number<num>{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using has_vec_len = decltype(std::declval<T&>().vec_len);
|
||||
|
||||
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
|
||||
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
|
||||
template <typename SrcBuffers,
|
||||
@@ -159,94 +156,63 @@ struct ThreadwiseTensorSliceTransfer_v7r2
|
||||
is_src_valid);
|
||||
});
|
||||
|
||||
if constexpr(is_detected<has_vec_len, decltype(element_op_)>::value)
|
||||
{
|
||||
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
|
||||
"vec_len in element_op_ type is not index_t");
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
|
||||
static_assert(elem_op_vec_len == 1 || elem_op_vec_len == 2 ||
|
||||
elem_op_vec_len == 4 || elem_op_vec_len == 8,
|
||||
"vec_len in element_op_ must be 1, 2, 4, 8");
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
|
||||
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
|
||||
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
|
||||
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
using elem_op_vec_t =
|
||||
typename vector_type<SrcData, elem_op_vec_len>::type;
|
||||
|
||||
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
using elem_op_vec_t =
|
||||
typename vector_type<DstData, elem_op_vec_len>::type;
|
||||
|
||||
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply pointwise function
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
|
||||
|
||||
return src_vectors[iSrc].template AsType<SrcData>()[i];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iDst) -> auto& {
|
||||
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
|
||||
|
||||
return dst_vectors(iDst).template AsType<DstData>()(i);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// apply pointwise function
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
}
|
||||
// pointwise function signature:
|
||||
// element_op_(dst_data_refs[I0],
|
||||
// dst_data_refs[I1],
|
||||
// ...,
|
||||
// src_data_refs[I0],
|
||||
// src_data_refs[I1],
|
||||
// ...)
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
|
||||
dst_vectors_tuple_(iAccess) = dst_vectors;
|
||||
|
||||
|
||||
@@ -31,4 +31,13 @@ struct nonesuch
|
||||
template <template <class...> class Op, class... Args>
|
||||
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
|
||||
|
||||
template <typename T>
|
||||
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
|
||||
|
||||
template <typename T>
|
||||
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
|
||||
|
||||
template <typename T>
|
||||
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user