Commit 61dcd11b authored by Henrik Gramner's avatar Henrik Gramner Committed by Henrik Gramner

x86: Add an msac function for coefficient hi_tok decoding

This particular sequence is executed often enough to justify having
a separate slightly more optimized code path instead of just chaining
multiple generic symbol decoding function calls together.
parent e29fd5c0
......@@ -171,6 +171,22 @@ unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *const s,
return bit;
}
unsigned dav1d_msac_decode_hi_tok_c(MsacContext *const s, uint16_t *const cdf) {
unsigned tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
unsigned tok = 3 + tok_br;
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
tok = 6 + tok_br;
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
tok = 9 + tok_br;
if (tok_br == 3)
tok = 12 + dav1d_msac_decode_symbol_adapt4(s, cdf, 3);
}
}
return tok;
}
void dav1d_msac_init(MsacContext *const s, const uint8_t *const data,
const size_t sz, const int disable_cdf_update_flag)
{
......
......@@ -58,6 +58,7 @@ unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_hi_tok_c(MsacContext *s, uint16_t *cdf);
int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
/* Supported n_symbols ranges: adapt4: 1-4, adapt8: 1-7, adapt16: 3-15 */
......@@ -79,6 +80,9 @@ int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
#ifndef dav1d_msac_decode_bool
#define dav1d_msac_decode_bool dav1d_msac_decode_bool_c
#endif
#ifndef dav1d_msac_decode_hi_tok
#define dav1d_msac_decode_hi_tok dav1d_msac_decode_hi_tok_c
#endif
static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
unsigned v = 0;
......
......@@ -199,40 +199,13 @@ static int decode_coefs(Dav1dTileContext *const t,
printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
t_dim->ctx, chroma, ctx, eob, rc, tok, ts->msac.rng);
// hi tok
if (tok_br == 2) {
#define dbg_print_hi_tok(i, tok, tok_br) \
if (dbg)\
printf("Post-hi_tok[%d][%d][%d][%d=%d=%d->%d]: r=%d\n",\
imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok, tok_br,\
ts->msac.rng)
const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx], 3);
tok = 3 + tok_br;
dbg_print_hi_tok(eob, tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx], 3);
tok = 6 + tok_br;
dbg_print_hi_tok(eob, tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx],
3);
tok = 9 + tok_br;
dbg_print_hi_tok(eob, tok + tok_br, tok_br);
if (tok_br == 3) {
tok = 12 +
dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx],
3);
dbg_print_hi_tok(eob, tok + tok_br, tok_br);
}
}
}
tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
if (dbg)
printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
imin(t_dim->ctx, 3), chroma, br_ctx, eob, rc, tok,
ts->msac.rng);
}
cf[rc] = tok;
......@@ -249,37 +222,14 @@ static int decode_coefs(Dav1dTileContext *const t,
printf("Post-lo_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
t_dim->ctx, chroma, ctx, i, rc, tok, ts->msac.rng);
// hi tok
if (tok == 3) {
const int br_ctx = get_br_ctx(levels, 1, tx_class, x, y, stride);
int tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx], 3);
tok = 3 + tok_br;
dbg_print_hi_tok(i, tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx], 3);
tok = 6 + tok_br;
dbg_print_hi_tok(i, tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx],
3);
tok = 9 + tok_br;
dbg_print_hi_tok(i, tok + tok_br, tok_br);
if (tok_br == 3) {
tok = 12 + dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx],
3);
dbg_print_hi_tok(i, tok + tok_br, tok_br);
}
}
}
tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
if (dbg)
printf("Post-hi_tok[%d][%d][%d][%d=%d=%d]: r=%d\n",
imin(t_dim->ctx, 3), chroma, br_ctx, i, rc, tok,
ts->msac.rng);
}
#undef dbg_print_hi_tok
cf[rc] = tok;
levels[x * stride + y] = (uint8_t) tok;
}
......@@ -292,43 +242,13 @@ static int decode_coefs(Dav1dTileContext *const t,
printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
t_dim->ctx, chroma, ctx, dc_tok, ts->msac.rng);
// hi tok
if (dc_tok == 3) {
#define dbg_print_hi_tok(dc_tok, tok_br) \
if (dbg) \
printf("Post-dc_hi_tok[%d][%d][%d][%d->%d]: r=%d\n", \
imin(t_dim->ctx, 3), chroma, br_ctx, tok_br, dc_tok, ts->msac.rng);
const int br_ctx = get_br_ctx(levels, 0, tx_class, 0, 0, stride);
int tok_br =
dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[br_ctx], 3);
dc_tok = 3 + tok_br;
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx], 3);
dc_tok = 6 + tok_br;
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx],
3);
dc_tok = 9 + tok_br;
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
if (tok_br == 3) {
dc_tok = 12 +
dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[br_ctx],
3);
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
}
}
}
dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[br_ctx]);
if (dbg)
printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
}
#undef dbg_print_hi_tok
}
} else { // dc-only
uint16_t *const lo_cdf = ts->cdf.coef.eob_base_tok[t_dim->ctx][chroma][0];
......@@ -338,38 +258,13 @@ static int decode_coefs(Dav1dTileContext *const t,
printf("Post-dc_lo_tok[%d][%d][%d][%d]: r=%d\n",
t_dim->ctx, chroma, 0, dc_tok, ts->msac.rng);
// hi tok
if (tok_br == 2) {
#define dbg_print_hi_tok(dc_tok, tok_br) \
if (dbg) \
printf("Post-dc_hi_tok[%d][%d][0][%d->%d]: r=%d\n", \
imin(t_dim->ctx, 3), chroma, tok_br, dc_tok, ts->msac.rng);
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
dc_tok = 3 + tok_br;
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac, br_cdf[0], 3);
dc_tok = 6 + tok_br;
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
if (tok_br == 3) {
tok_br = dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[0], 3);
dc_tok = 9 + tok_br;
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
if (tok_br == 3) {
dc_tok = 12 +
dav1d_msac_decode_symbol_adapt4(&ts->msac,
br_cdf[0], 3);
dbg_print_hi_tok(dc_tok + tok_br, tok_br);
}
}
}
dc_tok = dav1d_msac_decode_hi_tok(&ts->msac, br_cdf[0]);
if (dbg)
printf("Post-dc_hi_tok[%d][%d][0][%d]: r=%d\n",
imin(t_dim->ctx, 3), chroma, dc_tok, ts->msac.rng);
}
}
#undef dbg_print_hi_tok
// residual and sign
int dc_sign = 1 << 6;
......
......@@ -27,7 +27,7 @@
SECTION_RODATA 64 ; avoids cacheline splits
dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
min_prob: dw 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
pw_0xff00: times 8 dw 0xff00
pw_32: times 8 dw 32
......@@ -35,21 +35,24 @@ pw_32: times 8 dw 32
%define resp resq
%define movp movq
%define c_shuf q3333
%define DECODE_SYMBOL_ADAPT_INIT
%macro DECODE_SYMBOL_ADAPT_INIT 0-1
%endmacro
%else
%define resp resd
%define movp movd
%define c_shuf q1111
%macro DECODE_SYMBOL_ADAPT_INIT 0
%macro DECODE_SYMBOL_ADAPT_INIT 0-1 0 ; hi_tok
mov t0, r0m
mov t1, r1m
%if %1 == 0
mov t2, r2m
%endif
%if STACK_ALIGNMENT >= 16
sub esp, 40
sub esp, 40-%1*4
%else
mov eax, esp
and esp, ~15
sub esp, 40
sub esp, 40-%1*4
mov [esp], eax
%endif
%endmacro
......@@ -69,13 +72,13 @@ endstruc
SECTION .text
%if WIN64
DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3
%define buf rsp+8 ; shadow space
DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 3, 8
%define buf rsp+stack_offset+8 ; shadow space
%elif UNIX64
DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0
DECLARE_REG_TMP 0, 1, 2, 3, 4, 5, 7, 0, 8
%define buf rsp-40 ; red zone
%else
DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2
DECLARE_REG_TMP 2, 3, 4, 1, 5, 6, 5, 2, 3
%define buf esp+8
%endif
......@@ -440,3 +443,158 @@ cglobal msac_decode_bool, 0, 6, 0
movzx eax, al
%endif
jmp m(msac_decode_symbol_adapt4).renorm3
%macro HI_TOK 1 ; update_cdf
%if ARCH_X86_64 == 0
mov eax, -24
%endif
%%loop:
%if %1
movzx t2d, word [t1+3*2]
%endif
mova m1, m0
pshuflw m2, m2, q0000
psrlw m1, 6
movd [buf+12], m2
pand m2, m4
psllw m1, 7
pmulhuw m1, m2
%if ARCH_X86_64 == 0
add eax, 5
mov [buf+8], eax
%endif
pshuflw m3, m3, c_shuf
paddw m1, m5
movq [buf+16], m1
psubusw m1, m3
pxor m2, m2
pcmpeqw m1, m2
pmovmskb eax, m1
%if %1
lea ecx, [t2+80]
pcmpeqw m2, m2
shr ecx, 4
cmp t2d, 32
adc t2d, 0
movd m3, ecx
pavgw m2, m1
psubw m2, m0
psubw m0, m1
psraw m2, m3
paddw m0, m2
movq [t1], m0
mov [t1+3*2], t2w
%endif
tzcnt eax, eax
movzx ecx, word [buf+rax+16]
movzx t2d, word [buf+rax+14]
not t4
%if ARCH_X86_64
add t6d, 5
%endif
sub eax, 5 ; setup for merging the tok_br and tok branches
sub t2d, ecx
shl rcx, gprsize*8-16
add t4, rcx
bsr ecx, t2d
xor ecx, 15
shl t2d, cl
shl t4, cl
movd m2, t2d
mov [t7+msac.rng], t2d
not t4
sub t5d, ecx
jge %%end
mov t2, [t7+msac.buf]
mov rcx, [t7+msac.end]
%if UNIX64 == 0
push t8
%endif
lea t8, [t2+gprsize]
cmp t8, rcx
ja %%refill_eob
mov t2, [t2]
lea ecx, [t5+23]
add t5d, 16
shr ecx, 3
bswap t2
sub t8, rcx
shl ecx, 3
shr t2, cl
sub ecx, t5d
mov t5d, gprsize*8-16
shl t2, cl
mov [t7+msac.buf], t8
%if UNIX64 == 0
pop t8
%endif
sub t5d, ecx
xor t4, t2
%%end:
movp m3, t4
%if ARCH_X86_64
add t6d, eax ; CF = tok_br < 3 || tok == 15
jnc %%loop
lea eax, [t6+30]
%else
add eax, [buf+8]
jnc %%loop
add eax, 30
%if STACK_ALIGNMENT >= 16
add esp, 36
%else
mov esp, [esp]
%endif
%endif
mov [t7+msac.dif], t4
shr eax, 1
mov [t7+msac.cnt], t5d
RET
%%refill_eob:
mov t8, rcx
mov ecx, gprsize*8-24
sub ecx, t5d
%%refill_eob_loop:
cmp t2, t8
jae %%refill_eob_end
movzx t5d, byte [t2]
inc t2
shl t5, cl
xor t4, t5
sub ecx, 8
jge %%refill_eob_loop
%%refill_eob_end:
%if UNIX64 == 0
pop t8
%endif
mov t5d, gprsize*8-24
mov [t7+msac.buf], t2
sub t5d, ecx
jmp %%end
%endmacro
cglobal msac_decode_hi_tok, 0, 7 + ARCH_X86_64, 6
DECODE_SYMBOL_ADAPT_INIT 1
%if ARCH_X86_64 == 0 && PIC
LEA t2, min_prob+12*2
%define base t2-(min_prob+12*2)
%else
%define base 0
%endif
movq m0, [t1]
movd m2, [t0+msac.rng]
mov eax, [t0+msac.update_cdf]
movq m4, [base+pw_0xff00]
movp m3, [t0+msac.dif]
movq m5, [base+min_prob+12*2]
mov t4, [t0+msac.dif]
mov t5d, [t0+msac.cnt]
%if ARCH_X86_64
mov t6d, -24
%endif
movifnidn t7, t0
test eax, eax
jz .no_update_cdf
HI_TOK 1
.no_update_cdf:
HI_TOK 0
......@@ -37,11 +37,13 @@ unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf);
unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
unsigned dav1d_msac_decode_hi_tok_sse2(MsacContext *s, uint16_t *cdf);
#if ARCH_X86_64 || defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_sse2
#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_sse2
#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
#define dav1d_msac_decode_hi_tok dav1d_msac_decode_hi_tok_sse2
#endif
#define dav1d_msac_decode_bool_adapt dav1d_msac_decode_bool_adapt_sse2
......
......@@ -38,7 +38,7 @@
/* The normal code doesn't use function pointers */
typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf);
typedef unsigned (*decode_adapt_fn)(MsacContext *s, uint16_t *cdf);
typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
......@@ -46,9 +46,10 @@ typedef struct {
decode_symbol_adapt_fn symbol_adapt4;
decode_symbol_adapt_fn symbol_adapt8;
decode_symbol_adapt_fn symbol_adapt16;
decode_bool_adapt_fn bool_adapt;
decode_adapt_fn bool_adapt;
decode_bool_equi_fn bool_equi;
decode_bool_fn bool;
decode_adapt_fn hi_tok;
} MsacDSPContext;
static void randomize_cdf(uint16_t *const cdf, const int n) {
......@@ -199,6 +200,35 @@ static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
report("decode_bool");
}
static void check_decode_hi_tok(MsacDSPContext *const c, uint8_t *const buf) {
ALIGN_STK_16(uint16_t, cdf, 2, [16]);
MsacContext s_c, s_a;
if (check_func(c->hi_tok, "msac_decode_hi_tok")) {
declare_func(unsigned, MsacContext *s, uint16_t *cdf);
for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
s_a = s_c;
randomize_cdf(cdf[0], 3);
memcpy(cdf[1], cdf[0], sizeof(*cdf));
for (int i = 0; i < 64; i++) {
unsigned c_res = call_ref(&s_c, cdf[0]);
unsigned a_res = call_new(&s_a, cdf[1]);
if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
memcmp(cdf[0], cdf[1], sizeof(*cdf)))
{
if (fail())
msac_dump(c_res, a_res, &s_c, &s_a, cdf[0], cdf[1], 3);
break;
}
}
if (cdf_update)
bench_new(&s_a, cdf[1]);
}
}
report("decode_hi_tok");
}
void checkasm_check_msac(void) {
MsacDSPContext c;
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt_c;
......@@ -207,6 +237,7 @@ void checkasm_check_msac(void) {
c.bool_adapt = dav1d_msac_decode_bool_adapt_c;
c.bool_equi = dav1d_msac_decode_bool_equi_c;
c.bool = dav1d_msac_decode_bool_c;
c.hi_tok = dav1d_msac_decode_hi_tok_c;
#if ARCH_AARCH64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
......@@ -225,6 +256,7 @@ void checkasm_check_msac(void) {
c.bool_adapt = dav1d_msac_decode_bool_adapt_sse2;
c.bool_equi = dav1d_msac_decode_bool_equi_sse2;
c.bool = dav1d_msac_decode_bool_sse2;
c.hi_tok = dav1d_msac_decode_hi_tok_sse2;
}
#endif
......@@ -234,4 +266,5 @@ void checkasm_check_msac(void) {
check_decode_symbol(&c, buf);
check_decode_bool(&c, buf);
check_decode_hi_tok(&c, buf);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment