Hi,
I need to implement a kind of real-time morlet wavelet transform for 200-samples of signal in C++. I have the code ready, which does the covolution of incoming signal and complex morlet wavelet and then take the sqare norm of the covolution result to get the energy of signal at each frequecy (from 6-30Hz) and each channel(2). But I found this part of program is time-consuming. It usually took about 70ms to process a block of data. So after tens or hundreds of processing cycles, the whole system is delayed. Here I posted the wavelet code, see if there is anything that I can modify to improve the efficiency? Thank you very much if any expert can take a look at it (mainly WAVELET1::testTF)!
#ifndef MY_HEADER
#define MY_HEADER
#define LOWFREQ 1
#define HIGHFREQ 30
#define CHANNEL 2 //correspond to MAX_M in FIRFilter.h
#define MAXSAMPLE 200 //correspond to MAX_N in FIRFilter.h
class WAVELET1 {
private:
double fTF [HIGHFREQ-LOWFREQ+1][MAXSAMPLE];
public:
WAVELET1( void );
~WAVELET1( void );
void testTF( int,double, double *); // do convolution
double correlation( int, int, int,int, double, double, double *, double *);
} ;
#endif
#include "PCHIncludes.h"
#pragma hdrstop
#include <vector>
#include <cmath>
#include <limits>
#include <sstream>
#include <complex>
#include "C:/BCI2000/fftw3.h"
#include <wavelet.h>
using namespace std;
WAVELET1::WAVELET1(void)
{
}
WAVELET1::~WAVELET1(void)
{
}
void WAVELET1::testTF(int fNumSample, double fSamplingRate, double *fTestdata )
{
// For the whole algorithm, please refer to matlab program c:\MATLAB6p5\eeglab4.515\functions
const int FREQBAND = HIGHFREQ-LOWFREQ+1;
double mTFraw[CHANNEL][FREQBAND][MAXSAMPLE];
double mTempConst;
double mStdevFdomain;
double mStdevTdomain;
double mTemp;
const double PI=3.1415926;
const double mTfactor=0.5;
const double mNcwFrequency=7.0; //which determine a wavelet family.increasing which results in
// better frequency resolution in expense of the time resolution.
int len_pow2;
int mLength;
int mFFTSize,mTimeLengthSize;
complex<double> *mTempOutput, *mTempOutConst;
complex<double> *p4,*p6,*p5;
// Temp output in the middle of wavelet calculation
complex<double> *mBuffer3;
complex<double> mBuffer4;
std::vector<int> mFrequencyVector;
std::vector<double> mTimeLength;
for (int i = 0; i < FREQBAND; i++) {
mFrequencyVector.push_back(i+LOWFREQ);
}
for( int channel = 0; channel < CHANNEL; channel++ )
{
for (int mfrequencyindex = 0; mfrequencyindex < FREQBAND; mfrequencyindex++)
{
//SD_f
mStdevFdomain = mFrequencyVector[mfrequencyindex]/mNcwFrequency;
//SD_t
mStdevTdomain = 1/(2*PI*mStdevFdomain);
//t: (SD_t*2)*3.5 is about all wavelet length, cover ncw cycles.
//so the input t should be (SD_t*2)*3.5
for (int mtindex = 0; mtindex<(7*mStdevTdomain*fSamplingRate); mtindex++)
{ mTimeLength.push_back(-3.5*mStdevTdomain + mtindex/fSamplingRate);
}
mTimeLengthSize = mTimeLength.size();
mTempConst = pow( mStdevTdomain*sqrt(PI),(-0.5));
mTempOutput = new complex<double> [mTimeLengthSize];
mTempOutConst = new complex<double> [mTimeLengthSize];
p4=mTempOutput;
p5=mTempOutConst;
for (int i = 0; i < mTimeLengthSize; i++) {
*p5= complex<double>(0,(2*PI*mFrequencyVector[mfrequencyindex]*mTimeLength[i]));
*p4= mTempConst* exp( -pow((mTimeLength[i]),2)/( 2*pow(mStdevTdomain,2)))* exp(*p5);
p4++;
p5++;
}
//do convolution by FFT
mLength = fNumSample + mTimeLengthSize-1;
if (mLength<=1024)
len_pow2=1024;
else if((mLength<=2048)&&( mLength>=1024))
len_pow2=2048;
else if ((mLength<=4096)&&( mLength>=2048))
len_pow2=4096;
else
len_pow2=4096*2;
mFFTSize = len_pow2;
fftw_complex *in1,*out1;
fftw_plan p1;
in1 = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*mFFTSize);
out1 = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*mFFTSize);
for( int j = 0; j < fNumSample; j++ ) {
in1[j][0] = fTestdata[MAXSAMPLE*channel+j]; //one channel one trial data
in1[j][1] = 0;
}
for( int j = fNumSample; j < mFFTSize; j++ ) {
in1[j][0] = 0;
in1[j][1] = 0;
}
p1 = fftw_plan_dft_1d(mFFTSize, in1, out1, FFTW_FORWARD,FFTW_ESTIMATE);
fftw_execute(p1);
fftw_destroy_plan(p1);
fftw_free(in1);
fftw_plan p2;
fftw_complex *in2,*out2;
in2 = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*mFFTSize);
out2 = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*mFFTSize);
for( int j = 0; j < mTimeLengthSize; j++ ) {
in2[j][0] = real(*(mTempOutput+j));
in2[j][1] = imag(*(mTempOutput+j));
}
for( int j = mTimeLengthSize; j < mFFTSize; j++ ) {
in2[j][0] = 0;
in2[j][1] = 0;
}
p2 = fftw_plan_dft_1d(mFFTSize, in2, out2, FFTW_FORWARD,FFTW_ESTIMATE);
fftw_execute(p2);
fftw_destroy_plan(p2);
fftw_free(in2);
fftw_plan p3;
fftw_complex *in3,*out3;
in3 = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*mFFTSize);
out3 = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*mFFTSize);
for( int j = 0; j < mFFTSize; j++ ) {
in3[j][0] = out1[j][0]*out2[j][0]-out1[j][1]*out2[j][1];
in3[j][1] = out1[j][1]*out2[j][0]+out1[j][0]*out2[j][1];
}
p3 = fftw_plan_dft_1d(mFFTSize, in3, out3, FFTW_BACKWARD,FFTW_ESTIMATE);
fftw_execute(p3);
fftw_destroy_plan(p3);
fftw_free(in3);
mBuffer3 = new complex<double> [mLength];
p6 = mBuffer3;
for( int i=0; i < mLength ; i++ ) {
*p6 = complex<double>(out3[i][0]/mFFTSize,out3[i][1]/mFFTSize);
p6++;
}
fftw_free(out1);
fftw_free(out2);
fftw_free(out3);
for( int i = 0; i < fNumSample; i++ ) {
mBuffer4 = mBuffer3[static_cast<int>(floor(mTimeLength.size()*0.5+mTfactor))+i-1];
mTFraw[channel][mfrequencyindex][i] = 10*log10(pow(abs(mBuffer4),2) );
}
delete[] mTempOutput;
delete[] mTempOutConst;
delete[] mBuffer3;
//delete[] mBuffer4;
mTimeLength.clear();
}
}
//c3-c4 test trial
for( int k = 0; k < FREQBAND; k++ ) {
for(int l = 0; l < fNumSample; l++)
fTF[k][l]= mTFraw[0][k][l]-mTFraw[1][k][l];
} //c3-c4
}
double WAVELET1::correlation (int fLowCheckFreq, int fHighCheckFreq, int fLowCheckSample, int fHighCheckSample, double fNormRightTemplate, double fNormLeftTemplate, double *flefttemplate, double *frighttemplate)
{
double mNormTest,mTestResult;
double mCr,mCl;
mCl=0;
mCr=0;
mNormTest=0;
// do correlation between test TF distribution
for (int i = fLowCheckFreq-1; i < fHighCheckFreq; i++) {
for(int j = fLowCheckSample-1; j < fHighCheckSample; j++) {
mNormTest=mNormTest+pow(fTF[i][j],2);
mCl =mCl+ (fTF[i][j])* flefttemplate[i*MAXSAMPLE+j];
mCr =mCr+(fTF[i][j])* frighttemplate[i*MAXSAMPLE+j];
}
}
mCr=mCr/(sqrt(mNormTest)*sqrt(fNormRightTemplate));
mCl=mCl/(sqrt(mNormTest)*sqrt(fNormLeftTemplate));
mTestResult = mCl-mCr;
return ( mTestResult);
} ;