#include "catch2/catch.hpp" #include "kompute/Kompute.hpp" TEST_CASE("End to end OpMult Flow should execute correctly from manager") { kp::Manager mgr; std::shared_ptr tensorLHS{ new kp::Tensor({ 0, 1, 2 }) }; mgr.evalOp({ tensorLHS }); std::shared_ptr tensorRHS{ new kp::Tensor( { 2, 4, 6 }) }; mgr.evalOp({ tensorRHS }); std::shared_ptr tensorOutput{ new kp::Tensor( { 0, 0, 0 }) }; mgr.evalOp({ tensorOutput }); mgr.evalOp>({ tensorLHS, tensorRHS, tensorOutput }); REQUIRE(tensorOutput->data() == std::vector{0, 4, 12}); } TEST_CASE("End to end OpMult Flow should execute correctly from sequence") { std::shared_ptr tensorLHS{ new kp::Tensor( { 0, 1, 2 }) }; std::shared_ptr tensorRHS{ new kp::Tensor( { 2, 4, 6 }) }; std::shared_ptr tensorOutput{ new kp::Tensor( { 0, 0, 0 }) }; kp::Manager mgr; std::weak_ptr sqWeakPtr = mgr.getOrCreateManagedSequence("newSequence"); if (std::shared_ptr sq = sqWeakPtr.lock()) { sq->begin(); sq->record({ tensorLHS }); sq->record({ tensorRHS }); sq->record({ tensorOutput }); sq->record>({ tensorLHS, tensorRHS, tensorOutput }); sq->end(); sq->eval(); } sqWeakPtr.reset(); REQUIRE(tensorOutput->data() == std::vector{0, 4, 12}); } TEST_CASE("Test manager get create functionality for sequences") { kp::Manager mgr; std::weak_ptr sqWeakPtrOne = mgr.getOrCreateManagedSequence("sqOne"); std::weak_ptr sqWeakPtrTwo = mgr.getOrCreateManagedSequence("sqTwo"); std::weak_ptr sqWeakPtrOneRef = mgr.getOrCreateManagedSequence("sqOne"); std::weak_ptr sqWeakPtrTwoRef = mgr.getOrCreateManagedSequence("sqTwo"); REQUIRE(sqWeakPtrOne.lock() == sqWeakPtrOneRef.lock()); REQUIRE(sqWeakPtrTwo.lock() != sqWeakPtrOneRef.lock()); REQUIRE(sqWeakPtrTwo.lock() == sqWeakPtrTwoRef.lock()); REQUIRE(sqWeakPtrOneRef.lock() != sqWeakPtrTwoRef.lock()); } TEST_CASE("End to end OpMult Flow with OpCreateTensor called with multiple tensors") { std::shared_ptr tensorLHS{ new kp::Tensor( { 0, 1, 2 }) }; std::shared_ptr tensorRHS{ new kp::Tensor( { 2, 4, 6 }) }; std::shared_ptr tensorOutput{ new kp::Tensor( { 0, 0, 0 }) }; kp::Manager mgr; std::weak_ptr sqWeakPtr = mgr.getOrCreateManagedSequence("newSequence"); if (std::shared_ptr sq = sqWeakPtr.lock()) { sq->begin(); sq->record({ tensorLHS, tensorRHS, tensorOutput }); REQUIRE(tensorLHS->isInit()); REQUIRE(tensorRHS->isInit()); REQUIRE(tensorOutput->isInit()); sq->record>({ tensorLHS, tensorRHS, tensorOutput }); sq->end(); sq->eval(); } sqWeakPtr.reset(); REQUIRE(tensorOutput->data() == std::vector{0, 4, 12}); }