Merge MDCT orthogonalization factor in rot twiddles (remove 1 mult by samples in common case)

This commit is contained in:
Antoine Soulier
2023-05-08 10:44:34 -07:00
parent 79795643ce
commit 49101e4bbc
3 changed files with 537 additions and 518 deletions

View File

@@ -267,12 +267,11 @@ LC3_HOT static void mdct_pre_fft(const struct lc3_mdct_rot_def *def,
* Post-rotate FFT N/4 points coefficients, resulting MDCT N points
* def Size and twiddles factors
* x, y Input and output coefficients
* scale Scale on output coefficients
*
* `x` and y` can be the same buffer
*/
LC3_HOT static void mdct_post_fft(const struct lc3_mdct_rot_def *def,
const struct lc3_complex *x, float *y, float scale)
const struct lc3_complex *x, float *y)
{
int n4 = def->n4, n8 = n4 >> 1;
@@ -283,11 +282,11 @@ LC3_HOT static void mdct_post_fft(const struct lc3_mdct_rot_def *def,
for ( ; y1 > y; x0++, x1--, w0++, w1--) {
float u0 = (x0->im * w0->im + x0->re * w0->re) * scale;
float u1 = (x1->re * w1->im - x1->im * w1->re) * scale;
float u0 = x0->im * w0->im + x0->re * w0->re;
float u1 = x1->re * w1->im - x1->im * w1->re;
float v0 = (x0->re * w0->im - x0->im * w0->re) * scale;
float v1 = (x1->im * w1->im + x1->re * w1->re) * scale;
float v0 = x0->re * w0->im - x0->im * w0->re;
float v1 = x1->im * w1->im + x1->re * w1->re;
*(y0++) = u0; *(y0++) = u1;
*(--y1) = v0; *(--y1) = v1;
@@ -330,14 +329,13 @@ LC3_HOT static void imdct_pre_fft(const struct lc3_mdct_rot_def *def,
* Post-rotate FFT N/4 points coefficients, resulting IMDCT N points
* def Size and twiddles factors
* x, y Input and output coefficients
* scale Scale on output coefficients
*
* `x` and y` can be the same buffer
* The real and imaginary parts of `x` are swapped,
* to operate on FFT instead of IFFT
*/
LC3_HOT static void imdct_post_fft(const struct lc3_mdct_rot_def *def,
const struct lc3_complex *x, float *y, float scale)
const struct lc3_complex *x, float *y)
{
int n4 = def->n4;
@@ -350,11 +348,11 @@ LC3_HOT static void imdct_post_fft(const struct lc3_mdct_rot_def *def,
struct lc3_complex uz = *(x0++), vz = *(--x1);
struct lc3_complex uw = *(w0++), vw = *(--w1);
*(y0++) = (uz.re * uw.im - uz.im * uw.re) * scale;
*(--y1) = (uz.re * uw.re + uz.im * uw.im) * scale;
*(y0++) = uz.re * uw.im - uz.im * uw.re;
*(--y1) = uz.re * uw.re + uz.im * uw.im;
*(--y1) = (vz.re * vw.im - vz.im * vw.re) * scale;
*(y0++) = (vz.re * vw.re + vz.im * vw.im) * scale;
*(--y1) = vz.re * vw.im - vz.im * vw.re;
*(y0++) = vz.re * vw.re + vz.im * vw.im;
}
}
@@ -409,6 +407,19 @@ LC3_HOT static void imdct_window(enum lc3_dt dt, enum lc3_srate sr,
}
}
/**
* Rescale samples
* x, n Input and count of samples, scaled as output
* scale Scale factor
*/
LC3_HOT static void rescale(float *x, int n, float f)
{
for (int i = 0; i < (n >> 2); i++) {
*(x++) *= f; *(x++) *= f;
*(x++) *= f; *(x++) *= f;
}
}
/**
* Forward MDCT transformation
*/
@@ -416,7 +427,7 @@ void lc3_mdct_forward(enum lc3_dt dt, enum lc3_srate sr,
enum lc3_srate sr_dst, const float *x, float *d, float *y)
{
const struct lc3_mdct_rot_def *rot = lc3_mdct_rot[dt][sr];
int nf = LC3_NS(dt, sr_dst);
int ns_dst = LC3_NS(dt, sr_dst);
int ns = LC3_NS(dt, sr);
struct lc3_complex buffer[LC3_MAX_NS / 2];
@@ -427,7 +438,10 @@ void lc3_mdct_forward(enum lc3_dt dt, enum lc3_srate sr,
mdct_pre_fft(rot, u.f, u.z);
u.z = fft(u.z, ns/2, u.z, z);
mdct_post_fft(rot, u.z, y, sqrtf( (2.f*nf) / (ns*ns) ));
mdct_post_fft(rot, u.z, y);
if (ns != ns_dst)
rescale(y, ns_dst, sqrtf((float)ns_dst / ns));
}
/**
@@ -437,7 +451,7 @@ void lc3_mdct_inverse(enum lc3_dt dt, enum lc3_srate sr,
enum lc3_srate sr_src, const float *x, float *d, float *y)
{
const struct lc3_mdct_rot_def *rot = lc3_mdct_rot[dt][sr];
int nf = LC3_NS(dt, sr_src);
int ns_src = LC3_NS(dt, sr_src);
int ns = LC3_NS(dt, sr);
struct lc3_complex buffer[LC3_MAX_NS / 2];
@@ -446,7 +460,10 @@ void lc3_mdct_inverse(enum lc3_dt dt, enum lc3_srate sr,
imdct_pre_fft(rot, x, z);
z = fft(z, ns/2, z, u.z);
imdct_post_fft(rot, z, u.f, sqrtf(2.f / nf));
imdct_post_fft(rot, z, u.f);
if (ns != ns_src)
rescale(u.f, ns, sqrtf((float)ns / ns_src));
imdct_window(dt, sr, u.f, d, y);
}