mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Merge commit '42048bdb7d8d931966af76c6dacfedce1c9da90a' into develop
This commit is contained in:
@@ -225,3 +225,99 @@ TEST(TensorForeach, ClearTensorZeros)
|
||||
|
||||
EXPECT_THAT(actual, Eq(0));
|
||||
}
|
||||
|
||||
TEST(TensorForeach, CopyTensor)
|
||||
{
|
||||
constexpr auto dt = ckb::DataType::I32;
|
||||
const ckt::Extent shape = {10, 3, 45, 23, 6};
|
||||
using Counter = uint32_t;
|
||||
|
||||
const auto src_desc = ckt::make_descriptor<dt>(shape, ckt::PackedRightLayout{});
|
||||
const auto dst_desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
|
||||
|
||||
auto src_buffer = ckt::alloc_tensor_buffer(src_desc);
|
||||
auto dst_buffer = ckt::alloc_tensor_buffer(dst_desc);
|
||||
|
||||
const auto gen = [](const auto& index, const auto& lengths) {
|
||||
// Simple incrementing counter
|
||||
return static_cast<Counter>(ckt::calculate_offset(index, lengths));
|
||||
};
|
||||
|
||||
ckt::fill_tensor(
|
||||
src_desc, src_buffer.get(), [lengths = src_desc.get_lengths(), gen](const auto& index) {
|
||||
return gen(index, lengths);
|
||||
});
|
||||
ckt::clear_tensor_buffer(dst_desc, dst_buffer.get());
|
||||
|
||||
// Perform the actual test
|
||||
|
||||
ckt::copy_tensor(src_desc, src_buffer.get(), dst_desc, dst_buffer.get());
|
||||
|
||||
// Check that the dst tensor has the same data
|
||||
|
||||
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
|
||||
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
|
||||
|
||||
ckt::tensor_foreach(shape,
|
||||
[lengths = dst_desc.get_lengths(),
|
||||
gen,
|
||||
dst = dst_buffer.get(),
|
||||
invalid = reinterpret_cast<Counter*>(d_invalid.get()),
|
||||
strides = dst_desc.get_strides()](const auto& index) {
|
||||
const auto offset = ckt::calculate_offset(index, strides);
|
||||
const auto expected = gen(index, lengths);
|
||||
const auto actual = reinterpret_cast<const Counter*>(dst)[offset];
|
||||
|
||||
if(expected != actual)
|
||||
atomicAdd(invalid, 1);
|
||||
});
|
||||
|
||||
Counter invalid = 0;
|
||||
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
|
||||
|
||||
EXPECT_THAT(invalid, Eq(0));
|
||||
}
|
||||
|
||||
TEST(TensorForeach, FlatTensorIterator)
|
||||
{
|
||||
using Counter = uint32_t;
|
||||
|
||||
constexpr auto dt = ckb::DataType::I32;
|
||||
const ckt::Extent shape = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
|
||||
const ckt::Extent packed_strides = ckt::PackedRightLayout{}(shape);
|
||||
|
||||
const auto desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
|
||||
|
||||
auto buffer = ckt::alloc_tensor_buffer(desc);
|
||||
|
||||
// Fill the tensor with random values according to the *flat* index. The
|
||||
// FlatTensorIterator iterates over flat values even if the strides are not
|
||||
// packed, so indexing these elements according to the flat index in the
|
||||
// iterator should yield again this value.
|
||||
ckt::fill_tensor(desc, buffer.get(), [packed_strides](const auto& index) {
|
||||
const auto flat_index = ckt::calculate_offset(index, packed_strides);
|
||||
return static_cast<int32_t>(flat_index * 10001 % 1001);
|
||||
});
|
||||
|
||||
auto iterator = ckt::FlatTensorIterator(desc, reinterpret_cast<const int32_t*>(buffer.get()));
|
||||
|
||||
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
|
||||
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
|
||||
|
||||
ckt::tensor_foreach(shape,
|
||||
[iterator,
|
||||
packed_strides,
|
||||
strides = desc.get_strides(),
|
||||
data = reinterpret_cast<const int32_t*>(buffer.get()),
|
||||
invalid = reinterpret_cast<Counter*>(d_invalid.get())](const auto& index) {
|
||||
const auto flat_index = ckt::calculate_offset(index, packed_strides);
|
||||
const auto offset = ckt::calculate_offset(index, strides);
|
||||
if(iterator[flat_index] != data[offset])
|
||||
atomicAdd(invalid, 1);
|
||||
});
|
||||
|
||||
Counter invalid = 0;
|
||||
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
|
||||
|
||||
EXPECT_THAT(invalid, Eq(0));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user