Commit b20a2d63 authored by Henrik Gramner's avatar Henrik Gramner

x86-64: Add msac_decode_bool_equi asm

parent 30d5f486
......@@ -27,7 +27,6 @@
#include "config.h"
#include <assert.h>
#include <limits.h>
#include "common/intops.h"
......@@ -68,7 +67,7 @@ static inline void ctx_norm(MsacContext *s, ec_win dif, unsigned rng) {
ctx_refill(s);
}
unsigned dav1d_msac_decode_bool_equi(MsacContext *const s) {
unsigned dav1d_msac_decode_bool_equi_c(MsacContext *const s) {
ec_win vw, dif = s->dif;
unsigned ret, v, r = s->rng;
assert((dif >> (EC_WIN_SIZE - 16)) < r);
......@@ -99,13 +98,6 @@ unsigned dav1d_msac_decode_bool(MsacContext *const s, const unsigned f) {
return !ret;
}
unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
unsigned v = 0;
while (n--)
v = (v << 1) | dav1d_msac_decode_bool_equi(s);
return v;
}
int dav1d_msac_decode_subexp(MsacContext *const s, const int ref,
const int n, const unsigned k)
{
......@@ -122,15 +114,6 @@ int dav1d_msac_decode_subexp(MsacContext *const s, const int ref,
n - 1 - inv_recenter(n - 1 - ref, v);
}
int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
assert(n > 0);
const int l = ulog2(n) + 1;
assert(l > 1);
const unsigned m = (1 << l) - n;
const unsigned v = dav1d_msac_decode_bools(s, l - 1);
return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
}
/* Decodes a symbol given an inverse cumulative distribution function (CDF)
* table in Q15. */
static unsigned decode_symbol(MsacContext *const s, const uint16_t *const cdf,
......
......@@ -28,6 +28,7 @@
#ifndef DAV1D_SRC_MSAC_H
#define DAV1D_SRC_MSAC_H
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
......@@ -47,12 +48,10 @@ void dav1d_msac_init(MsacContext *s, const uint8_t *data, size_t sz,
int disable_cdf_update_flag);
unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
unsigned dav1d_msac_decode_bool_equi(MsacContext *s);
unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bools(MsacContext *s, unsigned n);
int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
int dav1d_msac_decode_uniform(MsacContext *s, unsigned n);
/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
#if ARCH_AARCH64 && HAVE_ASM
......@@ -65,6 +64,7 @@ unsigned dav1d_msac_decode_symbol_adapt16_neon(MsacContext *s, uint16_t *cdf,
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_neon
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_neon
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#elif ARCH_X86_64 && HAVE_ASM
unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
......@@ -72,13 +72,32 @@ unsigned dav1d_msac_decode_symbol_adapt8_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_sse2
#else
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#endif
static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
unsigned v = 0;
while (n--)
v = (v << 1) | dav1d_msac_decode_bool_equi(s);
return v;
}
static inline int dav1d_msac_decode_uniform(MsacContext *const s, const unsigned n) {
assert(n > 0);
const int l = ulog2(n) + 1;
assert(l > 1);
const unsigned m = (1 << l) - n;
const unsigned v = dav1d_msac_decode_bools(s, l - 1);
return v < m ? v : (v << 1) - m + dav1d_msac_decode_bool_equi(s);
}
#endif /* DAV1D_SRC_MSAC_H */
......@@ -111,6 +111,7 @@ cglobal msac_decode_symbol_adapt4, 3, 7, 6, s, cdf, ns
sub r2d, r1d ; rng
shl r1, 48
add r4, r1 ; ~dif
.renorm3:
mov r1d, [sq+msac.cnt]
movifnidn t0, sq
bsr ecx, r2d
......@@ -284,4 +285,21 @@ cglobal msac_decode_symbol_adapt16, 3, 7, 6, s, cdf, ns
%endif
jmp m(msac_decode_symbol_adapt4).renorm2
cglobal msac_decode_bool_equi, 1, 7, 0, s
mov r1d, [sq+msac.rng]
mov r4, [sq+msac.dif]
mov r2d, r1d
mov r1b, 8
mov r3, r4
mov eax, r1d
shr r1d, 1 ; v
shl rax, 47 ; vw
sub r2d, r1d ; r - v
sub r4, rax ; dif - vw
cmovb r2d, r1d
cmovb r4, r3
setb al ; the upper 32 bits contains garbage but that's OK
not r4
jmp m(msac_decode_symbol_adapt4).renorm3
%endif
......@@ -32,14 +32,18 @@
#include <string.h>
#define BUF_SIZE 8192
/* The normal code doesn't use function pointers */
typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
typedef struct {
decode_symbol_adapt_fn symbol_adapt4;
decode_symbol_adapt_fn symbol_adapt8;
decode_symbol_adapt_fn symbol_adapt16;
decode_bool_equi_fn bool_equi;
} MsacDSPContext;
static void randomize_cdf(uint16_t *const cdf, int n) {
......@@ -61,7 +65,7 @@ static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
if (check_func(c->symbol_adapt##n, "msac_decode_symbol_adapt%d", n)) { \
for (int cdf_update = 0; cdf_update <= 1; cdf_update++) { \
for (int ns = n_min; ns <= n_max; ns++) { \
dav1d_msac_init(&s_c, buf, sizeof(buf), !cdf_update); \
dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update); \
s_a = s_c; \
randomize_cdf(cdf[0], ns); \
memcpy(cdf[1], cdf[0], sizeof(*cdf)); \
......@@ -81,14 +85,13 @@ static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
} \
} while (0)
static void check_decode_symbol_adapt(MsacDSPContext *const c) {
static void check_decode_symbol_adapt(MsacDSPContext *const c,
uint8_t *const buf)
{
/* Use an aligned CDF buffer for more consistent benchmark
* results, and a misaligned one for checking correctness. */
ALIGN_STK_16(uint16_t, cdf, 2, [17]);
MsacContext s_c, s_a;
uint8_t buf[1024];
for (int i = 0; i < 1024; i++)
buf[i] = rnd();
declare_func(unsigned, MsacContext *s, uint16_t *cdf, size_t n_symbols);
CHECK_SYMBOL_ADAPT( 4, 1, 5);
......@@ -97,11 +100,33 @@ static void check_decode_symbol_adapt(MsacDSPContext *const c) {
report("decode_symbol_adapt");
}
static void check_decode_bool_equi(MsacDSPContext *const c,
uint8_t *const buf)
{
declare_func(unsigned, MsacContext *s);
if (check_func(c->bool_equi, "msac_decode_bool_equi")) {
MsacContext s_c, s_a;
dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
s_a = s_c;
for (int i = 0; i < 64; i++) {
unsigned c_res = call_ref(&s_c);
unsigned a_res = call_new(&s_a);
if (c_res != a_res || msac_cmp(&s_c, &s_a))
fail();
}
bench_new(&s_a);
}
report("decode_bool_equi");
}
void checkasm_check_msac(void) {
MsacDSPContext c;
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
c.bool_equi = dav1d_msac_decode_bool_equi_c;
#if ARCH_AARCH64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
......@@ -114,8 +139,14 @@ void checkasm_check_msac(void) {
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
c.bool_equi = dav1d_msac_decode_bool_equi_sse2;
}
#endif
check_decode_symbol_adapt(&c);
uint8_t buf[BUF_SIZE];
for (int i = 0; i < BUF_SIZE; i++)
buf[i] = rnd();
check_decode_symbol_adapt(&c, buf);
check_decode_bool_equi(&c, buf);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment