Commit e16e2726 authored by Henrik Gramner's avatar Henrik Gramner Committed by Henrik Gramner

x86-64: Add msac_decode_bool and msac_decode_bool_adapt asm

parent e25ed555
Pipeline #6961 passed with stages
in 7 minutes and 42 seconds
......@@ -85,7 +85,7 @@ unsigned dav1d_msac_decode_bool_equi_c(MsacContext *const s) {
/* Decode a single binary value.
* f: The probability that the bit is one
* Return: The value decoded (0 or 1). */
unsigned dav1d_msac_decode_bool(MsacContext *const s, const unsigned f) {
unsigned dav1d_msac_decode_bool_c(MsacContext *const s, const unsigned f) {
ec_win vw, dif = s->dif;
unsigned ret, v, r = s->rng;
assert((dif >> (EC_WIN_SIZE - 16)) < r);
......@@ -155,8 +155,8 @@ unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *const s,
return val;
}
unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s,
uint16_t *const cdf)
unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *const s,
uint16_t *const cdf)
{
const unsigned bit = dav1d_msac_decode_bool(s, *cdf);
......@@ -164,11 +164,10 @@ unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s,
// update_cdf() specialized for boolean CDFs
const unsigned count = cdf[1];
const int rate = (count >> 4) | 4;
if (bit) {
if (bit)
cdf[0] += (32768 - cdf[0]) >> rate;
} else {
else
cdf[0] -= cdf[0] >> rate;
}
cdf[1] = count + (count < 32);
}
......
......@@ -48,9 +48,9 @@ void dav1d_msac_init(MsacContext *s, const uint8_t *data, size_t sz,
int disable_cdf_update_flag);
unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
......@@ -64,7 +64,9 @@ unsigned dav1d_msac_decode_symbol_adapt16_neon(MsacContext *s, uint16_t *cdf,
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_neon
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_neon
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_c
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
#elif ARCH_X86_64 && HAVE_ASM
unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
......@@ -72,16 +74,22 @@ unsigned dav1d_msac_decode_symbol_adapt8_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_sse2
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_sse2
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_sse2
#else
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_c
#define dav1d_msac_decode_bool_equi dav1d_msac_decode_bool_equi_c
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
#endif
static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
......
......@@ -114,6 +114,7 @@ cglobal msac_decode_symbol_adapt4, 3, 7, 6, s, cdf, ns
.renorm3:
mov r1d, [sq+msac.cnt]
movifnidn t0, sq
.renorm4:
bsr ecx, r2d
xor ecx, 15 ; d
shl r2d, cl
......@@ -285,6 +286,58 @@ cglobal msac_decode_symbol_adapt16, 3, 7, 6, s, cdf, ns
%endif
jmp m(msac_decode_symbol_adapt4).renorm2
cglobal msac_decode_bool_adapt, 2, 7, 0, s, cdf
movzx eax, word [cdfq]
movzx r3d, byte [sq+msac.rng+1]
mov r4, [sq+msac.dif]
mov r2d, [sq+msac.rng]
mov r5d, eax
and eax, ~63
imul eax, r3d
%if UNIX64
mov r7, r4
%endif
shr eax, 7
add eax, 4 ; v
mov r3d, eax
shl rax, 48 ; vw
sub r2d, r3d ; r - v
sub r4, rax ; dif - vw
cmovb r2d, r3d
mov r3d, [sq+msac.update_cdf]
%if UNIX64
cmovb r4, r7
%else
cmovb r4, [sq+msac.dif]
%endif
setb al
not r4
test r3d, r3d
jz m(msac_decode_symbol_adapt4).renorm3
%if WIN64
push r7
%endif
movzx r7d, word [cdfq+2]
movifnidn t0, sq
lea ecx, [r7+64]
cmp r7d, 32
adc r7d, 0
mov [cdfq+2], r7w
imul r7d, eax, -32769
shr ecx, 4 ; rate
add r7d, r5d ; if (bit)
sub r5d, eax ; cdf[0] -= ((cdf[0] - 32769) >> rate) + 1;
sar r7d, cl ; else
sub r5d, r7d ; cdf[0] -= cdf[0] >> rate;
mov [cdfq], r5w
%if WIN64
mov r1d, [t0+msac.cnt]
pop r7
jmp m(msac_decode_symbol_adapt4).renorm4
%else
jmp m(msac_decode_symbol_adapt4).renorm3
%endif
cglobal msac_decode_bool_equi, 1, 7, 0, s
mov r1d, [sq+msac.rng]
mov r4, [sq+msac.dif]
......@@ -302,4 +355,23 @@ cglobal msac_decode_bool_equi, 1, 7, 0, s
not r4
jmp m(msac_decode_symbol_adapt4).renorm3
cglobal msac_decode_bool, 2, 7, 0, s, f
movzx eax, byte [sq+msac.rng+1] ; r >> 8
mov r4, [sq+msac.dif]
mov r2d, [sq+msac.rng]
and r1d, ~63
imul eax, r1d
mov r3, r4
shr eax, 7
add eax, 4 ; v
mov r1d, eax
shl rax, 48 ; vw
sub r2d, r1d ; r - v
sub r4, rax ; dif - vw
cmovb r2d, r1d
cmovb r4, r3
setb al
not r4
jmp m(msac_decode_symbol_adapt4).renorm3
%endif
......@@ -37,13 +37,17 @@
/* The normal code doesn't use function pointers */
typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf);
typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
typedef struct {
decode_symbol_adapt_fn symbol_adapt4;
decode_symbol_adapt_fn symbol_adapt8;
decode_symbol_adapt_fn symbol_adapt16;
decode_bool_adapt_fn bool_adapt;
decode_bool_equi_fn bool_equi;
decode_bool_fn bool;
} MsacDSPContext;
static void randomize_cdf(uint16_t *const cdf, int n) {
......@@ -85,9 +89,7 @@ static int msac_cmp(const MsacContext *const a, const MsacContext *const b) {
} \
} while (0)
static void check_decode_symbol_adapt(MsacDSPContext *const c,
uint8_t *const buf)
{
static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
/* Use an aligned CDF buffer for more consistent benchmark
* results, and a misaligned one for checking correctness. */
ALIGN_STK_16(uint16_t, cdf, 2, [17]);
......@@ -97,16 +99,36 @@ static void check_decode_symbol_adapt(MsacDSPContext *const c,
CHECK_SYMBOL_ADAPT( 4, 1, 5);
CHECK_SYMBOL_ADAPT( 8, 1, 8);
CHECK_SYMBOL_ADAPT(16, 4, 16);
report("decode_symbol_adapt");
report("decode_symbol");
}
static void check_decode_bool_equi(MsacDSPContext *const c,
uint8_t *const buf)
{
declare_func(unsigned, MsacContext *s);
static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
MsacContext s_c, s_a;
if (check_func(c->bool_adapt, "msac_decode_bool_adapt")) {
declare_func(unsigned, MsacContext *s, uint16_t *cdf);
uint16_t cdf[2][2];
for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
s_a = s_c;
cdf[0][0] = cdf[1][0] = rnd() % 32767 + 1;
cdf[0][1] = cdf[1][1] = 0;
for (int i = 0; i < 64; i++) {
unsigned c_res = call_ref(&s_c, cdf[0]);
unsigned a_res = call_new(&s_a, cdf[1]);
if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
memcmp(cdf[0], cdf[1], sizeof(*cdf)))
{
fail();
}
}
if (cdf_update)
bench_new(&s_a, cdf[0]);
}
}
if (check_func(c->bool_equi, "msac_decode_bool_equi")) {
MsacContext s_c, s_a;
declare_func(unsigned, MsacContext *s);
dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
s_a = s_c;
for (int i = 0; i < 64; i++) {
......@@ -118,7 +140,21 @@ static void check_decode_bool_equi(MsacDSPContext *const c,
bench_new(&s_a);
}
report("decode_bool_equi");
if (check_func(c->bool, "msac_decode_bool")) {
declare_func(unsigned, MsacContext *s, unsigned f);
dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
s_a = s_c;
for (int i = 0; i < 64; i++) {
const unsigned f = rnd() & 0x7fff;
unsigned c_res = call_ref(&s_c, f);
unsigned a_res = call_new(&s_a, f);
if (c_res != a_res || msac_cmp(&s_c, &s_a))
fail();
}
bench_new(&s_a, 16384);
}
report("decode_bool");
}
void checkasm_check_msac(void) {
......@@ -126,7 +162,9 @@ void checkasm_check_msac(void) {
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
c.bool_adapt = dav1d_msac_decode_bool_adapt_c;
c.bool_equi = dav1d_msac_decode_bool_equi_c;
c.bool = dav1d_msac_decode_bool_c;
#if ARCH_AARCH64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
......@@ -139,7 +177,9 @@ void checkasm_check_msac(void) {
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
c.bool_adapt = dav1d_msac_decode_bool_adapt_sse2;
c.bool_equi = dav1d_msac_decode_bool_equi_sse2;
c.bool = dav1d_msac_decode_bool_sse2;
}
#endif
......@@ -147,6 +187,6 @@ void checkasm_check_msac(void) {
for (int i = 0; i < BUF_SIZE; i++)
buf[i] = rnd();
check_decode_symbol_adapt(&c, buf);
check_decode_bool_equi(&c, buf);
check_decode_symbol(&c, buf);
check_decode_bool(&c, buf);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment