mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Get the test to compile
This commit is contained in:
@@ -35,24 +35,25 @@ class TestCkTileSinkHorn: public ::testing::Test
|
||||
|
||||
void RunGenericTest(const std::vector<ck_tile::index_t>& input_shape, const int max_iterations)
|
||||
{
|
||||
auto input_m = input_shape[0];
|
||||
|
||||
SinkhornKnoppArgs args{};
|
||||
args.input_m = static_cast<ck_tile::index_t>(input_shape[0]);
|
||||
args.max_iterations = max_iterations;
|
||||
|
||||
auto default_stride = {args.input_m, 1};
|
||||
auto default_stride = {input_m, 1};
|
||||
|
||||
ck_tile::HostTensor<XDataType> h_x(input_shape, default_stride);
|
||||
ck_tile::HostTensor<YDataType> h_y(input_shape, default_stride);
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(h_x);
|
||||
|
||||
auto buffer_size = h_xs.get_element_space_size_in_bytes();
|
||||
auto buffer_size = h_x.get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_y_mem(output_buffer_size);
|
||||
ck_tile::DeviceMem d_y_mem(buffer_size);
|
||||
|
||||
args.p_x = static_cast<void*>(d_x_mem.GetDeviceBuffer());
|
||||
args.out = static_cast<void*>(d_y_mem.GetDeviceBuffer());
|
||||
ck_tile::SinkhornKnoppArgs args{
|
||||
static_cast<void*>(d_y_mem.GetDeviceBuffer()),
|
||||
static_cast<void*>(d_x_mem.GetDeviceBuffer()),
|
||||
input_m,
|
||||
max_iterations
|
||||
};
|
||||
|
||||
d_x_mem.ToDevice(h_x.data());
|
||||
d_y_mem.ToDevice(h_y.data());
|
||||
@@ -64,13 +65,13 @@ class TestCkTileSinkHorn: public ::testing::Test
|
||||
>;
|
||||
using Kernel = ck_tile::SinkhornKnoppKernelDummyNonStochastic<
|
||||
Problem,
|
||||
ck_tile::SinkhornKnoppPolicy>;
|
||||
ck_tile::SinkhornKnoppDefaultPolicy>;
|
||||
|
||||
// Launch configuration
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::index_t kGridSize = 1 // TODO
|
||||
ck_tile::index_t kGridSize = 1; // TODO
|
||||
|
||||
//TODO
|
||||
// if(!Kernel::IsSupportedArgument())
|
||||
|
||||
Reference in New Issue
Block a user