#include #include #include #include using namespace mlpack; using namespace mlpack::tree; template double traverse(const TreeType& node, const arma::vec& point, const int level = 0) { double result = 0.0; for (size_t i = 0; i < node.NumPoints(); ++i) { result += arma::dot(point, node.Dataset().col(node.Point(i))); } double leftResult = 0; double rightResult = 0; // if (level % 4 == 0) { if (node.Left()) { // #pragma omp task shared(node, leftResult, point) { leftResult = traverse(*node.Left(), point, level + 1); } } if (node.Right()) { #pragma omp task shared(node, rightResult, point) { rightResult = traverse(*node.Right(), point, level + 1); } } #pragma omp taskwait } // else // { // leftResult = (node.Left()) ? traverse(*node.Left(), point, level + 1) : 0; // rightResult = (node.Right()) ? traverse(*node.Right(), point, level + 1) : 0; // } return result + leftResult + rightResult; } int main(int argc, char** argv) { CLI::ParseCommandLine(argc, argv); start_t = clock(); arma::mat matrix; matrix.randu(50, 100000); Timer::Start("tree_building"); BinarySpaceTree > tree(matrix); Timer::Stop("tree_building"); double result = 0.0; Timer::Start("traversal"); arma::mat queryPoints(50, 1000); queryPoints.randu(); #pragma omp parallel { #pragma omp single { // Now traverse it. for (size_t i = 0; i < 1000; ++i) { result += traverse(tree, queryPoints.col(i)); } } } Timer::Stop("traversal"); std::cout << result << ".\n"; }