Commit caca5725 authored by Henrik Gramner's avatar Henrik Gramner Committed by Henrik Gramner

mc: Ensure high bitdepth intermediates fits in int16_t

An extreme edge case with the combination of 8-tap sharp_sharp, mx and my
both around 8, and a very specific pixel input pattern can cause overflows.

Add code to checkasm to trigger this scenario.
parent ad4d1c43
Pipeline #4339 passed with stages
in 6 minutes and 41 seconds
...@@ -39,9 +39,14 @@ ...@@ -39,9 +39,14 @@
#if BITDEPTH == 8 #if BITDEPTH == 8
#define get_intermediate_bits(bitdepth_max) 4 #define get_intermediate_bits(bitdepth_max) 4
// Output in interval [-5132, 9212], fits in int16_t as is
#define PREP_BIAS 0
#else #else
// 4 for 10 bits/component, 2 for 12 bits/component // 4 for 10 bits/component, 2 for 12 bits/component
#define get_intermediate_bits(bitdepth_max) (14 - bitdepth_from_max(bitdepth_max)) #define get_intermediate_bits(bitdepth_max) (14 - bitdepth_from_max(bitdepth_max))
// Output in interval [-20588, 36956] (10-bit), [-20602, 36983] (12-bit)
// Subtract a bias to ensure the output fits in int16_t
#define PREP_BIAS 8192
#endif #endif
static NOINLINE void static NOINLINE void
...@@ -63,7 +68,7 @@ prep_c(int16_t *tmp, const pixel *src, const ptrdiff_t src_stride, ...@@ -63,7 +68,7 @@ prep_c(int16_t *tmp, const pixel *src, const ptrdiff_t src_stride,
const int intermediate_bits = get_intermediate_bits(bitdepth_max); const int intermediate_bits = get_intermediate_bits(bitdepth_max);
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
tmp[x] = src[x] << intermediate_bits; tmp[x] = (src[x] << intermediate_bits) - PREP_BIAS;
tmp += w; tmp += w;
src += src_stride; src += src_stride;
...@@ -237,8 +242,12 @@ prep_8tap_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride, ...@@ -237,8 +242,12 @@ prep_8tap_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
mid_ptr = mid + 128 * 3; mid_ptr = mid + 128 * 3;
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++) {
tmp[x] = DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6); int t = DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6) -
PREP_BIAS;
assert(t >= INT16_MIN && t <= INT16_MAX);
tmp[x] = t;
}
mid_ptr += 128; mid_ptr += 128;
tmp += w; tmp += w;
...@@ -247,7 +256,8 @@ prep_8tap_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride, ...@@ -247,7 +256,8 @@ prep_8tap_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1, tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fh, 1,
6 - intermediate_bits); 6 - intermediate_bits) -
PREP_BIAS;
tmp += w; tmp += w;
src += src_stride; src += src_stride;
...@@ -257,7 +267,8 @@ prep_8tap_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride, ...@@ -257,7 +267,8 @@ prep_8tap_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fv, src_stride, tmp[x] = DAV1D_FILTER_8TAP_RND(src, x, fv, src_stride,
6 - intermediate_bits); 6 - intermediate_bits) -
PREP_BIAS;
tmp += w; tmp += w;
src += src_stride; src += src_stride;
...@@ -302,7 +313,8 @@ prep_8tap_scaled_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride, ...@@ -302,7 +313,8 @@ prep_8tap_scaled_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
GET_V_FILTER(my >> 6); GET_V_FILTER(my >> 6);
for (x = 0; x < w; x++) for (x = 0; x < w; x++)
tmp[x] = fv ? DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6) : mid_ptr[x]; tmp[x] = (fv ? DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6)
: mid_ptr[x]) - PREP_BIAS;
my += dy; my += dy;
mid_ptr += (my >> 10) * 128; mid_ptr += (my >> 10) * 128;
...@@ -499,7 +511,8 @@ static void prep_bilin_c(int16_t *tmp, ...@@ -499,7 +511,8 @@ static void prep_bilin_c(int16_t *tmp,
mid_ptr = mid; mid_ptr = mid;
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my, 128, 4); tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my, 128, 4) -
PREP_BIAS;
mid_ptr += 128; mid_ptr += 128;
tmp += w; tmp += w;
...@@ -508,7 +521,8 @@ static void prep_bilin_c(int16_t *tmp, ...@@ -508,7 +521,8 @@ static void prep_bilin_c(int16_t *tmp,
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
tmp[x] = FILTER_BILIN_RND(src, x, mx, 1, tmp[x] = FILTER_BILIN_RND(src, x, mx, 1,
4 - intermediate_bits); 4 - intermediate_bits) -
PREP_BIAS;
tmp += w; tmp += w;
src += src_stride; src += src_stride;
...@@ -518,7 +532,7 @@ static void prep_bilin_c(int16_t *tmp, ...@@ -518,7 +532,7 @@ static void prep_bilin_c(int16_t *tmp,
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
tmp[x] = FILTER_BILIN_RND(src, x, my, src_stride, tmp[x] = FILTER_BILIN_RND(src, x, my, src_stride,
4 - intermediate_bits); 4 - intermediate_bits) - PREP_BIAS;
tmp += w; tmp += w;
src += src_stride; src += src_stride;
...@@ -557,7 +571,7 @@ static void prep_bilin_scaled_c(int16_t *tmp, ...@@ -557,7 +571,7 @@ static void prep_bilin_scaled_c(int16_t *tmp,
int x; int x;
for (x = 0; x < w; x++) for (x = 0; x < w; x++)
tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my >> 6, 128, 4); tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my >> 6, 128, 4) - PREP_BIAS;
my += dy; my += dy;
mid_ptr += (my >> 10) * 128; mid_ptr += (my >> 10) * 128;
...@@ -571,7 +585,8 @@ static void avg_c(pixel *dst, const ptrdiff_t dst_stride, ...@@ -571,7 +585,8 @@ static void avg_c(pixel *dst, const ptrdiff_t dst_stride,
HIGHBD_DECL_SUFFIX) HIGHBD_DECL_SUFFIX)
{ {
const int intermediate_bits = get_intermediate_bits(bitdepth_max); const int intermediate_bits = get_intermediate_bits(bitdepth_max);
const int sh = intermediate_bits + 1, rnd = 1 << intermediate_bits; const int sh = intermediate_bits + 1;
const int rnd = (1 << intermediate_bits) + PREP_BIAS * 2;
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
dst[x] = iclip_pixel((tmp1[x] + tmp2[x] + rnd) >> sh); dst[x] = iclip_pixel((tmp1[x] + tmp2[x] + rnd) >> sh);
...@@ -587,7 +602,8 @@ static void w_avg_c(pixel *dst, const ptrdiff_t dst_stride, ...@@ -587,7 +602,8 @@ static void w_avg_c(pixel *dst, const ptrdiff_t dst_stride,
const int weight HIGHBD_DECL_SUFFIX) const int weight HIGHBD_DECL_SUFFIX)
{ {
const int intermediate_bits = get_intermediate_bits(bitdepth_max); const int intermediate_bits = get_intermediate_bits(bitdepth_max);
const int sh = intermediate_bits + 4, rnd = 8 << intermediate_bits; const int sh = intermediate_bits + 4;
const int rnd = (8 << intermediate_bits) + PREP_BIAS * 16;
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
dst[x] = iclip_pixel((tmp1[x] * weight + dst[x] = iclip_pixel((tmp1[x] * weight +
...@@ -604,7 +620,8 @@ static void mask_c(pixel *dst, const ptrdiff_t dst_stride, ...@@ -604,7 +620,8 @@ static void mask_c(pixel *dst, const ptrdiff_t dst_stride,
const uint8_t *mask HIGHBD_DECL_SUFFIX) const uint8_t *mask HIGHBD_DECL_SUFFIX)
{ {
const int intermediate_bits = get_intermediate_bits(bitdepth_max); const int intermediate_bits = get_intermediate_bits(bitdepth_max);
const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits; const int sh = intermediate_bits + 6;
const int rnd = (32 << intermediate_bits) + PREP_BIAS * 64;
do { do {
for (int x = 0; x < w; x++) for (int x = 0; x < w; x++)
dst[x] = iclip_pixel((tmp1[x] * mask[x] + dst[x] = iclip_pixel((tmp1[x] * mask[x] +
...@@ -668,7 +685,8 @@ static void w_mask_c(pixel *dst, const ptrdiff_t dst_stride, ...@@ -668,7 +685,8 @@ static void w_mask_c(pixel *dst, const ptrdiff_t dst_stride,
// and then load this intermediate to calculate final value for odd rows // and then load this intermediate to calculate final value for odd rows
const int intermediate_bits = get_intermediate_bits(bitdepth_max); const int intermediate_bits = get_intermediate_bits(bitdepth_max);
const int bitdepth = bitdepth_from_max(bitdepth_max); const int bitdepth = bitdepth_from_max(bitdepth_max);
const int sh = intermediate_bits + 6, rnd = 32 << intermediate_bits; const int sh = intermediate_bits + 6;
const int rnd = (32 << intermediate_bits) + PREP_BIAS * 64;
const int mask_sh = bitdepth + intermediate_bits - 4; const int mask_sh = bitdepth + intermediate_bits - 4;
const int mask_rnd = 1 << (mask_sh - 5); const int mask_rnd = 1 << (mask_sh - 5);
do { do {
...@@ -797,7 +815,7 @@ static void warp_affine_8x8t_c(int16_t *tmp, const ptrdiff_t tmp_stride, ...@@ -797,7 +815,7 @@ static void warp_affine_8x8t_c(int16_t *tmp, const ptrdiff_t tmp_stride,
const int8_t *const filter = const int8_t *const filter =
dav1d_mc_warp_filter[64 + ((tmy + 512) >> 10)]; dav1d_mc_warp_filter[64 + ((tmy + 512) >> 10)];
tmp[x] = FILTER_WARP_RND(mid_ptr, x, filter, 8, 7); tmp[x] = FILTER_WARP_RND(mid_ptr, x, filter, 8, 7) - PREP_BIAS;
} }
mid_ptr += 8; mid_ptr += 8;
tmp += tmp_stride; tmp += tmp_stride;
......
...@@ -84,6 +84,17 @@ static void check_mc(Dav1dMCDSPContext *const c) { ...@@ -84,6 +84,17 @@ static void check_mc(Dav1dMCDSPContext *const c) {
report("mc"); report("mc");
} }
/* Generate worst case input in the topleft corner, randomize the rest */
static void generate_mct_input(pixel *const buf, const int bitdepth_max) {
static const int8_t pattern[8] = { -1, 0, -1, 0, 0, -1, 0, -1 };
const int sign = -(rnd() & 1);
for (int y = 0; y < 135; y++)
for (int x = 0; x < 135; x++)
buf[135*y+x] = ((x | y) < 8 ? (pattern[x] ^ pattern[y] ^ sign)
: rnd()) & bitdepth_max;
}
static void check_mct(Dav1dMCDSPContext *const c) { static void check_mct(Dav1dMCDSPContext *const c) {
ALIGN_STK_32(pixel, src_buf, 135 * 135,); ALIGN_STK_32(pixel, src_buf, 135 * 135,);
ALIGN_STK_32(int16_t, c_tmp, 128 * 128,); ALIGN_STK_32(int16_t, c_tmp, 128 * 128,);
...@@ -107,9 +118,7 @@ static void check_mct(Dav1dMCDSPContext *const c) { ...@@ -107,9 +118,7 @@ static void check_mct(Dav1dMCDSPContext *const c) {
#else #else
const int bitdepth_max = 0xff; const int bitdepth_max = 0xff;
#endif #endif
generate_mct_input(src_buf, bitdepth_max);
for (int i = 0; i < 135 * 135; i++)
src_buf[i] = rnd() & bitdepth_max;
call_ref(c_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX); call_ref(c_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
call_new(a_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX); call_new(a_tmp, src, w, w, h, mx, my HIGHBD_TAIL_SUFFIX);
...@@ -127,12 +136,10 @@ static void init_tmp(Dav1dMCDSPContext *const c, pixel *const buf, ...@@ -127,12 +136,10 @@ static void init_tmp(Dav1dMCDSPContext *const c, pixel *const buf,
int16_t (*const tmp)[128 * 128], const int bitdepth_max) int16_t (*const tmp)[128 * 128], const int bitdepth_max)
{ {
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
for (int j = 0; j < 135 * 135; j++) generate_mct_input(buf, bitdepth_max);
buf[j] = rnd() & bitdepth_max; c->mct[FILTER_2D_8TAP_SHARP](tmp[i], buf + 135 * 3 + 3,
c->mct[rnd() % N_2D_FILTERS](tmp[i], buf + 135 * 3 + 3, 135 * sizeof(pixel), 128, 128,
128 * sizeof(pixel), 128, 128, 8, 8 HIGHBD_TAIL_SUFFIX);
rnd() & 15, rnd() & 15
HIGHBD_TAIL_SUFFIX);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment