#include "opencv2/core/ocl.hpp" #include "opencv2/highgui.hpp" #include "opencv2/imgcodecs.hpp" #include "opencv2/optflow.hpp" #include #include #include /* This tool finds correspondences between two images using Global Patch Collider * and calculates error using provided ground truth flow. * * It will look for the file named "forest.yml.gz" with a learned forest. * You can obtain the "forest.yml.gz" either by manually training it using another tool with *_train suffix * or by downloading one of the files trained on some publicly available dataset from here: * * https://drive.google.com/open?id=0B7Hb8cfuzrIIZDFscXVYd0NBNFU */ using namespace cv; const String keys = "{help h ? | | print this message}" "{@image1 | | image1}" "{@image2 | | image2}" "{@groundtruth | | path to the .flo file}" "{@output | | output to a file instead of displaying, output image path}" "{g gpu | | use OpenCL}" "{f forest |forest.yml.gz| path to the forest.yml.gz}"; const int nTrees = 5; static double normL2( const Point2f &v ) { return sqrt( v.x * v.x + v.y * v.y ); } static Vec3d getFlowColor( const Point2f &f, const bool logScale = true, const double scaleDown = 5 ) { if ( f.x == 0 && f.y == 0 ) return Vec3d( 0, 0, 1 ); double radius = normL2( f ); if ( logScale ) radius = log( radius + 1 ); radius /= scaleDown; radius = std::min( 1.0, radius ); double angle = ( atan2( -f.y, -f.x ) + CV_PI ) * 180 / CV_PI; return Vec3d( angle, radius, 1 ); } static void displayFlow( InputArray _flow, OutputArray _img ) { const Size sz = _flow.size(); Mat flow = _flow.getMat(); _img.create( sz, CV_32FC3 ); Mat img = _img.getMat(); for ( int i = 0; i < sz.height; ++i ) for ( int j = 0; j < sz.width; ++j ) img.at< Vec3f >( i, j ) = getFlowColor( flow.at< Point2f >( i, j ) ); cvtColor( img, img, COLOR_HSV2BGR ); } static bool fileProbe( const char *name ) { return std::ifstream( name ).good(); } int main( int argc, const char **argv ) { CommandLineParser parser( argc, argv, keys ); parser.about( "Global Patch Collider evaluation tool" ); if ( parser.has( "help" ) ) { parser.printMessage(); return 0; } String fromPath = parser.get< String >( 0 ); String toPath = parser.get< String >( 1 ); String gtPath = parser.get< String >( 2 ); String outPath = parser.get< String >( 3 ); const bool useOpenCL = parser.has( "gpu" ); String forestDumpPath = parser.get< String >( "forest" ); if ( !parser.check() ) { parser.printErrors(); return 1; } if ( !fileProbe( forestDumpPath.c_str() ) ) { std::cerr << "Can't open the file with a trained model: `" << forestDumpPath << "`.\nYou can obtain this file either by manually training the model using another tool with *_train suffix or by " "downloading one of the files trained on some publicly available dataset from " "here:\nhttps://drive.google.com/open?id=0B7Hb8cfuzrIIZDFscXVYd0NBNFU" << std::endl; return 1; } ocl::setUseOpenCL( useOpenCL ); Ptr< optflow::GPCForest< nTrees > > forest = Algorithm::load< optflow::GPCForest< nTrees > >( forestDumpPath ); Mat from = imread( fromPath ); Mat to = imread( toPath ); Mat gt = readOpticalFlow( gtPath ); std::vector< std::pair< Point2i, Point2i > > corr; TickMeter meter; meter.start(); forest->findCorrespondences( from, to, corr, optflow::GPCMatchingParams( useOpenCL ) ); meter.stop(); std::cout << "Found " << corr.size() << " matches." << std::endl; std::cout << "Time: " << meter.getTimeSec() << " sec." << std::endl; double error = 0; int totalCorrectFlowVectors = 0; Mat dispErr = Mat::zeros( from.size(), CV_32FC3 ); dispErr = Scalar( 0, 0, 1 ); Mat disp = Mat::zeros( from.size(), CV_32FC3 ); disp = Scalar( 0, 0, 1 ); for ( size_t i = 0; i < corr.size(); ++i ) { const Point2f a = corr[i].first; const Point2f b = corr[i].second; const Point2f gtDisplacement = gt.at< Point2f >( corr[i].first.y, corr[i].first.x ); // Check that flow vector is correct if (!cvIsNaN(gtDisplacement.x) && !cvIsNaN(gtDisplacement.y) && gtDisplacement.x < 1e9 && gtDisplacement.y < 1e9) { const Point2f c = a + gtDisplacement; error += normL2( b - c ); circle( dispErr, a, 3, getFlowColor( b - c, false, 32 ), -1 ); ++totalCorrectFlowVectors; } circle( disp, a, 3, getFlowColor( b - a ), -1 ); } if (totalCorrectFlowVectors) error /= totalCorrectFlowVectors; std::cout << "Average endpoint error: " << error << " px." << std::endl; cvtColor( disp, disp, COLOR_HSV2BGR ); cvtColor( dispErr, dispErr, COLOR_HSV2BGR ); Mat dispGroundTruth; displayFlow( gt, dispGroundTruth ); if ( outPath.length() ) { putText( disp, "Sparse matching: Global Patch Collider", Point2i( 24, 40 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA ); char buf[256]; sprintf( buf, "Average EPE: %.2f", error ); putText( disp, buf, Point2i( 24, 80 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA ); sprintf( buf, "Number of matches: %u", (unsigned)corr.size() ); putText( disp, buf, Point2i( 24, 120 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA ); disp *= 255; imwrite( outPath, disp ); return 0; } namedWindow( "Correspondences", WINDOW_AUTOSIZE ); imshow( "Correspondences", disp ); namedWindow( "Error", WINDOW_AUTOSIZE ); imshow( "Error", dispErr ); namedWindow( "Ground truth", WINDOW_AUTOSIZE ); imshow( "Ground truth", dispGroundTruth ); waitKey( 0 ); return 0; }