Commit e16e2726 authored by Henrik Gramner's avatar Henrik Gramner Committed by Henrik Gramner

x86-64: Add msac_decode_bool and msac_decode_bool_adapt asm

parent e25ed555
Pipeline #6961 passed with stages
in 7 minutes and 42 seconds
...@@ -85,7 +85,7 @@ unsigned dav1d_msac_decode_bool_equi_c(MsacContext *const s) { ...@@ -85,7 +85,7 @@ unsigned dav1d_msac_decode_bool_equi_c(MsacContext *const s) {
/* Decode a single binary value. /* Decode a single binary value.
* f: The probability that the bit is one * f: The probability that the bit is one
* Return: The value decoded (0 or 1). */ * Return: The value decoded (0 or 1). */
unsigned dav1d_msac_decode_bool(MsacContext *const s, const unsigned f) { unsigned dav1d_msac_decode_bool_c(MsacContext *const s, const unsigned f) {
ec_win vw, dif = s->dif; ec_win vw, dif = s->dif;
unsigned ret, v, r = s->rng; unsigned ret, v, r = s->rng;
assert((dif >> (EC_WIN_SIZE - 16)) < r); assert((dif >> (EC_WIN_SIZE - 16)) < r);
...@@ -155,7 +155,7 @@ unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *const s, ...@@ -155,7 +155,7 @@ unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *const s,
return val; return val;
} }
unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s, unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *const s,
uint16_t *const cdf) uint16_t *const cdf)
{ {
const unsigned bit = dav1d_msac_decode_bool(s, *cdf); const unsigned bit = dav1d_msac_decode_bool(s, *cdf);
...@@ -164,11 +164,10 @@ unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s, ...@@ -164,11 +164,10 @@ unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s,
// update_cdf() specialized for boolean CDFs // update_cdf() specialized for boolean CDFs
const unsigned count = cdf[1]; const unsigned count = cdf[1];
const int rate = (count >> 4) | 4; const int rate = (count >> 4) | 4;
if (bit) { if (bit)
cdf[0] += (32768 - cdf[0]) >> rate; cdf[0] += (32768 - cdf[0]) >> rate;
} else { else
cdf[0] -= cdf[0] >> rate; cdf[0] -= cdf[0] >> rate;
}
cdf[1] = count + (count < 32); cdf[1] = count + (count < 32);
} }
......
...@@ -48,9 +48,9 @@ void dav1d_msac_init(MsacContext *s, const uint8_t *data, size_t sz, ...@@ -48,9 +48,9 @@ void dav1d_msac_init(MsacContext *s, const uint8_t *data, size_t sz,
int disable_cdf_update_flag); int disable_cdf_update_flag);
unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf, unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
size_t n_symbols); size_t n_symbols);
unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s); unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f); unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k); int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */ /* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
...@@ -64,7 +64,9 @@ unsigned dav1d_msac_decode_symbol_adapt16_neon(MsacContext *s, uint16_t *cdf, ...@@ -64,7 +64,9 @@ unsigned dav1d_msac_decode_symbol_adapt16_neon(MsacContext *s, uint16_t *cdf,
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_neon #define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_neon
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_neon #define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_neon
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_c
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c #define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
#elif ARCH_X86_64 && HAVE_ASM #elif ARCH_X86_64 && HAVE_ASM
unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf, unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols); size_t n_symbols);
...@@ -72,16 +74,22 @@ unsigned dav1d_msac_decode_symbol_adapt8_sse2(MsacContext *s, uint16_t *cdf, ...@@ -72,16 +74,22 @@ unsigned dav1d_msac_decode_symbol_adapt8_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols); size_t n_symbols);
unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf, unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols); size_t n_symbols);
unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s); unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2 #define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2 #define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_sse2
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_sse2 #define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_sse2
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_sse2
#else #else
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt_c #define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt_c #define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_c
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c #define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
#endif #endif
static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) { static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
......
...@@ -114,6 +114,7 @@ cglobal msac_decode_symbol_adapt4, 3, 7, 6, s, cdf, ns ...@@ -114,6 +114,7 @@ cglobal msac_decode_symbol_adapt4, 3, 7, 6, s, cdf, ns
.renorm3: .renorm3:
mov r1d, [sq+msac.cnt] mov r1d, [sq+msac.cnt]
movifnidn t0, sq movifnidn t0, sq
.renorm4:
bsr ecx, r2d bsr ecx, r2d
xor ecx, 15 ; d xor ecx, 15 ; d
shl r2d, cl shl r2d, cl
...@@ -285,6 +286,58 @@ cglobal msac_decode_symbol_adapt16, 3, 7, 6, s, cdf, ns ...@@ -285,6 +286,58 @@ cglobal msac_decode_symbol_adapt16, 3, 7, 6, s, cdf, ns
%endif %endif
jmp m(msac_decode_symbol_adapt4).renorm2 jmp m(msac_decode_symbol_adapt4).renorm2
cglobal msac_decode_bool_adapt, 2, 7, 0, s, cdf
movzx eax, word [cdfq]
movzx r3d, byte [sq+msac.rng+1]
mov r4, [sq+msac.dif]
mov r2d, [sq+msac.rng]
mov r5d, eax
and eax, ~63
imul eax, r3d
%if UNIX64
mov r7, r4
%endif
shr eax, 7
add eax, 4 ; v
mov r3d, eax
shl rax, 48 ; vw
sub r2d, r3d ; r - v
sub r4, rax ; dif - vw
cmovb r2d, r3d
mov r3d, [sq+msac.update_cdf]
%if UNIX64
cmovb r4, r7
%else
cmovb r4, [sq+msac.dif]
%endif
setb al
not r4
test r3d, r3d
jz m(msac_decode_symbol_adapt4).renorm3
%if WIN64
push r7
%endif
movzx r7d, word [cdfq+2]
movifnidn t0, sq
lea ecx, [r7+64]
cmp r7d, 32
adc r7d, 0
mov [cdfq+2], r7w
imul r7d, eax, -32769
shr ecx, 4 ; rate
add r7d, r5d ; if (bit)
sub r5d, eax ; cdf[0] -= ((cdf[0] - 32769) >> rate) + 1;
sar r7d, cl ; else
sub r5d, r7d ; cdf[0] -= cdf[0] >> rate;
mov [cdfq], r5w
%if WIN64
mov r1d, [t0+msac.cnt]
pop r7
jmp m(msac_decode_symbol_adapt4).renorm4
%else
jmp m(msac_decode_symbol_adapt4).renorm3
%endif
cglobal msac_decode_bool_equi, 1, 7, 0, s cglobal msac_decode_bool_equi, 1, 7, 0, s
mov r1d, [sq+msac.rng] mov r1d, [sq+msac.rng]
mov r4, [sq+msac.dif] mov r4, [sq+msac.dif]
...@@ -302,4 +355,23 @@ cglobal msac_decode_bool_equi, 1, 7, 0, s ...@@ -302,4 +355,23 @@ cglobal msac_decode_bool_equi, 1, 7, 0, s
not r4 not r4
jmp m(msac_decode_symbol_adapt4).renorm3 jmp m(msac_decode_symbol_adapt4).renorm3
cglobal msac_decode_bool, 2, 7, 0, s, f
movzx eax, byte [sq+msac.rng+1] ; r >> 8
mov r4, [sq+msac.dif]
mov r2d, [sq+msac.rng]
and r1d, ~63
imul eax, r1d
mov r3, r4
shr eax, 7
add eax, 4 ; v
mov r1d, eax
shl rax, 48 ; vw
sub r2d, r1d ; r - v
sub r4, rax ; dif - vw
cmovb r2d, r1d
cmovb r4, r3
setb al
not r4
jmp m(msac_decode_symbol_adapt4).renorm3
%endif %endif
...@@ -37,13 +37,17 @@ ...@@ -37,13 +37,17 @@
/* The normal code doesn't use function pointers */ /* The normal code doesn't use function pointers */
typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf, typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
size_t n_symbols); size_t n_symbols);
typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf);
typedef unsigned (*decode_bool_equi_fn)(MsacContext *s); typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
typedef struct { typedef struct {
decode_symbol_adapt_fn symbol_adapt4; decode_symbol_adapt_fn symbol_adapt4;
decode_symbol_adapt_fn symbol_adapt8; decode_symbol_adapt_fn symbol_adapt8;
decode_symbol_adapt_fn symbol_adapt16; decode_symbol_adapt_fn symbol_adapt16;
decode_bool_adapt_fn bool_adapt;
decode_bool_equi_fn bool_equi; decode_bool_equi_fn bool_equi;
decode_bool_fn bool;
} MsacDSPContext; } MsacDSPContext;
static void randomize_cdf(uint16_t *const cdf, int n) { static void randomize_cdf(uint16_t *const cdf, int n) {
...@@ -85,9 +89,7 @@ static int msac_cmp(const MsacContext *const a, const MsacContext *const b) { ...@@ -85,9 +89,7 @@ static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
} \ } \
} while (0) } while (0)
static void check_decode_symbol_adapt(MsacDSPContext *const c, static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
uint8_t *const buf)
{
/* Use an aligned CDF buffer for more consistent benchmark /* Use an aligned CDF buffer for more consistent benchmark
* results, and a misaligned one for checking correctness. */ * results, and a misaligned one for checking correctness. */
ALIGN_STK_16(uint16_t, cdf, 2, [17]); ALIGN_STK_16(uint16_t, cdf, 2, [17]);
...@@ -97,16 +99,36 @@ static void check_decode_symbol_adapt(MsacDSPContext *const c, ...@@ -97,16 +99,36 @@ static void check_decode_symbol_adapt(MsacDSPContext *const c,
CHECK_SYMBOL_ADAPT( 4, 1, 5); CHECK_SYMBOL_ADAPT( 4, 1, 5);
CHECK_SYMBOL_ADAPT( 8, 1, 8); CHECK_SYMBOL_ADAPT( 8, 1, 8);
CHECK_SYMBOL_ADAPT(16, 4, 16); CHECK_SYMBOL_ADAPT(16, 4, 16);
report("decode_symbol_adapt"); report("decode_symbol");
} }
static void check_decode_bool_equi(MsacDSPContext *const c, static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
uint8_t *const buf) MsacContext s_c, s_a;
{
declare_func(unsigned, MsacContext *s); if (check_func(c->bool_adapt, "msac_decode_bool_adapt")) {
declare_func(unsigned, MsacContext *s, uint16_t *cdf);
uint16_t cdf[2][2];
for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
s_a = s_c;
cdf[0][0] = cdf[1][0] = rnd() % 32767 + 1;
cdf[0][1] = cdf[1][1] = 0;
for (int i = 0; i < 64; i++) {
unsigned c_res = call_ref(&s_c, cdf[0]);
unsigned a_res = call_new(&s_a, cdf[1]);
if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
memcmp(cdf[0], cdf[1], sizeof(*cdf)))
{
fail();
}
}
if (cdf_update)
bench_new(&s_a, cdf[0]);
}
}
if (check_func(c->bool_equi, "msac_decode_bool_equi")) { if (check_func(c->bool_equi, "msac_decode_bool_equi")) {
MsacContext s_c, s_a; declare_func(unsigned, MsacContext *s);
dav1d_msac_init(&s_c, buf, BUF_SIZE, 1); dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
s_a = s_c; s_a = s_c;
for (int i = 0; i < 64; i++) { for (int i = 0; i < 64; i++) {
...@@ -118,7 +140,21 @@ static void check_decode_bool_equi(MsacDSPContext *const c, ...@@ -118,7 +140,21 @@ static void check_decode_bool_equi(MsacDSPContext *const c,
bench_new(&s_a); bench_new(&s_a);
} }
report("decode_bool_equi"); if (check_func(c->bool, "msac_decode_bool")) {
declare_func(unsigned, MsacContext *s, unsigned f);
dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
s_a = s_c;
for (int i = 0; i < 64; i++) {
const unsigned f = rnd() & 0x7fff;
unsigned c_res = call_ref(&s_c, f);
unsigned a_res = call_new(&s_a, f);
if (c_res != a_res || msac_cmp(&s_c, &s_a))
fail();
}
bench_new(&s_a, 16384);
}
report("decode_bool");
} }
void checkasm_check_msac(void) { void checkasm_check_msac(void) {
...@@ -126,7 +162,9 @@ void checkasm_check_msac(void) { ...@@ -126,7 +162,9 @@ void checkasm_check_msac(void) {
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c; c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c; c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c; c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
c.bool_adapt = dav1d_msac_decode_bool_adapt_c;
c.bool_equi = dav1d_msac_decode_bool_equi_c; c.bool_equi = dav1d_msac_decode_bool_equi_c;
c.bool = dav1d_msac_decode_bool_c;
#if ARCH_AARCH64 && HAVE_ASM #if ARCH_AARCH64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) { if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
...@@ -139,7 +177,9 @@ void checkasm_check_msac(void) { ...@@ -139,7 +177,9 @@ void checkasm_check_msac(void) {
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2; c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2; c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2; c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
c.bool_adapt = dav1d_msac_decode_bool_adapt_sse2;
c.bool_equi = dav1d_msac_decode_bool_equi_sse2; c.bool_equi = dav1d_msac_decode_bool_equi_sse2;
c.bool = dav1d_msac_decode_bool_sse2;
} }
#endif #endif
...@@ -147,6 +187,6 @@ void checkasm_check_msac(void) { ...@@ -147,6 +187,6 @@ void checkasm_check_msac(void) {
for (int i = 0; i < BUF_SIZE; i++) for (int i = 0; i < BUF_SIZE; i++)
buf[i] = rnd(); buf[i] = rnd();
check_decode_symbol_adapt(&c, buf); check_decode_symbol(&c, buf);
check_decode_bool_equi(&c, buf); check_decode_bool(&c, buf);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment