问题描述
- 求看一下这个Matlab&c++混合编程的代码实现的是什么功能?
-
RT,程序代码如下:#include <math.h> #include<omp.h> #include <sys/types.h> #include "mex.h" #define INF 1E20 typedef signed int int32_t; /* * Generalized distance transforms based on Felzenswalb and Huttenlocher. */ static inline int square(int x) { return x*x; } void dt1d(float *src, float *dst, int *ptr, int step, int n, double a, double b) { int *v = new int[n]; float *z = new float[n+1]; int k = 0; v[0] = 0; z[0] = -INF; z[1] = +INF; for (int q = 1; q <= n-1; q++) { float s = ((src[q*step] - src[v[k]*step]) - b*(q - v[k]) + a*(square(q) - square(v[k]))) / (2*a*(q-v[k])); while (s <= z[k]) { // Update pointer k--; s = ((src[q*step] - src[v[k]*step]) - b*(q - v[k]) + a*(square(q) - square(v[k]))) / (2*a*(q-v[k])); } k++; v[k] = q; z[k] = s; z[k+1] = +INF; } k = 0; for (int q = 0; q <= n-1; q++) { while (z[k+1] < q) k++; dst[q*step] = a*square(q-v[k]) + b*(q-v[k]) + src[v[k]*step]; ptr[q*step] = v[k]; } delete [] v; delete [] z; } // matlab entry point // [M, Ix, Iy] = dt(vals, ax, bx, ay, by) void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { if (nrhs != 5) mexErrMsgTxt("Wrong number of inputs"); if (nlhs != 3) mexErrMsgTxt("Wrong number of outputs"); if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) mexErrMsgTxt("Invalid input"); const int *dims = mxGetDimensions(prhs[0]); float *vals = (float *)mxGetPr(prhs[0]); double ax = mxGetScalar(prhs[1]); double bx = mxGetScalar(prhs[2]); double ay = mxGetScalar(prhs[3]); double by = mxGetScalar(prhs[4]); mxArray *mxM = mxCreateNumericArray(2, dims, mxSINGLE_CLASS, mxREAL); mxArray *mxIx = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL); mxArray *mxIy = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL); float *M = (float *)mxGetPr(mxM); int32_t *Ix = (int32_t *)mxGetPr(mxIx); int32_t *Iy = (int32_t *)mxGetPr(mxIy); float *tmpM = (float *)mxCalloc(dims[0]*dims[1], sizeof(float)); int32_t *tmpIx = (int32_t *)mxCalloc(dims[0]*dims[1], sizeof(int32_t)); int32_t *tmpIy = (int32_t *)mxCalloc(dims[0]*dims[1], sizeof(int32_t)); int coreNum = omp_get_num_procs(); int threadNum=coreNum; if (threadNum > dims[1]) threadNum = dims[1]; if (threadNum > dims[0]) threadNum = dims[0]; omp_set_num_threads(threadNum); #pragma omp parallel for schedule (guided) for (int x = 0; x < dims[1]; x++) dt1d(vals+x*dims[0], tmpM+x*dims[0], tmpIy+x*dims[0], 1, dims[0], -ay, -by); #pragma omp parallel for schedule (guided) for (int y = 0; y < dims[0]; y++) dt1d(tmpM+y, M+y, tmpIx+y, dims[0], dims[1], -ax, -bx); // get argmins and adjust for matlab indexing from 1 for (int x = 0; x < dims[1]; x++) { for (int y = 0; y < dims[0]; y++) { int p = x*dims[0]+y; Ix[p] = tmpIx[p]+1; Iy[p] = tmpIy[tmpIx[p]*dims[0]+y]+1; } } mxFree(tmpM); mxFree(tmpIx); mxFree(tmpIy); plhs[0] = mxM; plhs[1] = mxIx; plhs[2] = mxIy; } /* %%% DEBUGGING CODE %%% N = 500; A = rand(N,N); w = rand(4,1); tic; [res,ix,iy] = dt(A,w(1),w(2),w(3),w(4)); toc; tic; [res2,ix2,iy2] = dt2(A,w(1),w(2),w(3),w(4)); toc; norm(res(:) - res2(:)) mean(ix(:) ~= ix2(:)) mean(iy(:) ~= iy2(:)) */
不太清楚这里的dt1d是从哪里来的?
解决方案
这个dt1d不是程序一开始写的函数么,这个是所谓的基于Felzenswalb and Huttenlocher的广义距离变换
时间: 2025-01-09 07:03:33