Commit 2463174c authored by Henrik Gramner's avatar Henrik Gramner
Browse files

x86: AVX-512 decimate_score

Also drop the MMX versions and improve the SSE2, SSSE3 and AVX2 versions.
parent 49fb50a6
......@@ -460,9 +460,6 @@ void x264_quant_init( x264_t *h, int cpu, x264_quant_function_t *pf )
{
#if ARCH_X86
pf->denoise_dct = x264_denoise_dct_mmx;
pf->decimate_score15 = x264_decimate_score15_mmx2;
pf->decimate_score16 = x264_decimate_score16_mmx2;
pf->decimate_score64 = x264_decimate_score64_mmx2;
pf->coeff_last8 = x264_coeff_last8_mmx2;
pf->coeff_last[ DCT_LUMA_AC] = x264_coeff_last15_mmx2;
pf->coeff_last[ DCT_LUMA_4x4] = x264_coeff_last16_mmx2;
......@@ -562,6 +559,9 @@ void x264_quant_init( x264_t *h, int cpu, x264_quant_function_t *pf )
{
pf->dequant_4x4 = x264_dequant_4x4_avx512;
pf->dequant_8x8 = x264_dequant_8x8_avx512;
pf->decimate_score15 = x264_decimate_score15_avx512;
pf->decimate_score16 = x264_decimate_score16_avx512;
pf->decimate_score64 = x264_decimate_score64_avx512;
pf->coeff_last4 = x264_coeff_last4_avx512;
pf->coeff_last8 = x264_coeff_last8_avx512;
pf->coeff_last[ DCT_LUMA_AC] = x264_coeff_last15_avx512;
......@@ -594,9 +594,6 @@ void x264_quant_init( x264_t *h, int cpu, x264_quant_function_t *pf )
pf->quant_4x4 = x264_quant_4x4_mmx2;
pf->quant_8x8 = x264_quant_8x8_mmx2;
pf->quant_4x4_dc = x264_quant_4x4_dc_mmx2;
pf->decimate_score15 = x264_decimate_score15_mmx2;
pf->decimate_score16 = x264_decimate_score16_mmx2;
pf->decimate_score64 = x264_decimate_score64_mmx2;
pf->coeff_last[ DCT_LUMA_AC] = x264_coeff_last15_mmx2;
pf->coeff_last[ DCT_LUMA_4x4] = x264_coeff_last16_mmx2;
pf->coeff_last[ DCT_LUMA_8x8] = x264_coeff_last64_mmx2;
......@@ -736,6 +733,9 @@ void x264_quant_init( x264_t *h, int cpu, x264_quant_function_t *pf )
pf->dequant_4x4 = x264_dequant_4x4_avx512;
pf->dequant_8x8 = x264_dequant_8x8_avx512;
}
pf->decimate_score15 = x264_decimate_score15_avx512;
pf->decimate_score16 = x264_decimate_score16_avx512;
pf->decimate_score64 = x264_decimate_score64_avx512;
pf->coeff_last8 = x264_coeff_last8_avx512;
pf->coeff_last[ DCT_LUMA_AC] = x264_coeff_last15_avx512;
pf->coeff_last[DCT_LUMA_4x4] = x264_coeff_last16_avx512;
......
......@@ -32,7 +32,9 @@
SECTION_RODATA 64
%if HIGH_BIT_DEPTH == 0
%if HIGH_BIT_DEPTH
decimate_shuf_avx512: dd 0, 4, 8,12, 1, 5, 9,13, 2, 6,10,14, 3, 7,11,15
%else
dequant_shuf_avx512: dw 0, 2, 4, 6, 8,10,12,14,16,18,20,22,24,26,28,30
dw 32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62
%endif
......@@ -1370,13 +1372,12 @@ cglobal denoise_dct, 4,4,4
; int decimate_score( dctcoef *dct )
;-----------------------------------------------------------------------------
%macro DECIMATE_MASK 5
%if mmsize==16
%macro DECIMATE_MASK 4
%if HIGH_BIT_DEPTH
movdqa m0, [%3+ 0]
movdqa m1, [%3+32]
packssdw m0, [%3+16]
packssdw m1, [%3+48]
mova m0, [%3+0*16]
packssdw m0, [%3+1*16]
mova m1, [%3+2*16]
packssdw m1, [%3+3*16]
ABSW2 m0, m1, m0, m1, m3, m4
%else
ABSW m0, [%3+ 0], m3
......@@ -1388,40 +1389,35 @@ cglobal denoise_dct, 4,4,4
pcmpgtb m0, %4
pmovmskb %1, m2
pmovmskb %2, m0
%else ; mmsize==8
%endmacro
%macro DECIMATE_MASK16_AVX512 0
mova m0, [r0]
%if HIGH_BIT_DEPTH
movq m0, [%3+ 0]
movq m1, [%3+16]
movq m2, [%3+32]
movq m3, [%3+48]
packssdw m0, [%3+ 8]
packssdw m1, [%3+24]
packssdw m2, [%3+40]
packssdw m3, [%3+56]
%else
movq m0, [%3+ 0]
movq m1, [%3+ 8]
movq m2, [%3+16]
movq m3, [%3+24]
%endif
ABSW2 m0, m1, m0, m1, m6, m7
ABSW2 m2, m3, m2, m3, m6, m7
packsswb m0, m1
packsswb m2, m3
pxor m4, m4
pxor m6, m6
pcmpeqb m4, m0
pcmpeqb m6, m2
pcmpgtb m0, %4
pcmpgtb m2, %4
pmovmskb %5, m4
pmovmskb %1, m6
shl %1, 8
or %1, %5
pmovmskb %5, m0
pmovmskb %2, m2
shl %2, 8
or %2, %5
vptestmd k0, m0, m0
pabsd m0, m0
vpcmpud k1, m0, [pd_1] {1to16}, 6
%else
vptestmw k0, m0, m0
pabsw m0, m0
vpcmpuw k1, m0, [pw_1], 6
%endif
%endmacro
%macro SHRX 2
%if cpuflag(bmi2)
shrx %1, %1, %2
%else
shr %1, %2b ; %2 has to be rcx/ecx
%endif
%endmacro
%macro BLSR 2
%if cpuflag(bmi1)
blsr %1, %2
%else
lea %1, [%2-1]
and %1, %2
%endif
%endmacro
......@@ -1431,33 +1427,60 @@ cextern decimate_table8
%macro DECIMATE4x4 1
cglobal decimate_score%1, 1,3
%ifdef PIC
lea r4, [decimate_table4]
lea r5, [decimate_mask_table4]
%define table r4
%define mask_table r5
%if cpuflag(avx512)
DECIMATE_MASK16_AVX512
xor eax, eax
kmovw edx, k0
%if %1 == 15
shr edx, 1
%else
%define table decimate_table4
%define mask_table decimate_mask_table4
test edx, edx
%endif
DECIMATE_MASK edx, eax, r0, [pb_1], ecx
jz .ret
ktestw k1, k1
jnz .ret9
%else
DECIMATE_MASK edx, eax, r0, [pb_1]
xor edx, 0xffff
je .ret
jz .ret
test eax, eax
jne .ret9
%if %1==15
jnz .ret9
%if %1 == 15
shr edx, 1
%endif
%endif
%ifdef PIC
lea r4, [decimate_mask_table4]
%define mask_table r4
%else
%define mask_table decimate_mask_table4
%endif
movzx ecx, dl
movzx eax, byte [mask_table + rcx]
%if ARCH_X86_64
xor edx, ecx
jz .ret
%if cpuflag(lzcnt)
lzcnt ecx, ecx
lea r5, [decimate_table4-32]
add r5, rcx
%else
bsr ecx, ecx
lea r5, [decimate_table4-1]
sub r5, rcx
%endif
%define table r5
%else
cmp edx, ecx
je .ret
jz .ret
bsr ecx, ecx
shr edx, 1
shr edx, cl
SHRX edx, ecx
%define table decimate_table4
%endif
tzcnt ecx, edx
shr edx, 1
shr edx, cl
SHRX edx, ecx
add al, byte [table + rcx]
add al, byte [mask_table + rdx]
.ret:
......@@ -1465,175 +1488,224 @@ cglobal decimate_score%1, 1,3
.ret9:
mov eax, 9
RET
%endmacro
%if ARCH_X86_64 == 0
INIT_MMX mmx2
DECIMATE4x4 15
DECIMATE4x4 16
%endif
INIT_XMM sse2
DECIMATE4x4 15
DECIMATE4x4 16
INIT_XMM ssse3
DECIMATE4x4 15
DECIMATE4x4 16
; 2x gt1 output, 2x nz output, 1x mask
%macro DECIMATE_MASK64_AVX2 5
pabsw m0, [r0+ 0]
pabsw m2, [r0+32]
pabsw m1, [r0+64]
pabsw m3, [r0+96]
packsswb m0, m2
packsswb m1, m3
pcmpgtb m2, m0, %5 ; the > 1 checks don't care about order, so
pcmpgtb m3, m1, %5 ; we can save latency by doing them here
pmovmskb %1, m2
pmovmskb %2, m3
or %1, %2
jne .ret9
%macro DECIMATE_MASK64_AVX2 2 ; nz_low, nz_high
mova m0, [r0+0*32]
packsswb m0, [r0+1*32]
mova m1, [r0+2*32]
packsswb m1, [r0+3*32]
mova m4, [pb_1]
pabsb m2, m0
pabsb m3, m1
por m2, m3 ; the > 1 checks don't care about order, so
ptest m4, m2 ; we can save latency by doing them here
jnc .ret9
vpermq m0, m0, q3120
vpermq m1, m1, q3120
pxor m4, m4
pcmpeqb m0, m4
pcmpeqb m1, m4
pmovmskb %3, m0
pmovmskb %4, m1
pmovmskb %1, m0
pmovmskb %2, m1
%endmacro
%macro DECIMATE8x8 0
%macro DECIMATE_MASK64_AVX512 0
mova m0, [r0]
%if HIGH_BIT_DEPTH
packssdw m0, [r0+1*64]
mova m1, [r0+2*64]
packssdw m1, [r0+3*64]
packsswb m0, m1
vbroadcasti32x4 m1, [pb_1]
pabsb m2, m0
vpcmpub k0, m2, m1, 6
ktestq k0, k0
jnz .ret9
mova m1, [decimate_shuf_avx512]
vpermd m0, m1, m0
vptestmb k1, m0, m0
%else
mova m1, [r0+64]
vbroadcasti32x4 m3, [pb_1]
packsswb m2, m0, m1
pabsb m2, m2
vpcmpub k0, m2, m3, 6
ktestq k0, k0
jnz .ret9
vptestmw k1, m0, m0
vptestmw k2, m1, m1
%endif
%endmacro
%macro DECIMATE8x8 0
%if ARCH_X86_64
cglobal decimate_score64, 1,5
%if mmsize == 64
DECIMATE_MASK64_AVX512
xor eax, eax
%if HIGH_BIT_DEPTH
kmovq r1, k1
test r1, r1
jz .ret
%else
kortestd k1, k2
jz .ret
kunpckdq k1, k2, k1
kmovq r1, k1
%endif
%elif mmsize == 32
DECIMATE_MASK64_AVX2 r1d, eax
not r1
shl rax, 32
xor r1, rax
jz .ret
%else
mova m5, [pb_1]
DECIMATE_MASK r1d, eax, r0+SIZEOF_DCTCOEF* 0, m5
test eax, eax
jnz .ret9
DECIMATE_MASK r2d, eax, r0+SIZEOF_DCTCOEF*16, m5
shl r2d, 16
or r1d, r2d
DECIMATE_MASK r2d, r3d, r0+SIZEOF_DCTCOEF*32, m5
shl r2, 32
or eax, r3d
or r1, r2
DECIMATE_MASK r2d, r3d, r0+SIZEOF_DCTCOEF*48, m5
not r1
shl r2, 48
xor r1, r2
jz .ret
add eax, r3d
jnz .ret9
%endif
%ifdef PIC
lea r4, [decimate_table8]
%define table r4
%else
%define table decimate_table8
%endif
mova m5, [pb_1]
%if mmsize==32
DECIMATE_MASK64_AVX2 eax, r2d, r1d, r3d, m5
shl r3, 32
or r1, r3
xor r1, -1
je .ret
%else
DECIMATE_MASK r1d, eax, r0+SIZEOF_DCTCOEF* 0, m5, null
test eax, eax
jne .ret9
DECIMATE_MASK r2d, eax, r0+SIZEOF_DCTCOEF*16, m5, null
shl r2d, 16
or r1d, r2d
DECIMATE_MASK r2d, r3d, r0+SIZEOF_DCTCOEF*32, m5, null
shl r2, 32
or eax, r3d
or r1, r2
DECIMATE_MASK r2d, r3d, r0+SIZEOF_DCTCOEF*48, m5, null
shl r2, 48
or r1, r2
xor r1, -1
je .ret
add eax, r3d
jne .ret9
%endif
mov al, -6
mov al, -6
.loop:
tzcnt rcx, r1
shr r1, cl
add al, byte [table + rcx]
jge .ret9
shr r1, 1
jne .loop
add al, 6
add al, byte [table + rcx]
jge .ret9
shr r1, 1
SHRX r1, rcx
%if cpuflag(bmi2)
test r1, r1
%endif
jnz .loop
add al, 6
.ret:
REP_RET
.ret9:
mov eax, 9
mov eax, 9
RET
%else ; ARCH
%if mmsize == 8
cglobal decimate_score64, 1,6
%else
cglobal decimate_score64, 1,5
%endif
mova m5, [pb_1]
%if mmsize==32
DECIMATE_MASK64_AVX2 r0, r2, r3, r4, m5
xor r3, -1
je .tryret
xor r4, -1
.cont:
%else
DECIMATE_MASK r3, r2, r0+SIZEOF_DCTCOEF* 0, m5, r5
test r2, r2
jne .ret9
DECIMATE_MASK r4, r2, r0+SIZEOF_DCTCOEF*16, m5, r5
shl r4, 16
or r3, r4
DECIMATE_MASK r4, r1, r0+SIZEOF_DCTCOEF*32, m5, r5
or r2, r1
DECIMATE_MASK r1, r0, r0+SIZEOF_DCTCOEF*48, m5, r5
shl r1, 16
or r4, r1
xor r3, -1
je .tryret
xor r4, -1
.cont:
add r0, r2
jne .ret9
%endif
mov al, -6
cglobal decimate_score64, 1,4
%if mmsize == 64
DECIMATE_MASK64_AVX512
xor eax, eax
%if HIGH_BIT_DEPTH
kshiftrq k2, k1, 32
%endif
kmovd r2, k1
kmovd r3, k2
test r2, r2
jz .tryret
%elif mmsize == 32
DECIMATE_MASK64_AVX2 r2, r3
xor eax, eax
not r3
xor r2, -1
jz .tryret
%else
mova m5, [pb_1]
DECIMATE_MASK r2, r1, r0+SIZEOF_DCTCOEF* 0, m5
test r1, r1
jnz .ret9
DECIMATE_MASK r3, r1, r0+SIZEOF_DCTCOEF*16, m5
not r2
shl r3, 16
xor r2, r3
mov r0m, r2
DECIMATE_MASK r3, r2, r0+SIZEOF_DCTCOEF*32, m5
or r2, r1
DECIMATE_MASK r1, r0, r0+SIZEOF_DCTCOEF*48, m5
add r0, r2
jnz .ret9
mov r2, r0m
not r3
shl r1, 16
xor r3, r1
test r2, r2
jz .tryret
%endif
mov al, -6
.loop:
tzcnt ecx, r2
add al, byte [decimate_table8 + ecx]
jge .ret9
sub ecx, 31 ; increase the shift count by one to shift away the lowest set bit as well
jz .run31 ; only bits 0-4 are used so we have to explicitly handle the case of 1<<31
shrd r2, r3, cl
SHRX r3, ecx
%if notcpuflag(bmi2)
test r2, r2
%endif
jnz .loop
BLSR r2, r3
jz .end
.largerun:
tzcnt ecx, r3
test r3, r3
je .largerun
shrd r3, r4, cl
shr r4, cl
add al, byte [decimate_table8 + ecx]
jge .ret9
shrd r3, r4, 1
shr r4, 1
test r3, r3
jne .loop
test r4, r4
jne .loop
add al, 6
.ret:
REP_RET
.tryret:
xor r4, -1
jne .cont
shr r3, 1
SHRX r3, ecx
.loop2:
tzcnt ecx, r3
add al, byte [decimate_table8 + ecx]
jge .ret9
shr r3, 1
SHRX r3, ecx
.run31:
test r3, r3
jnz .loop2
.end:
add al, 6
RET
.tryret:
BLSR r2, r3
jz .ret
mov al, -6
jmp .largerun
.ret9:
mov eax, 9
RET
.largerun:
mov r3, r4
xor r4, r4
tzcnt ecx, r3
shr r3, cl
shr r3, 1
jne .loop
add al, 6
RET
.ret:
REP_RET
%endif ; ARCH
%endmacro
%if ARCH_X86_64 == 0
INIT_MMX mmx2
DECIMATE8x8
%endif
INIT_XMM sse2
DECIMATE4x4 15
DECIMATE4x4 16
DECIMATE8x8
INIT_XMM ssse3
DECIMATE4x4 15
DECIMATE4x4 16
DECIMATE8x8
%if HIGH_BIT_DEPTH
INIT_ZMM avx512
%else
INIT_YMM avx2
DECIMATE8x8
INIT_YMM avx512
%endif
DECIMATE4x4 15
DECIMATE4x4 16
INIT_ZMM avx512
DECIMATE8x8
;-----------------------------------------------------------------------------
; int coeff_last( dctcoef *dct )
......
......@@ -88,16 +88,16 @@ void x264_denoise_dct_sse2 ( dctcoef *dct, uint32_t *sum, udctcoef *offset, int
void x264_denoise_dct_ssse3( dctcoef *dct, uint32_t *sum, udctcoef *offset, int size );
void x264_denoise_dct_avx ( dctcoef *dct, uint32_t *sum, udctcoef *offset, int size );
void x264_denoise_dct_avx2 ( dctcoef *dct, uint32_t *sum, udctcoef *offset, int size );
int x264_decimate_score15_mmx2( dctcoef *dct );
int x264_decimate_score15_sse2( dctcoef *dct );
int x264_decimate_score15_ssse3( dctcoef *dct );
int x264_decimate_score16_mmx2( dctcoef *dct );
int x264_decimate_score15_avx512( dctcoef *dct );
int x264_decimate_score16_sse2( dctcoef *dct );
int x264_decimate_score16_ssse3( dctcoef *dct );
int x264_decimate_score64_mmx2( dctcoef *dct );
int x264_decimate_score16_avx512( dctcoef *dct );
int x264_decimate_score64_sse2( dctcoef *dct );
int x264_decimate_score64_ssse3( dctcoef *dct );
int x264_decimate_score64_avx2( int16_t *dct );
int x264_decimate_score64_avx512( dctcoef *dct );
int x264_coeff_last4_mmx2( dctcoef *dct );
int x264_coeff_last8_mmx2( dctcoef *dct );
int x264_coeff_last15_mmx2( dctcoef *dct );
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment