azONNXWrapper.cpp 1.1 KB
#include "azONNXWrapper.h"

int AZONNXWrapper::init(const std::string& model_path)
{
	if (!model_path.empty()) {
		//加载并初始化网络
		net_ = cv::dnn::readNetFromONNX(model_path);
		net_.setPreferableBackend(cv::dnn::Backend::DNN_BACKEND_OPENCV);
		net_.setPreferableTarget(cv::dnn::Target::DNN_TARGET_CPU);
		return 0;
	}
	return -1;
}

float* AZONNXWrapper::forward(cv::Mat img, cv::Scalar mean, cv::Scalar std)
{
	CV_Assert(!img.empty());
	//图像预处理,仅支持三通道
	cv::Mat input;
	int incn = img.channels();
	if (incn < 3) {
		cv::cvtColor(img, input, cv::COLOR_GRAY2RGB);
	}
	else if (incn > 3) {
		cv::cvtColor(img, input, cv::COLOR_BGR2RGB);
	}
	else {
		input = img.clone();
	}
	//转float类型
	input.convertTo(input, CV_32F, 1 / 255.);
	//预处理
	cv::subtract(input, mean, input);
	cv::divide(input, std, input);
	//设定输出
	cv::Mat blob = cv::dnn::blobFromImage(input);
	net_.setInput(blob);
	//推理
	cv::Mat predicts = net_.forward();
	//
	for (int i = 0; i < 512; i++) predictions[i] = predicts.ptr<float>(0)[i];
	return predictions;
}


AZONNXWrapper::AZONNXWrapper() {

}

AZONNXWrapper::~AZONNXWrapper()
{
	delete[] predictions;
	predictions = NULL;
}