Get the test to compile

This commit is contained in:
Damien Lejeune
2026-01-21 04:57:01 -05:00
parent e3efa236ec
commit 557a8d3f21
3 changed files with 28 additions and 19 deletions

View File

@@ -35,24 +35,25 @@ class TestCkTileSinkHorn: public ::testing::Test
void RunGenericTest(const std::vector<ck_tile::index_t>& input_shape, const int max_iterations)
{
auto input_m = input_shape[0];
SinkhornKnoppArgs args{};
args.input_m = static_cast<ck_tile::index_t>(input_shape[0]);
args.max_iterations = max_iterations;
auto default_stride = {args.input_m, 1};
auto default_stride = {input_m, 1};
ck_tile::HostTensor<XDataType> h_x(input_shape, default_stride);
ck_tile::HostTensor<YDataType> h_y(input_shape, default_stride);
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(h_x);
auto buffer_size = h_xs.get_element_space_size_in_bytes();
auto buffer_size = h_x.get_element_space_size_in_bytes();
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_y_mem(output_buffer_size);
ck_tile::DeviceMem d_y_mem(buffer_size);
args.p_x = static_cast<void*>(d_x_mem.GetDeviceBuffer());
args.out = static_cast<void*>(d_y_mem.GetDeviceBuffer());
ck_tile::SinkhornKnoppArgs args{
static_cast<void*>(d_y_mem.GetDeviceBuffer()),
static_cast<void*>(d_x_mem.GetDeviceBuffer()),
input_m,
max_iterations
};
d_x_mem.ToDevice(h_x.data());
d_y_mem.ToDevice(h_y.data());
@@ -64,13 +65,13 @@ class TestCkTileSinkHorn: public ::testing::Test
>;
using Kernel = ck_tile::SinkhornKnoppKernelDummyNonStochastic<
Problem,
ck_tile::SinkhornKnoppPolicy>;
ck_tile::SinkhornKnoppDefaultPolicy>;
// Launch configuration
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = 1 // TODO
ck_tile::index_t kGridSize = 1; // TODO
//TODO
// if(!Kernel::IsSupportedArgument())