mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Merge commit '42048bdb7d8d931966af76c6dacfedce1c9da90a' into develop
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -259,9 +259,118 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Instance = ckb::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
|
||||
EXPECT_THAT(
|
||||
ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
"│ ├─ Weight Layout: GKYXC\n"
|
||||
"│ ├─ Output Layout: GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 256×256×32\n"
|
||||
" ├─ Gemm padding: DEFAULT\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V4\n"
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 2\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
// Test printing of optional parameters num_groups_to_merge,
|
||||
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
|
||||
TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
|
||||
{
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // AK1
|
||||
32, // MPerWMMA
|
||||
32, // NPerXDL
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
|
||||
4, // NumGroupsToMerge
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
1, // MaxTransposeTransferSrcScalarPerVector
|
||||
1>; // MaxTransposeTransferDstScalarPerVector>
|
||||
|
||||
EXPECT_THAT(ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"2D Backward Weight Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
@@ -272,37 +381,146 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 256×256×32\n"
|
||||
" ├─ Gemm padding: DEFAULT\n"
|
||||
" ├─ Data tile size: 128×128×16\n"
|
||||
" ├─ Struct does not contain optional gemm_padding argument\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V4\n"
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Pipeline version: V1\n"
|
||||
" ├─ Pipeline scheduler: DEFAULT\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 8×8\n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" └─ Vector access (GMEM write) instruction size: 2"));
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
|
||||
" ├─ Max Transpose transfer scr scalar per vector: 1\n"
|
||||
" ├─ Max Transpose dst scalar per vector: 1\n"
|
||||
" └─ Num groups to merge: 4"));
|
||||
}
|
||||
|
||||
// Test printing of optional parameters num_groups_to_merge,
|
||||
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
|
||||
TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest)
|
||||
{
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
3, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNDHWC, // InLayout
|
||||
ck::tensor_layout::convolution::GKZYXC, // WeiLayout
|
||||
ck::tensor_layout::convolution::GNDHWK, // OutLayout
|
||||
ck::half_t, // InDataType
|
||||
ck::half_t, // WeiDataType
|
||||
ck::half_t, // OutDataType
|
||||
float, // AccDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
|
||||
Default, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerWmma
|
||||
32, // NPerWmma
|
||||
4, // MRepeat
|
||||
4, // NRepeat
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
1, // ABlockLdsAddExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
1, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock_
|
||||
1, // NummGemmKPrefetchStage
|
||||
ck::LoopScheduler::Default, // BlkGemmPipeSched
|
||||
ck::PipelineVersion::v1, // BlkGemmPipelineVer
|
||||
false>; // BComputeDataType
|
||||
|
||||
EXPECT_THAT(
|
||||
ckr::describe<Instance>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"3D Backward Weight Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Input Layout: GNDHWC\n"
|
||||
"│ ├─ Weight Layout: GKZYXC\n"
|
||||
"│ ├─ Output Layout: GNDHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 128×128×16\n"
|
||||
" ├─ Struct does not contain optional gemm_padding argument\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V1\n"
|
||||
" ├─ Pipeline scheduler: DEFAULT\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 32×32\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 2×128×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" ├─ Vector access (GMEM write) instruction size: 8\n"
|
||||
" ├─ Num gemm k prefetch stage: 1\n"
|
||||
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
|
||||
"parameter\n"
|
||||
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
|
||||
" └─ Struct does not contain optional num_groups_to_merge parameter"));
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasInstanceString)
|
||||
|
||||
@@ -209,7 +209,8 @@ struct ReferenceOutputMatcher
|
||||
// Round to 2 digits
|
||||
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
|
||||
*listener << e.wrong_elements << "/" << e.total_elements
|
||||
<< " incorrect elements (~" << percentage << "%)";
|
||||
<< " incorrect elements (~" << percentage << "%)," << " max error "
|
||||
<< e.max_error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,8 +98,10 @@ TEST(ConvFwdTesting, Validate)
|
||||
[&]([[maybe_unused]] std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123});
|
||||
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123});
|
||||
ckt::clear_tensor_buffer(
|
||||
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
|
||||
ckt::clear_tensor_buffer(
|
||||
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
|
||||
});
|
||||
|
||||
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
||||
@@ -115,8 +117,10 @@ TEST(ConvFwdTesting, Validate)
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
++field_count;
|
||||
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2});
|
||||
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1});
|
||||
ckt::clear_tensor_buffer(
|
||||
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(2));
|
||||
ckt::clear_tensor_buffer(
|
||||
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(1));
|
||||
});
|
||||
|
||||
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
||||
|
||||
@@ -225,3 +225,99 @@ TEST(TensorForeach, ClearTensorZeros)
|
||||
|
||||
EXPECT_THAT(actual, Eq(0));
|
||||
}
|
||||
|
||||
TEST(TensorForeach, CopyTensor)
|
||||
{
|
||||
constexpr auto dt = ckb::DataType::I32;
|
||||
const ckt::Extent shape = {10, 3, 45, 23, 6};
|
||||
using Counter = uint32_t;
|
||||
|
||||
const auto src_desc = ckt::make_descriptor<dt>(shape, ckt::PackedRightLayout{});
|
||||
const auto dst_desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
|
||||
|
||||
auto src_buffer = ckt::alloc_tensor_buffer(src_desc);
|
||||
auto dst_buffer = ckt::alloc_tensor_buffer(dst_desc);
|
||||
|
||||
const auto gen = [](const auto& index, const auto& lengths) {
|
||||
// Simple incrementing counter
|
||||
return static_cast<Counter>(ckt::calculate_offset(index, lengths));
|
||||
};
|
||||
|
||||
ckt::fill_tensor(
|
||||
src_desc, src_buffer.get(), [lengths = src_desc.get_lengths(), gen](const auto& index) {
|
||||
return gen(index, lengths);
|
||||
});
|
||||
ckt::clear_tensor_buffer(dst_desc, dst_buffer.get());
|
||||
|
||||
// Perform the actual test
|
||||
|
||||
ckt::copy_tensor(src_desc, src_buffer.get(), dst_desc, dst_buffer.get());
|
||||
|
||||
// Check that the dst tensor has the same data
|
||||
|
||||
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
|
||||
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
|
||||
|
||||
ckt::tensor_foreach(shape,
|
||||
[lengths = dst_desc.get_lengths(),
|
||||
gen,
|
||||
dst = dst_buffer.get(),
|
||||
invalid = reinterpret_cast<Counter*>(d_invalid.get()),
|
||||
strides = dst_desc.get_strides()](const auto& index) {
|
||||
const auto offset = ckt::calculate_offset(index, strides);
|
||||
const auto expected = gen(index, lengths);
|
||||
const auto actual = reinterpret_cast<const Counter*>(dst)[offset];
|
||||
|
||||
if(expected != actual)
|
||||
atomicAdd(invalid, 1);
|
||||
});
|
||||
|
||||
Counter invalid = 0;
|
||||
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
|
||||
|
||||
EXPECT_THAT(invalid, Eq(0));
|
||||
}
|
||||
|
||||
TEST(TensorForeach, FlatTensorIterator)
|
||||
{
|
||||
using Counter = uint32_t;
|
||||
|
||||
constexpr auto dt = ckb::DataType::I32;
|
||||
const ckt::Extent shape = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
|
||||
const ckt::Extent packed_strides = ckt::PackedRightLayout{}(shape);
|
||||
|
||||
const auto desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
|
||||
|
||||
auto buffer = ckt::alloc_tensor_buffer(desc);
|
||||
|
||||
// Fill the tensor with random values according to the *flat* index. The
|
||||
// FlatTensorIterator iterates over flat values even if the strides are not
|
||||
// packed, so indexing these elements according to the flat index in the
|
||||
// iterator should yield again this value.
|
||||
ckt::fill_tensor(desc, buffer.get(), [packed_strides](const auto& index) {
|
||||
const auto flat_index = ckt::calculate_offset(index, packed_strides);
|
||||
return static_cast<int32_t>(flat_index * 10001 % 1001);
|
||||
});
|
||||
|
||||
auto iterator = ckt::FlatTensorIterator(desc, reinterpret_cast<const int32_t*>(buffer.get()));
|
||||
|
||||
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
|
||||
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
|
||||
|
||||
ckt::tensor_foreach(shape,
|
||||
[iterator,
|
||||
packed_strides,
|
||||
strides = desc.get_strides(),
|
||||
data = reinterpret_cast<const int32_t*>(buffer.get()),
|
||||
invalid = reinterpret_cast<Counter*>(d_invalid.get())](const auto& index) {
|
||||
const auto flat_index = ckt::calculate_offset(index, packed_strides);
|
||||
const auto offset = ckt::calculate_offset(index, strides);
|
||||
if(iterator[flat_index] != data[offset])
|
||||
atomicAdd(invalid, 1);
|
||||
});
|
||||
|
||||
Counter invalid = 0;
|
||||
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
|
||||
|
||||
EXPECT_THAT(invalid, Eq(0));
|
||||
}
|
||||
|
||||
@@ -74,7 +74,8 @@ TYPED_TEST(ValidationReportTests, SingleCorrect)
|
||||
ckt::fill_tensor(desc, b.get(), generator);
|
||||
|
||||
ckt::ValidationReport report;
|
||||
report.check("correct", desc, b.get(), a.get());
|
||||
report.check("correct - explicit tolerance", desc, b.get(), a.get());
|
||||
report.check_by_accumulations("correct - implicit tolerance", desc, b.get(), a.get(), 0);
|
||||
|
||||
EXPECT_THAT(report.get_errors().size(), Eq(0));
|
||||
}
|
||||
@@ -97,17 +98,22 @@ TYPED_TEST(ValidationReportTests, SingleIncorrect)
|
||||
});
|
||||
|
||||
ckt::ValidationReport report;
|
||||
report.check("incorrect", desc, b.get(), a.get());
|
||||
report.check("incorrect - explicit tolerance", desc, b.get(), a.get());
|
||||
report.check_by_accumulations("incorrect - implicit tolerance", desc, b.get(), a.get(), 0);
|
||||
|
||||
const auto errors = report.get_errors();
|
||||
|
||||
const auto flat_size = desc.get_element_size();
|
||||
const auto expected_errors = flat_size >= 999999 ? 3 : flat_size >= 12345 ? 2 : 1;
|
||||
|
||||
ASSERT_THAT(errors.size(), Eq(1));
|
||||
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect"));
|
||||
EXPECT_THAT(errors[0].wrong_elements, Eq(expected_errors));
|
||||
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
|
||||
ASSERT_THAT(errors.size(), Eq(2));
|
||||
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect - explicit tolerance"));
|
||||
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect - implicit tolerance"));
|
||||
for(int i = 0; i < 2; ++i)
|
||||
{
|
||||
EXPECT_THAT(errors[i].wrong_elements, Eq(expected_errors));
|
||||
EXPECT_THAT(errors[i].total_elements, Eq(desc.get_element_size()));
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
|
||||
@@ -121,14 +127,20 @@ TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
|
||||
ckt::clear_tensor_buffer(desc, b.get());
|
||||
|
||||
ckt::ValidationReport report;
|
||||
report.check("zero_is_incorrect", desc, b.get(), a.get());
|
||||
report.check("zero_is_incorrect - explicit tolerance", desc, b.get(), a.get());
|
||||
report.check_by_accumulations(
|
||||
"zero_is_incorrect - implicit tolerance", desc, b.get(), a.get(), 0);
|
||||
|
||||
const auto errors = report.get_errors();
|
||||
ASSERT_THAT(errors.size(), Eq(1));
|
||||
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect"));
|
||||
EXPECT_THAT(errors[0].wrong_elements, Eq(0));
|
||||
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
|
||||
EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size()));
|
||||
ASSERT_THAT(errors.size(), Eq(2));
|
||||
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect - explicit tolerance"));
|
||||
EXPECT_THAT(errors[1].tensor_name, StrEq("zero_is_incorrect - implicit tolerance"));
|
||||
for(int i = 0; i < 2; ++i)
|
||||
{
|
||||
EXPECT_THAT(errors[i].wrong_elements, Eq(0));
|
||||
EXPECT_THAT(errors[i].total_elements, Eq(desc.get_element_size()));
|
||||
EXPECT_THAT(errors[i].both_all_zero, Eq(true));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ValidationReportTests, MultipleSomeIncorrect)
|
||||
@@ -143,11 +155,12 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
|
||||
auto b = ckt::alloc_tensor_buffer(desc);
|
||||
|
||||
ckt::fill_tensor_buffer(
|
||||
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 100); });
|
||||
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(float(i % 100)); });
|
||||
ckt::fill_tensor_buffer(
|
||||
desc, b.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 101); });
|
||||
desc, b.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(float(i % 101)); });
|
||||
|
||||
report.check("incorrect 1", desc, b.get(), a.get());
|
||||
report.check("incorrect 1 - explicit tolerance", desc, b.get(), a.get());
|
||||
report.check("incorrect 1 - implicit tolerance", desc, b.get(), a.get(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
@@ -169,7 +182,8 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
|
||||
}
|
||||
});
|
||||
|
||||
report.check("correct", desc, b.get(), a.get());
|
||||
report.check("correct - explicit tolerance", desc, b.get(), a.get());
|
||||
report.check("correct - implicit tolerance", desc, b.get(), a.get(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
@@ -182,16 +196,21 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
|
||||
ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 1; });
|
||||
ckt::fill_tensor_buffer(desc, b.get(), []([[maybe_unused]] size_t i) { return 555; });
|
||||
|
||||
report.check("incorrect 2", desc, b.get(), a.get());
|
||||
report.check("incorrect 2 - explicit tolerance", desc, b.get(), a.get());
|
||||
report.check("incorrect 2 - implicit tolerance", desc, b.get(), a.get(), 0);
|
||||
}
|
||||
|
||||
const auto errors = report.get_errors();
|
||||
|
||||
ASSERT_THAT(errors.size(), Eq(2));
|
||||
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1"));
|
||||
ASSERT_THAT(errors.size(), Eq(4));
|
||||
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1 - explicit tolerance"));
|
||||
EXPECT_THAT(errors[0].wrong_elements, Eq(46840334));
|
||||
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 2"));
|
||||
EXPECT_THAT(errors[1].wrong_elements, Eq(482800));
|
||||
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 1 - implicit tolerance"));
|
||||
EXPECT_THAT(errors[1].wrong_elements, Eq(46840334));
|
||||
EXPECT_THAT(errors[2].tensor_name, StrEq("incorrect 2 - explicit tolerance"));
|
||||
EXPECT_THAT(errors[2].wrong_elements, Eq(482800));
|
||||
EXPECT_THAT(errors[3].tensor_name, StrEq("incorrect 2 - implicit tolerance"));
|
||||
EXPECT_THAT(errors[3].wrong_elements, Eq(482800));
|
||||
}
|
||||
|
||||
// MatchesReference operates on the types defined in testing.hpp, so just
|
||||
@@ -234,7 +253,7 @@ ValidationReport validate<DUMMY_SIGNATURE>(const Args<DUMMY_SIGNATURE>& args,
|
||||
{
|
||||
ValidationReport report;
|
||||
report.check("a", args.make_a_descriptor(), actual.a, expected.a);
|
||||
report.check("b", args.make_b_descriptor(), actual.b, expected.b);
|
||||
report.check_by_accumulations("b", args.make_b_descriptor(), actual.b, expected.b, 0);
|
||||
return report;
|
||||
}
|
||||
|
||||
@@ -299,5 +318,5 @@ TEST(MatchesReference, Incorrect)
|
||||
EXPECT_THAT(listener.str(),
|
||||
StringEqWithDiff( //
|
||||
"1 tensors failed to validate\n"
|
||||
" - a: 625/625 incorrect elements (~100%)"));
|
||||
" - a: 625/625 incorrect elements (~100%), max error 1"));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user