Commit 78d27b7d authored by Henrik Gramner's avatar Henrik Gramner
Browse files

x86: Rewrite wiener SSE2/SSSE3/AVX2 asm

The previous implementation did two separate passes in the horizontal
and vertical directions, with the intermediate values being stored
in a buffer on the stack. This caused bad cache thrashing.

By interleaving the horizontal and vertical passes in combination
with a ring buffer for storing only a few rows at a time the
performance is improved by a significant amount.

Also split the function into 7-tap and 5-tap versions. The latter is
faster and fairly common (always for chroma, sometimes for luma).
parent 3497c4c9
......@@ -288,7 +288,7 @@ COLD void bitfn(dav1d_loop_restoration_dsp_init_arm)(Dav1dLoopRestorationDSPCont
if (!(flags & DAV1D_ARM_CPU_FLAG_NEON)) return;
c->wiener = wiener_filter_neon;
c->wiener[0] = c->wiener[1] = wiener_filter_neon;
if (bpc <= 10)
c->selfguided = sgr_filter_neon;
}
......@@ -67,7 +67,7 @@ void (name)(pixel *dst, ptrdiff_t dst_stride, \
typedef decl_selfguided_filter_fn(*selfguided_fn);
typedef struct Dav1dLoopRestorationDSPContext {
wienerfilter_fn wiener;
wienerfilter_fn wiener[2]; /* 7-tap, 5-tap */
selfguided_fn selfguided;
} Dav1dLoopRestorationDSPContext;
......
......@@ -509,7 +509,7 @@ static void selfguided_c(pixel *p, const ptrdiff_t p_stride,
}
COLD void bitfn(dav1d_loop_restoration_dsp_init)(Dav1dLoopRestorationDSPContext *const c, int bpc) {
c->wiener = wiener_c;
c->wiener[0] = c->wiener[1] = wiener_c;
c->selfguided = selfguided_c;
#if HAVE_ASM
......
......@@ -163,6 +163,7 @@ static void lr_stripe(const Dav1dFrameContext *const f, pixel *p,
int stripe_h = imin((64 - 8 * !y) >> ss_ver, row_h - y);
ALIGN_STK_16(int16_t, filter, 2, [8]);
wienerfilter_fn wiener_fn = NULL;
if (lr->type == DAV1D_RESTORATION_WIENER) {
filter[0][0] = filter[0][6] = lr->filter_h[0];
filter[0][1] = filter[0][5] = lr->filter_h[1];
......@@ -178,6 +179,8 @@ static void lr_stripe(const Dav1dFrameContext *const f, pixel *p,
filter[1][1] = filter[1][5] = lr->filter_v[1];
filter[1][2] = filter[1][4] = lr->filter_v[2];
filter[1][3] = 128 - (filter[1][0] + filter[1][1] + filter[1][2]) * 2;
wiener_fn = dsp->lr.wiener[!(filter[0][0] | filter[1][0])];
} else {
assert(lr->type == DAV1D_RESTORATION_SGRPROJ);
}
......@@ -185,9 +188,9 @@ static void lr_stripe(const Dav1dFrameContext *const f, pixel *p,
while (y + stripe_h <= row_h) {
// Change HAVE_BOTTOM bit in edges to (y + stripe_h != row_h)
edges ^= (-(y + stripe_h != row_h) ^ edges) & LR_HAVE_BOTTOM;
if (lr->type == DAV1D_RESTORATION_WIENER) {
dsp->lr.wiener(p, p_stride, left, lpf, lpf_stride, unit_w, stripe_h,
filter, edges HIGHBD_CALL_SUFFIX);
if (wiener_fn) {
wiener_fn(p, p_stride, left, lpf, lpf_stride, unit_w, stripe_h,
filter, edges HIGHBD_CALL_SUFFIX);
} else {
dsp->lr.selfguided(p, p_stride, left, lpf, lpf_stride, unit_w, stripe_h,
lr->sgr_idx, lr->sgr_weights, edges HIGHBD_CALL_SUFFIX);
......
......@@ -332,7 +332,7 @@ COLD void bitfn(dav1d_loop_restoration_dsp_init_ppc)
if (!(flags & DAV1D_PPC_CPU_FLAG_VSX)) return;
#if BITDEPTH == 8
c->wiener = wiener_filter_vsx;
c->wiener[0] = c->wiener[1] = wiener_filter_vsx;
#endif
}
......
......@@ -29,20 +29,25 @@
%if ARCH_X86_64
SECTION_RODATA 32
wiener_shufA: db 1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12, 7, 13, 8, 14
wiener_shufB: db 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10
wiener_shufC: db 6, 5, 7, 6, 8, 7, 9, 8, 10, 9, 11, 10, 12, 11, 13, 12
wiener_shufD: db 4, -1, 5, -1, 6, -1, 7, -1, 8, -1, 9, -1, 10, -1, 11, -1
wiener_l_shuf: db 4, 4, 4, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
pb_0to31: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
db 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
pb_right_ext_mask: times 32 db 0xff
times 32 db 0
pb_14x0_1_2: times 14 db 0
db 1, 2
pb_0_to_15_min_n: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 13
db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 14
pb_15: times 16 db 15
pw_16: times 2 dw 16
pw_256: times 2 dw 256
pw_2048: times 2 dw 2048
pw_16380: times 2 dw 16380
pw_5_6: dw 5, 6
pd_6: dd 6
pd_1024: dd 1024
pb_3: times 4 db 3
pb_m5: times 4 db -5
pw_16: times 2 dw 16
pw_256: times 2 dw 256
pw_2056: times 2 dw 2056
pw_m16380: times 2 dw -16380
pw_5_6: dw 5, 6
pd_1024: dd 1024
pd_0xf0080029: dd 0xf0080029
pd_0xf00801c7: dd 0xf00801c7
......@@ -50,277 +55,662 @@ cextern sgr_x_by_x
SECTION .text
INIT_YMM avx2
cglobal wiener_filter_h, 5, 12, 16, dst, left, src, stride, flt, w, h, edge
mov edged, edgem
vpbroadcastb m15, [fltq+0]
movifnidn wd, wm
vpbroadcastb m14, [fltq+2]
mov hd, hm
vpbroadcastb m13, [fltq+4]
vpbroadcastw m12, [fltq+6]
vpbroadcastd m11, [pw_2048]
vpbroadcastd m10, [pw_16380]
lea r11, [pb_right_ext_mask]
DEFINE_ARGS dst, left, src, stride, x, w, h, edge, srcptr, dstptr, xlim
; if (edge & has_right) align_w_to_32
; else w -= 32, and use that as limit in x loop
test edgeb, 2 ; has_right
jnz .align
mov xlimq, -3
jmp .loop
.align:
add wd, 31
and wd, ~31
xor xlimd, xlimd
; main y loop for vertical filter
.loop:
mov srcptrq, srcq
mov dstptrq, dstq
lea xq, [wq+xlimq]
%macro REPX 2-*
%xdefine %%f(x) %1
%rep %0 - 1
%rotate 1
%%f(%1)
%endrep
%endmacro
; load left edge pixels
test edgeb, 1 ; have_left
jz .emu_left
test leftq, leftq ; left == NULL for the edge-extended bottom/top
jz .load_left_combined
movd xm0, [leftq]
add leftq, 4
pinsrd xm0, [srcq], 1
pslldq xm0, 9
jmp .left_load_done
.load_left_combined:
movq xm0, [srcq-3]
pslldq xm0, 10
jmp .left_load_done
.emu_left:
movd xm0, [srcq]
pshufb xm0, [pb_14x0_1_2]
DECLARE_REG_TMP 4, 9, 7, 11, 12, 13, 14 ; wiener ring buffer pointers
; load right edge pixels
.left_load_done:
cmp xd, 32
jg .main_load
test xd, xd
jg .load_and_splat
je .splat_right
; for very small images (w=[1-2]), edge-extend the original cache,
; ugly, but only runs in very odd cases
add wd, wd
pshufb xm0, [r11-pb_right_ext_mask+pb_0_to_15_min_n+wq*8-16]
shr wd, 1
; main x loop, mostly this starts in .main_load
.splat_right:
; no need to load new pixels, just extend them from the (possibly previously
; extended) previous load into m0
pshufb xm1, xm0, [pb_15]
jmp .main_loop
.load_and_splat:
; load new pixels and extend edge for right-most
movu m1, [srcptrq+3]
sub r11, xq
movu m2, [r11-pb_right_ext_mask+pb_right_ext_mask+32]
add r11, xq
vpbroadcastb m3, [srcptrq+2+xq]
pand m1, m2
pandn m3, m2, m3
por m1, m3
jmp .main_loop
.main_load:
; load subsequent line
movu m1, [srcptrq+3]
INIT_YMM avx2
cglobal wiener_filter7, 5, 15, 16, -384*12-16, dst, dst_stride, left, lpf, \
lpf_stride, w, edge, flt, h
mov fltq, fltmp
mov edged, r8m
mov wd, wm
mov hd, r6m
vbroadcasti128 m6, [wiener_shufA]
vpbroadcastb m11, [fltq+ 0] ; x0 x0
vbroadcasti128 m7, [wiener_shufB]
vpbroadcastd m12, [fltq+ 2]
vbroadcasti128 m8, [wiener_shufC]
packsswb m12, m12 ; x1 x2
vpbroadcastw m13, [fltq+ 6] ; x3
vbroadcasti128 m9, [wiener_shufD]
add lpfq, wq
vpbroadcastd m10, [pw_m16380]
lea t1, [rsp+wq*2+16]
vpbroadcastd m14, [fltq+16] ; y0 y1
add dstq, wq
vpbroadcastd m15, [fltq+20] ; y2 y3
neg wq
test edgeb, 4 ; LR_HAVE_TOP
jz .no_top
call .h_top
add lpfq, lpf_strideq
mov t6, t1
mov t5, t1
add t1, 384*2
call .h_top
lea r7, [lpfq+lpf_strideq*4]
mov lpfq, dstq
mov t4, t1
add t1, 384*2
mov [rsp+8*1], lpf_strideq
add r7, lpf_strideq
mov [rsp+8*0], r7 ; below
call .h
mov t3, t1
mov t2, t1
dec hd
jz .v1
add lpfq, dst_strideq
add t1, 384*2
call .h
mov t2, t1
dec hd
jz .v2
add lpfq, dst_strideq
add t1, 384*2
call .h
dec hd
jz .v3
.main:
lea t0, [t1+384*2]
.main_loop:
vinserti128 m0, xm1, 1
palignr m2, m1, m0, 10
palignr m3, m1, m0, 11
palignr m4, m1, m0, 12
palignr m5, m1, m0, 13
palignr m6, m1, m0, 14
palignr m7, m1, m0, 15
punpcklbw m0, m2, m1
punpckhbw m2, m1
punpcklbw m8, m3, m7
punpckhbw m3, m7
punpcklbw m7, m4, m6
punpckhbw m4, m6
pxor m9, m9
punpcklbw m6, m5, m9
punpckhbw m5, m9
pmaddubsw m0, m15
pmaddubsw m2, m15
pmaddubsw m8, m14
pmaddubsw m3, m14
pmaddubsw m7, m13
pmaddubsw m4, m13
paddw m0, m8
paddw m2, m3
psllw m8, m6, 7
psllw m3, m5, 7
psubw m8, m10
psubw m3, m10
pmullw m6, m12
pmullw m5, m12
paddw m0, m7
paddw m2, m4
paddw m0, m6
paddw m2, m5
; for a signed overflow to happen we need filter and pixels as follow:
; filter => -5,-23,-17,90,-17,-23,-5
; pixels => 255,255,255,0,255,255,255 or 0,0,0,255,0,0,0
; m0 would fall in the range [-59A6;+59A6] = [A65A;59A6]
; m8 would fall in the range [-3FFC;+3F84] = [C004;3F84]
; 32-bit arithmetic m0+m8 = [-99A2;+992A] = [FFFF665E;992A]
; => signed 16-bit overflow occurs
paddsw m0, m8 ; paddsw clips this range to [-8000;+7FFF]
paddsw m2, m3
psraw m0, 3 ; shift changes the range to [-1000;+FFF]
psraw m2, 3
paddw m0, m11 ; adding back 800 (removed in m8) changes the
paddw m2, m11 ; range to [-800;+17FF] as defined in the spec
mova [dstptrq], xm0 ; (note that adding another 800 would give us
mova [dstptrq+16], xm2; the same range as in the C code => [0;1FFF])
vextracti128 [dstptrq+32], m0, 1
vextracti128 [dstptrq+48], m2, 1
vextracti128 xm0, m1, 1
add srcptrq, 32
add dstptrq, 64
sub xq, 32
cmp xd, 32
jg .main_load
test xd, xd
jg .load_and_splat
cmp xd, xlimd
jg .splat_right
add srcq, strideq
add dstq, 384*2
dec hd
jg .loop
call .hv
dec hd
jnz .main_loop
test edgeb, 8 ; LR_HAVE_BOTTOM
jz .v3
mov lpfq, [rsp+8*0]
call .hv_bottom
add lpfq, [rsp+8*1]
call .hv_bottom
.v1:
call .v
RET
.no_top:
lea r7, [lpfq+lpf_strideq*4]
mov lpfq, dstq
mov [rsp+8*1], lpf_strideq
lea r7, [r7+lpf_strideq*2]
mov [rsp+8*0], r7
call .h
mov t6, t1
mov t5, t1
mov t4, t1
mov t3, t1
mov t2, t1
dec hd
jz .v1
add lpfq, dst_strideq
add t1, 384*2
call .h
mov t2, t1
dec hd
jz .v2
add lpfq, dst_strideq
add t1, 384*2
call .h
dec hd
jz .v3
lea t0, [t1+384*2]
call .hv
dec hd
jz .v3
add t0, 384*8
call .hv
dec hd
jnz .main
.v3:
call .v
.v2:
call .v
jmp .v1
.extend_right:
movd xm2, r10d
vpbroadcastd m0, [pb_3]
vpbroadcastd m1, [pb_m5]
vpbroadcastb m2, xm2
movu m3, [pb_0to31]
psubb m0, m2
psubb m1, m2
pminub m0, m3
pminub m1, m3
pshufb m4, m0
pshufb m5, m1
ret
.h:
mov r10, wq
test edgeb, 1 ; LR_HAVE_LEFT
jz .h_extend_left
movd xm4, [leftq]
vpblendd m4, [lpfq+r10-4], 0xfe
add leftq, 4
jmp .h_main
.h_extend_left:
vbroadcasti128 m5, [lpfq+r10] ; avoid accessing memory located
mova m4, [lpfq+r10] ; before the start of the buffer
palignr m4, m5, 12
pshufb m4, [wiener_l_shuf]
jmp .h_main
.h_top:
mov r10, wq
movu m4, [lpfq+r10-4]
test edgeb, 1 ; LR_HAVE_LEFT
jnz .h_main
pshufb m4, [wiener_l_shuf]
jmp .h_main
.h_loop:
movu m4, [lpfq+r10-4]
.h_main:
movu m5, [lpfq+r10+4]
test edgeb, 2 ; LR_HAVE_RIGHT
jnz .h_have_right
cmp r10d, -34
jl .h_have_right
call .extend_right
.h_have_right:
pshufb m0, m4, m6
pmaddubsw m0, m11
pshufb m1, m5, m6
pmaddubsw m1, m11
pshufb m2, m4, m7
pmaddubsw m2, m12
pshufb m3, m5, m7
pmaddubsw m3, m12
paddw m0, m2
pshufb m2, m4, m8
pmaddubsw m2, m12
paddw m1, m3
pshufb m3, m5, m8
pmaddubsw m3, m12
pshufb m4, m9
paddw m0, m2
pmullw m2, m4, m13
pshufb m5, m9
paddw m1, m3
pmullw m3, m5, m13
psllw m4, 7
psllw m5, 7
paddw m4, m10
paddw m5, m10
paddw m0, m2
vpbroadcastd m2, [pw_2056]
paddw m1, m3
paddsw m0, m4
paddsw m1, m5
psraw m0, 3
psraw m1, 3
paddw m0, m2
paddw m1, m2
mova [t1+r10*2+ 0], m0
mova [t1+r10*2+32], m1
add r10, 32
jl .h_loop
ret
ALIGN function_align
.hv:
add lpfq, dst_strideq
mov r10, wq
test edgeb, 1 ; LR_HAVE_LEFT
jz .hv_extend_left
movd xm4, [leftq]
vpblendd m4, [lpfq+r10-4], 0xfe
add leftq, 4
jmp .hv_main
.hv_extend_left:
movu m4, [lpfq+r10-4]
pshufb m4, [wiener_l_shuf]
jmp .hv_main
.hv_bottom:
mov r10, wq
test edgeb, 1 ; LR_HAVE_LEFT
jz .hv_extend_left
.hv_loop:
movu m4, [lpfq+r10-4]
.hv_main:
movu m5, [lpfq+r10+4]
test edgeb, 2 ; LR_HAVE_RIGHT
jnz .hv_have_right
cmp r10d, -34
jl .hv_have_right
call .extend_right
.hv_have_right:
pshufb m0, m4, m6
pmaddubsw m0, m11
pshufb m1, m5, m6
pmaddubsw m1, m11
pshufb m2, m4, m7
pmaddubsw m2, m12
pshufb m3, m5, m7
pmaddubsw m3, m12
paddw m0, m2
pshufb m2, m4, m8
pmaddubsw m2, m12
paddw m1, m3
pshufb m3, m5, m8
pmaddubsw m3, m12
pshufb m4, m9
paddw m0, m2
pmullw m2, m4, m13
pshufb m5, m9
paddw m1, m3
pmullw m3, m5, m13
psllw m4, 7
psllw m5, 7
paddw m4, m10
paddw m5, m10
paddw m0, m2
paddw m1, m3
mova m2, [t4+r10*2]
paddw m2, [t2+r10*2]
mova m3, [t3+r10*2]
paddsw m0, m4
vpbroadcastd m4, [pw_2056]
paddsw m1, m5
mova m5, [t5+r10*2]
paddw m5, [t1+r10*2]
psraw m0, 3
psraw m1, 3
paddw m0, m4
paddw m1, m4
paddw m4, m0, [t6+r10*2]
mova [t0+r10*2], m0
punpcklwd m0, m2, m3
pmaddwd m0, m15
punpckhwd m2, m3
pmaddwd m2, m15
punpcklwd m3, m4, m5
pmaddwd m3, m14
punpckhwd m4, m5
pmaddwd m4, m14
paddd m0, m3
paddd m4, m2
mova m2, [t4+r10*2+32]
paddw m2, [t2+r10*2+32]
mova m3, [t3+r10*2+32]
mova m5, [t5+r10*2+32]
paddw m5, [t1+r10*2+32]
psrad m0, 11
psrad m4, 11
packssdw m0, m4
paddw m4, m1, [t6+r10*2+32]
mova [t0+r10*2+32], m1
punpcklwd m1, m2, m3
pmaddwd m1, m15
punpckhwd m2, m3
pmaddwd m2, m15
punpcklwd m3, m4, m5
pmaddwd m3, m14
punpckhwd m4, m5
pmaddwd m4, m14
paddd m1, m3
paddd m2, m4
psrad m1, 11
psrad m2, 11
packssdw m1, m2
packuswb m0, m1
mova [dstq+r10], m0
add r10, 32
jl .hv_loop
mov t6, t5
mov t5, t4
mov t4, t3
mov t3, t2
mov t2, t1
mov t1, t0
mov t0, t6
add dstq, dst_strideq
ret
.v:
mov r10, wq
.v_loop:
mova m2, [t4+r10*2+ 0]
paddw m2, [t2+r10*2+ 0]
mova m4, [t3+r10*2+ 0]
mova m6, [t1+r10*2+ 0]
paddw m8, m6, [t6+r10*2+ 0]
paddw m6, [t5+r10*2+ 0]
mova m3, [t4+r10*2+32]
paddw m3, [t2+r10*2+32]
mova m5, [t3+r10*2+32]
mova m7, [t1+r10*2+32]
paddw m9, m7, [t6+r10*2+32]
paddw m7, [t5+r10*2+32]
punpcklwd m0, m2, m4
pmaddwd m0, m15
punpckhwd m2, m4
pmaddwd m2, m15
punpcklwd m4, m8, m6