OpenCV_4.2.0/opencv_contrib-4.2.0/modules/tracking/samples/tracking_by_matching.cpp

255 lines
9.3 KiB
C++

#include <opencv2/core.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/tracking/tracking_by_matching.hpp>
#include <iostream>
#ifdef HAVE_OPENCV_DNN
#include <opencv2/dnn.hpp>
using namespace std;
using namespace cv;
using namespace cv::tbm;
static const char* keys =
{ "{video_name | | video name }"
"{start_frame |0| Start frame }"
"{frame_step |1| Frame step }"
"{detector_model | | Path to detector's Caffe model }"
"{detector_weights | | Path to detector's Caffe weights }"
"{desired_class_id |-1| The desired class that should be tracked }"
};
static void help()
{
cout << "\nThis example shows the functionality of \"Tracking-by-Matching\" approach:"
" detector is used to detect objects on frames, \n"
"matching is used to find correspondences between new detections and tracked objects.\n"
"Detection is made by DNN detection network every `--frame_step` frame.\n"
"Point a .prototxt file of the network as the parameter `--detector_model`, and a .caffemodel file"
" as the parameter `--detector_weights`.\n"
"(As an example of such detection network is a popular MobileNet_SSD network trained on VOC dataset.)\n"
"If `--desired_class_id` parameter is set, the detection result is filtered by class id,"
" returned by the detection network.\n"
"(That is, if a detection net was trained on VOC dataset, then to track pedestrians point --desired_class_id=15)\n"
"Example of <video_name> is in opencv_extra/testdata/cv/tracking/\n"
"Call:\n"
"./example_tracking_tracking_by_matching --video_name=<video_name> --detector_model=<detector_model_path> --detector_weights=<detector_weights_path> \\\n"
" [--start_frame=<start_frame>] \\\n"
" [--frame_step=<frame_step>] \\\n"
" [--desired_class_id=<desired_class_id>]\n"
<< endl;
cout << "\n\nHot keys: \n"
"\tq - quit the program\n"
"\tp - pause/resume video\n";
}
cv::Ptr<ITrackerByMatching> createTrackerByMatchingWithFastDescriptor();
class DnnObjectDetector
{
public:
DnnObjectDetector(const String& net_caffe_model_path, const String& net_caffe_weights_path,
int desired_class_id=-1,
float confidence_threshold = 0.2,
//the following parameters are default for popular MobileNet_SSD caffe model
const String& net_input_name="data",
const String& net_output_name="detection_out",
double net_scalefactor=0.007843,
const Size& net_size = Size(300,300),
const Scalar& net_mean = Scalar(127.5, 127.5, 127.5),
bool net_swapRB=false)
:desired_class_id(desired_class_id),
confidence_threshold(confidence_threshold),
net_input_name(net_input_name),
net_output_name(net_output_name),
net_scalefactor(net_scalefactor),
net_size(net_size),
net_mean(net_mean),
net_swapRB(net_swapRB)
{
net = dnn::readNetFromCaffe(net_caffe_model_path, net_caffe_weights_path);
if (net.empty())
CV_Error(Error::StsError, "Cannot read Caffe net");
}
TrackedObjects detect(const cv::Mat& frame, int frame_idx)
{
Mat resized_frame;
resize(frame, resized_frame, net_size);
Mat inputBlob = cv::dnn::blobFromImage(resized_frame, net_scalefactor, net_size, net_mean, net_swapRB);
net.setInput(inputBlob, net_input_name);
Mat detection = net.forward(net_output_name);
Mat detection_as_mat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
TrackedObjects res;
for (int i = 0; i < detection_as_mat.rows; i++)
{
float cur_confidence = detection_as_mat.at<float>(i, 2);
int cur_class_id = static_cast<int>(detection_as_mat.at<float>(i, 1));
int x_left = static_cast<int>(detection_as_mat.at<float>(i, 3) * frame.cols);
int y_bottom = static_cast<int>(detection_as_mat.at<float>(i, 4) * frame.rows);
int x_right = static_cast<int>(detection_as_mat.at<float>(i, 5) * frame.cols);
int y_top = static_cast<int>(detection_as_mat.at<float>(i, 6) * frame.rows);
Rect cur_rect(x_left, y_bottom, (x_right - x_left), (y_top - y_bottom));
if (cur_confidence < confidence_threshold)
continue;
if ((desired_class_id >= 0) && (cur_class_id != desired_class_id))
continue;
//clipping by frame size
cur_rect = cur_rect & Rect(Point(), frame.size());
if (cur_rect.empty())
continue;
TrackedObject cur_obj(cur_rect, cur_confidence, frame_idx, -1);
res.push_back(cur_obj);
}
return res;
}
private:
cv::dnn::Net net;
int desired_class_id;
float confidence_threshold;
String net_input_name;
String net_output_name;
double net_scalefactor;
Size net_size;
Scalar net_mean;
bool net_swapRB;
};
cv::Ptr<ITrackerByMatching>
createTrackerByMatchingWithFastDescriptor() {
cv::tbm::TrackerParams params;
cv::Ptr<ITrackerByMatching> tracker = createTrackerByMatching(params);
std::shared_ptr<IImageDescriptor> descriptor_fast =
std::make_shared<ResizedImageDescriptor>(
cv::Size(16, 32), cv::InterpolationFlags::INTER_LINEAR);
std::shared_ptr<IDescriptorDistance> distance_fast =
std::make_shared<MatchTemplateDistance>();
tracker->setDescriptorFast(descriptor_fast);
tracker->setDistanceFast(distance_fast);
return tracker;
}
int main( int argc, char** argv ){
CommandLineParser parser( argc, argv, keys );
cv::Ptr<ITrackerByMatching> tracker = createTrackerByMatchingWithFastDescriptor();
String video_name = parser.get<String>("video_name");
int start_frame = parser.get<int>("start_frame");
int frame_step = parser.get<int>("frame_step");
String detector_model = parser.get<String>("detector_model");
String detector_weights = parser.get<String>("detector_weights");
int desired_class_id = parser.get<int>("desired_class_id");
if( video_name.empty() || detector_model.empty() || detector_weights.empty() )
{
help();
return -1;
}
//open the capture
VideoCapture cap;
cap.open( video_name );
cap.set( CAP_PROP_POS_FRAMES, start_frame );
if( !cap.isOpened() )
{
help();
cout << "***Could not initialize capturing...***\n";
cout << "Current parameter's value: \n";
parser.printMessage();
return -1;
}
// If you use the popular MobileNet_SSD detector, the default parameters may be used.
// Otherwise, set your own parameters (net_mean, net_scalefactor, etc).
DnnObjectDetector detector(detector_model, detector_weights, desired_class_id);
Mat frame;
namedWindow( "Tracking by Matching", 1 );
int frame_counter = -1;
int64 time_total = 0;
bool paused = false;
for ( ;; )
{
if( paused )
{
char c = (char) waitKey(30);
if (c == 'p')
paused = !paused;
if (c == 'q')
break;
continue;
}
cap >> frame;
if(frame.empty()){
break;
}
frame_counter++;
if (frame_counter < start_frame)
continue;
if (frame_counter % frame_step != 0)
continue;
int64 frame_time = getTickCount();
TrackedObjects detections = detector.detect(frame, frame_counter);
// timestamp in milliseconds
uint64_t cur_timestamp = static_cast<uint64_t>(1000.0 / 30 * frame_counter);
tracker->process(frame, detections, cur_timestamp);
frame_time = getTickCount() - frame_time;
time_total += frame_time;
// Drawing colored "worms" (tracks).
frame = tracker->drawActiveTracks(frame);
// Drawing all detected objects on a frame by BLUE COLOR
for (const auto &detection : detections) {
cv::rectangle(frame, detection.rect, cv::Scalar(255, 0, 0), 3);
}
// Drawing tracked detections only by RED color and print ID and detection
// confidence level.
for (const auto &detection : tracker->trackedDetections()) {
cv::rectangle(frame, detection.rect, cv::Scalar(0, 0, 255), 3);
std::string text = std::to_string(detection.object_id) +
" conf: " + std::to_string(detection.confidence);
cv::putText(frame, text, detection.rect.tl(), cv::FONT_HERSHEY_COMPLEX,
1.0, cv::Scalar(0, 0, 255), 3);
}
imshow( "Tracking by Matching", frame );
char c = (char) waitKey( 2 );
if (c == 'q')
break;
if (c == 'p')
paused = !paused;
}
double s = frame_counter / (time_total / getTickFrequency());
printf("FPS: %f\n", s);
return 0;
}
#else // #ifdef HAVE_OPENCV_DNN
int main(int, char**){
CV_Error(cv::Error::StsNotImplemented, "At the moment the sample 'tracking_by_matching' can work only when opencv_dnn module is built.");
}
#endif // #ifdef HAVE_OPENCV_DNN