mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add a prototype of int4
This commit is contained in:
@@ -543,7 +543,7 @@ ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
add_subdirectory(library)
|
||||
#add_subdirectory(library)
|
||||
|
||||
if(NOT GPU_ARCHS)
|
||||
rocm_package_setup_component(tests
|
||||
@@ -556,20 +556,20 @@ if(NOT GPU_ARCHS)
|
||||
PACKAGE_NAME examples
|
||||
)
|
||||
add_subdirectory(example)
|
||||
if(BUILD_TESTING)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
#if(BUILD_TESTING)
|
||||
#add_subdirectory(test)
|
||||
#endif()
|
||||
endif()
|
||||
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
#add_subdirectory(profiler)
|
||||
|
||||
if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
add_subdirectory(codegen)
|
||||
endif()
|
||||
#if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
#add_subdirectory(codegen)
|
||||
#endif()
|
||||
|
||||
#Create an interface target for the include only files and call it "composablekernels"
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
@@ -29,6 +29,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
|
||||
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
|
||||
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
|
||||
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::half_t;
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
@@ -27,17 +27,31 @@ using DeviceGemmV2Instance =
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
#if 0
|
||||
64,
|
||||
16, 16,
|
||||
64, 16, 8,
|
||||
256, 8, 16,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
1, 1, S<1, 16, 1, 4>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>;
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
|
||||
#else
|
||||
128,
|
||||
16, 32,
|
||||
128, 8, 16,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
1, 1, S<1, 16, 1, 8>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
|
||||
93
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
Normal file
93
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::pk_i4_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
inline __host__ __device__ ck::half2_t
|
||||
type_convert_packed_i4_to_half2(ck::pk_i4_t x)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f);
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
|
||||
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
|
||||
|
||||
return {l_f16, h_f16};
|
||||
}
|
||||
|
||||
|
||||
struct ElementwisePackedI4ToHalf2
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
|
||||
{
|
||||
y = type_convert_packed_i4_to_half2(x);
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
#if 0
|
||||
64,
|
||||
16, 16,
|
||||
256, 8, 32,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 4>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
|
||||
#else
|
||||
128,
|
||||
16, 32,
|
||||
128, 8, 32,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 8>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
@@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
@@ -27,17 +27,17 @@ using DeviceGemmV2Instance =
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
224, 256,
|
||||
64, 8, 2,
|
||||
64,
|
||||
16, 16,
|
||||
256, 8, 8,
|
||||
16, 16,
|
||||
7, 8,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
1, 1,
|
||||
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 8, 2, 0,
|
||||
1, 2, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
|
||||
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 16, 1, 4>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
@@ -228,6 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
#if 0
|
||||
if(config.do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
@@ -257,11 +258,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
get_atol<CDataType>());
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4});
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
|
||||
@@ -22,6 +22,19 @@ struct PassThroughPack2
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
|
||||
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
|
||||
|
||||
y = {l_f16, h_f16};
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
|
||||
@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
{
|
||||
@@ -1015,6 +1022,8 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)), "pk data N cannot be 1");
|
||||
}
|
||||
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
@@ -1109,7 +1118,7 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_tmp_vector)::type;
|
||||
|
||||
@@ -1120,7 +1129,7 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
if constexpr(SrcBuffer::IsDynamicBuffer())
|
||||
{
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize, is_src_valid);
|
||||
}
|
||||
else if constexpr(SrcBuffer::IsStaticBuffer())
|
||||
{
|
||||
@@ -1129,11 +1138,34 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset / PackedSize>{}];
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, PackedSize>::type;
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, 1>::type;
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack2{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value &&
|
||||
SrcScalarPerVector % 2 == 0)
|
||||
{
|
||||
|
||||
@@ -31,8 +31,8 @@ template <typename SliceLengths,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarPerVector_,
|
||||
index_t DstScalarPerVector_,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
@@ -55,6 +55,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
|
||||
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
@@ -67,6 +78,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, "SrcData != DstData");
|
||||
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1 || DstScalarPerVector == 1)), "pk data N cannot be 1");
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -95,11 +108,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector * PackedSize>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector * PackedSize) == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
@@ -181,7 +194,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)};
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
@@ -279,7 +292,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
// OOB Check
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector * PackedSize>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -368,9 +381,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
|
||||
SrcScalarPerVector,
|
||||
SrcScalarPerVector * PackedSize,
|
||||
DstVectorDim,
|
||||
DstScalarPerVector>{},
|
||||
DstScalarPerVector * PackedSize>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
|
||||
@@ -410,7 +423,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
constexpr auto packed_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
|
||||
|
||||
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
@@ -438,7 +456,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector * PackedSize>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -532,7 +550,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
dst_coord_.GetOffset() / PackedSize,
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
|
||||
@@ -429,7 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using r_t = typename vector_type<T, N>::type;
|
||||
|
||||
@@ -12,6 +12,7 @@ using half_t = _Float16;
|
||||
using int4_t = _BitInt(4);
|
||||
using f8_t = _BitInt(8);
|
||||
using bf8_t = unsigned _BitInt(8);
|
||||
using pk_i4_t = unsigned char;
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
|
||||
Reference in New Issue
Block a user