kissfft.hh 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. #ifndef KISSFFT_CLASS_HH
  2. #define KISSFFT_CLASS_HH
  3. #include <complex>
  4. #include <utility>
  5. #include <vector>
  6. template <typename T_Scalar,
  7. typename T_Complex=std::complex<T_Scalar>
  8. >
  9. class kissfft
  10. {
  11. public:
  12. typedef T_Scalar scalar_type;
  13. typedef T_Complex cpx_type;
  14. kissfft( std::size_t nfft,
  15. bool inverse )
  16. :_nfft(nfft)
  17. ,_inverse(inverse)
  18. {
  19. // fill twiddle factors
  20. _twiddles.resize(_nfft);
  21. const scalar_type phinc = (_inverse?2:-2)* acos( (scalar_type) -1) / _nfft;
  22. for (std::size_t i=0;i<_nfft;++i)
  23. _twiddles[i] = exp( cpx_type(0,i*phinc) );
  24. //factorize
  25. //start factoring out 4's, then 2's, then 3,5,7,9,...
  26. std::size_t n= _nfft;
  27. std::size_t p=4;
  28. do {
  29. while (n % p) {
  30. switch (p) {
  31. case 4: p = 2; break;
  32. case 2: p = 3; break;
  33. default: p += 2; break;
  34. }
  35. if (p*p>n)
  36. p = n;// no more factors
  37. }
  38. n /= p;
  39. _stageRadix.push_back(p);
  40. _stageRemainder.push_back(n);
  41. }while(n>1);
  42. }
  43. /// Changes the FFT-length and/or the transform direction.
  44. ///
  45. /// @post The @c kissfft object will be in the same state as if it
  46. /// had been newly constructed with the passed arguments.
  47. /// However, the implementation may be faster than constructing a
  48. /// new fft object.
  49. void assign( std::size_t nfft,
  50. bool inverse )
  51. {
  52. if ( nfft != _nfft )
  53. {
  54. kissfft tmp( nfft, inverse ); // O(n) time.
  55. std::swap( tmp, *this ); // this is O(1) in C++11, O(n) otherwise.
  56. }
  57. else if ( inverse != _inverse )
  58. {
  59. // conjugate the twiddle factors.
  60. for ( typename std::vector<cpx_type>::iterator it = _twiddles.begin();
  61. it != _twiddles.end(); ++it )
  62. it->imag( -it->imag() );
  63. }
  64. }
  65. /// Calculates the complex Discrete Fourier Transform.
  66. ///
  67. /// The size of the passed arrays must be passed in the constructor.
  68. /// The sum of the squares of the absolute values in the @c dst
  69. /// array will be @c N times the sum of the squares of the absolute
  70. /// values in the @c src array, where @c N is the size of the array.
  71. /// In other words, the l_2 norm of the resulting array will be
  72. /// @c sqrt(N) times as big as the l_2 norm of the input array.
  73. /// This is also the case when the inverse flag is set in the
  74. /// constructor. Hence when applying the same transform twice, but with
  75. /// the inverse flag changed the second time, then the result will
  76. /// be equal to the original input times @c N.
  77. void transform( const cpx_type * src,
  78. cpx_type * dst ) const
  79. {
  80. kf_work(0, dst, src, 1,1);
  81. }
  82. /// Calculates the Discrete Fourier Transform (DFT) of a real input
  83. /// of size @c 2*N.
  84. ///
  85. /// The 0-th and N-th value of the DFT are real numbers. These are
  86. /// stored in @c dst[0].real() and @c dst[1].imag() respectively.
  87. /// The remaining DFT values up to the index N-1 are stored in
  88. /// @c dst[1] to @c dst[N-1].
  89. /// The other half of the DFT values can be calculated from the
  90. /// symmetry relation
  91. /// @code
  92. /// DFT(src)[2*N-k] == conj( DFT(src)[k] );
  93. /// @endcode
  94. /// The same scaling factors as in @c transform() apply.
  95. ///
  96. /// @note For this to work, the types @c scalar_type and @c cpx_type
  97. /// must fulfill the following requirements:
  98. ///
  99. /// For any object @c z of type @c cpx_type,
  100. /// @c reinterpret_cast<scalar_type(&)[2]>(z)[0] is the real part of @c z and
  101. /// @c reinterpret_cast<scalar_type(&)[2]>(z)[1] is the imaginary part of @c z.
  102. /// For any pointer to an element of an array of @c cpx_type named @c p
  103. /// and any valid array index @c i, @c reinterpret_cast<T*>(p)[2*i]
  104. /// is the real part of the complex number @c p[i], and
  105. /// @c reinterpret_cast<T*>(p)[2*i+1] is the imaginary part of the
  106. /// complex number @c p[i].
  107. ///
  108. /// Since C++11, these requirements are guaranteed to be satisfied for
  109. /// @c scalar_types being @c float, @c double or @c long @c double
  110. /// together with @c cpx_type being @c std::complex<scalar_type>.
  111. void transform_real( const scalar_type * src,
  112. cpx_type * dst ) const
  113. {
  114. const std::size_t N = _nfft;
  115. if ( N == 0 )
  116. return;
  117. // perform complex FFT
  118. transform( reinterpret_cast<const cpx_type*>(src), dst );
  119. // post processing for k = 0 and k = N
  120. dst[0] = cpx_type( dst[0].real() + dst[0].imag(),
  121. dst[0].real() - dst[0].imag() );
  122. // post processing for all the other k = 1, 2, ..., N-1
  123. const scalar_type pi = acos( (scalar_type) -1);
  124. const scalar_type half_phi_inc = ( _inverse ? pi : -pi ) / N;
  125. const cpx_type twiddle_mul = exp( cpx_type(0, half_phi_inc) );
  126. for ( std::size_t k = 1; 2*k < N; ++k )
  127. {
  128. const cpx_type w = 0.5 * cpx_type(
  129. dst[k].real() + dst[N-k].real(),
  130. dst[k].imag() - dst[N-k].imag() );
  131. const cpx_type z = 0.5 * cpx_type(
  132. dst[k].imag() + dst[N-k].imag(),
  133. -dst[k].real() + dst[N-k].real() );
  134. const cpx_type twiddle =
  135. k % 2 == 0 ?
  136. _twiddles[k/2] :
  137. _twiddles[k/2] * twiddle_mul;
  138. dst[ k] = w + twiddle * z;
  139. dst[N-k] = conj( w - twiddle * z );
  140. }
  141. if ( N % 2 == 0 )
  142. dst[N/2] = conj( dst[N/2] );
  143. }
  144. private:
  145. void kf_work( std::size_t stage,
  146. cpx_type * Fout,
  147. const cpx_type * f,
  148. std::size_t fstride,
  149. std::size_t in_stride) const
  150. {
  151. const std::size_t p = _stageRadix[stage];
  152. const std::size_t m = _stageRemainder[stage];
  153. cpx_type * const Fout_beg = Fout;
  154. cpx_type * const Fout_end = Fout + p*m;
  155. if (m==1) {
  156. do{
  157. *Fout = *f;
  158. f += fstride*in_stride;
  159. }while(++Fout != Fout_end );
  160. }else{
  161. do{
  162. // recursive call:
  163. // DFT of size m*p performed by doing
  164. // p instances of smaller DFTs of size m,
  165. // each one takes a decimated version of the input
  166. kf_work(stage+1, Fout , f, fstride*p,in_stride);
  167. f += fstride*in_stride;
  168. }while( (Fout += m) != Fout_end );
  169. }
  170. Fout=Fout_beg;
  171. // recombine the p smaller DFTs
  172. switch (p) {
  173. case 2: kf_bfly2(Fout,fstride,m); break;
  174. case 3: kf_bfly3(Fout,fstride,m); break;
  175. case 4: kf_bfly4(Fout,fstride,m); break;
  176. case 5: kf_bfly5(Fout,fstride,m); break;
  177. default: kf_bfly_generic(Fout,fstride,m,p); break;
  178. }
  179. }
  180. void kf_bfly2( cpx_type * Fout, const size_t fstride, std::size_t m) const
  181. {
  182. for (std::size_t k=0;k<m;++k) {
  183. const cpx_type t = Fout[m+k] * _twiddles[k*fstride];
  184. Fout[m+k] = Fout[k] - t;
  185. Fout[k] += t;
  186. }
  187. }
  188. void kf_bfly4( cpx_type * Fout, const std::size_t fstride, const std::size_t m) const
  189. {
  190. cpx_type scratch[7];
  191. const scalar_type negative_if_inverse = _inverse ? -1 : +1;
  192. for (std::size_t k=0;k<m;++k) {
  193. scratch[0] = Fout[k+ m] * _twiddles[k*fstride ];
  194. scratch[1] = Fout[k+2*m] * _twiddles[k*fstride*2];
  195. scratch[2] = Fout[k+3*m] * _twiddles[k*fstride*3];
  196. scratch[5] = Fout[k] - scratch[1];
  197. Fout[k] += scratch[1];
  198. scratch[3] = scratch[0] + scratch[2];
  199. scratch[4] = scratch[0] - scratch[2];
  200. scratch[4] = cpx_type( scratch[4].imag()*negative_if_inverse ,
  201. -scratch[4].real()*negative_if_inverse );
  202. Fout[k+2*m] = Fout[k] - scratch[3];
  203. Fout[k ]+= scratch[3];
  204. Fout[k+ m] = scratch[5] + scratch[4];
  205. Fout[k+3*m] = scratch[5] - scratch[4];
  206. }
  207. }
  208. void kf_bfly3( cpx_type * Fout, const std::size_t fstride, const std::size_t m) const
  209. {
  210. std::size_t k=m;
  211. const std::size_t m2 = 2*m;
  212. const cpx_type *tw1,*tw2;
  213. cpx_type scratch[5];
  214. const cpx_type epi3 = _twiddles[fstride*m];
  215. tw1=tw2=&_twiddles[0];
  216. do{
  217. scratch[1] = Fout[m] * *tw1;
  218. scratch[2] = Fout[m2] * *tw2;
  219. scratch[3] = scratch[1] + scratch[2];
  220. scratch[0] = scratch[1] - scratch[2];
  221. tw1 += fstride;
  222. tw2 += fstride*2;
  223. Fout[m] = Fout[0] - scratch[3]*scalar_type(0.5);
  224. scratch[0] *= epi3.imag();
  225. Fout[0] += scratch[3];
  226. Fout[m2] = cpx_type( Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() );
  227. Fout[m] += cpx_type( -scratch[0].imag(),scratch[0].real() );
  228. ++Fout;
  229. }while(--k);
  230. }
  231. void kf_bfly5( cpx_type * Fout, const std::size_t fstride, const std::size_t m) const
  232. {
  233. cpx_type *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
  234. cpx_type scratch[13];
  235. const cpx_type ya = _twiddles[fstride*m];
  236. const cpx_type yb = _twiddles[fstride*2*m];
  237. Fout0=Fout;
  238. Fout1=Fout0+m;
  239. Fout2=Fout0+2*m;
  240. Fout3=Fout0+3*m;
  241. Fout4=Fout0+4*m;
  242. for ( std::size_t u=0; u<m; ++u ) {
  243. scratch[0] = *Fout0;
  244. scratch[1] = *Fout1 * _twiddles[ u*fstride];
  245. scratch[2] = *Fout2 * _twiddles[2*u*fstride];
  246. scratch[3] = *Fout3 * _twiddles[3*u*fstride];
  247. scratch[4] = *Fout4 * _twiddles[4*u*fstride];
  248. scratch[7] = scratch[1] + scratch[4];
  249. scratch[10]= scratch[1] - scratch[4];
  250. scratch[8] = scratch[2] + scratch[3];
  251. scratch[9] = scratch[2] - scratch[3];
  252. *Fout0 += scratch[7];
  253. *Fout0 += scratch[8];
  254. scratch[5] = scratch[0] + cpx_type(
  255. scratch[7].real()*ya.real() + scratch[8].real()*yb.real(),
  256. scratch[7].imag()*ya.real() + scratch[8].imag()*yb.real()
  257. );
  258. scratch[6] = cpx_type(
  259. scratch[10].imag()*ya.imag() + scratch[9].imag()*yb.imag(),
  260. -scratch[10].real()*ya.imag() - scratch[9].real()*yb.imag()
  261. );
  262. *Fout1 = scratch[5] - scratch[6];
  263. *Fout4 = scratch[5] + scratch[6];
  264. scratch[11] = scratch[0] +
  265. cpx_type(
  266. scratch[7].real()*yb.real() + scratch[8].real()*ya.real(),
  267. scratch[7].imag()*yb.real() + scratch[8].imag()*ya.real()
  268. );
  269. scratch[12] = cpx_type(
  270. -scratch[10].imag()*yb.imag() + scratch[9].imag()*ya.imag(),
  271. scratch[10].real()*yb.imag() - scratch[9].real()*ya.imag()
  272. );
  273. *Fout2 = scratch[11] + scratch[12];
  274. *Fout3 = scratch[11] - scratch[12];
  275. ++Fout0;
  276. ++Fout1;
  277. ++Fout2;
  278. ++Fout3;
  279. ++Fout4;
  280. }
  281. }
  282. /* perform the butterfly for one stage of a mixed radix FFT */
  283. void kf_bfly_generic(
  284. cpx_type * Fout,
  285. const size_t fstride,
  286. std::size_t m,
  287. std::size_t p
  288. ) const
  289. {
  290. const cpx_type * twiddles = &_twiddles[0];
  291. cpx_type scratchbuf[p];
  292. for ( std::size_t u=0; u<m; ++u ) {
  293. std::size_t k = u;
  294. for ( std::size_t q1=0 ; q1<p ; ++q1 ) {
  295. scratchbuf[q1] = Fout[ k ];
  296. k += m;
  297. }
  298. k=u;
  299. for ( std::size_t q1=0 ; q1<p ; ++q1 ) {
  300. std::size_t twidx=0;
  301. Fout[ k ] = scratchbuf[0];
  302. for ( std::size_t q=1;q<p;++q ) {
  303. twidx += fstride * k;
  304. if (twidx>=_nfft)
  305. twidx-=_nfft;
  306. Fout[ k ] += scratchbuf[q] * twiddles[twidx];
  307. }
  308. k += m;
  309. }
  310. }
  311. }
  312. std::size_t _nfft;
  313. bool _inverse;
  314. std::vector<cpx_type> _twiddles;
  315. std::vector<std::size_t> _stageRadix;
  316. std::vector<std::size_t> _stageRemainder;
  317. };
  318. #endif