Commit c59f1940 authored by Rupert Swarbrick's avatar Rupert Swarbrick Committed by Ronald S. Bultje

Correctly flush at the end of OBUs

This fixes failures when an OBU has more than a byte's worth of
trailing zeros.

As part of this work, it also rejigs the dav1d_flush_get_bits function
slightly. This worked before, but it wasn't very obvious why (it
worked because bits_left was never more than 7). This patch renames it
to dav1d_bytealign_get_bits, which makes it clearer what it does and
adds a comment explaining why it works properly.

The new dav1d_bytealign_get_bits is also now void (rather than
returning the next byte to read). The patch defines
dav1d_get_bits_pos, which returns the current bit position. This feels
a little easier to reason about.

We also add a new check to make sure that we haven't fallen off the
end of the OBU. This can happen when a byte buffer contains more than
one OBU: the GetBits might not have got to EOF, but we might now be
half-way through the next OBU.
parent 2532642b
...@@ -126,8 +126,16 @@ int dav1d_get_bits_subexp(GetBits *const c, const int ref, const unsigned n) { ...@@ -126,8 +126,16 @@ int dav1d_get_bits_subexp(GetBits *const c, const int ref, const unsigned n) {
return (int) get_bits_subexp_u(c, ref + (1 << n), 2 << n) - (1 << n); return (int) get_bits_subexp_u(c, ref + (1 << n), 2 << n) - (1 << n);
} }
const uint8_t *dav1d_flush_get_bits(GetBits *c) { void dav1d_bytealign_get_bits(GetBits *c) {
// bits_left is never more than 7, because it is only incremented
// by refill(), called by dav1d_get_bits and that never reads more
// than 7 bits more than it needs.
//
// If this wasn't true, we would need to work out how many bits to
// discard (bits_left % 8), subtract that from bits_left and then
// shift state right by that amount.
assert(c->bits_left <= 7);
c->bits_left = 0; c->bits_left = 0;
c->state = 0; c->state = 0;
return c->ptr;
} }
...@@ -46,6 +46,13 @@ int dav1d_get_sbits(GetBits *c, unsigned n); ...@@ -46,6 +46,13 @@ int dav1d_get_sbits(GetBits *c, unsigned n);
unsigned dav1d_get_uniform(GetBits *c, unsigned max); unsigned dav1d_get_uniform(GetBits *c, unsigned max);
unsigned dav1d_get_vlc(GetBits *c); unsigned dav1d_get_vlc(GetBits *c);
int dav1d_get_bits_subexp(GetBits *c, int ref, unsigned n); int dav1d_get_bits_subexp(GetBits *c, int ref, unsigned n);
const uint8_t *dav1d_flush_get_bits(GetBits *c);
// Discard bits from the buffer until we're next byte-aligned.
void dav1d_bytealign_get_bits(GetBits *c);
// Return the current bit position relative to the start of the buffer.
static inline unsigned dav1d_get_bits_pos(const GetBits *c) {
return (c->ptr - c->ptr_start) * 8 - c->bits_left;
}
#endif /* __DAV1D_SRC_GETBITS_H__ */ #endif /* __DAV1D_SRC_GETBITS_H__ */
...@@ -46,15 +46,17 @@ ...@@ -46,15 +46,17 @@
static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
Av1SequenceHeader *const hdr) Av1SequenceHeader *const hdr)
{ {
const uint8_t *const init_ptr = gb->ptr;
#define DEBUG_SEQ_HDR 0 #define DEBUG_SEQ_HDR 0
#if DEBUG_SEQ_HDR
const unsigned init_bit_pos = dav1d_get_bits_pos(gb);
#endif
hdr->profile = dav1d_get_bits(gb, 3); hdr->profile = dav1d_get_bits(gb, 3);
if (hdr->profile > 2) goto error; if (hdr->profile > 2) goto error;
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-profile: off=%ld\n", printf("SEQHDR: post-profile: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
hdr->still_picture = dav1d_get_bits(gb, 1); hdr->still_picture = dav1d_get_bits(gb, 1);
...@@ -62,7 +64,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -62,7 +64,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
if (hdr->reduced_still_picture_header && !hdr->still_picture) goto error; if (hdr->reduced_still_picture_header && !hdr->still_picture) goto error;
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-stillpicture_flags: off=%ld\n", printf("SEQHDR: post-stillpicture_flags: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
if (hdr->reduced_still_picture_header) { if (hdr->reduced_still_picture_header) {
...@@ -97,7 +99,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -97,7 +99,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
} }
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-timinginfo: off=%ld\n", printf("SEQHDR: post-timinginfo: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
hdr->display_model_info_present = dav1d_get_bits(gb, 1); hdr->display_model_info_present = dav1d_get_bits(gb, 1);
...@@ -126,7 +128,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -126,7 +128,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
} }
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-operating-points: off=%ld\n", printf("SEQHDR: post-operating-points: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
} }
...@@ -136,7 +138,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -136,7 +138,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->max_height = dav1d_get_bits(gb, hdr->height_n_bits) + 1; hdr->max_height = dav1d_get_bits(gb, hdr->height_n_bits) + 1;
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-size: off=%ld\n", printf("SEQHDR: post-size: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
hdr->frame_id_numbers_present = hdr->frame_id_numbers_present =
hdr->reduced_still_picture_header ? 0 : dav1d_get_bits(gb, 1); hdr->reduced_still_picture_header ? 0 : dav1d_get_bits(gb, 1);
...@@ -146,7 +148,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -146,7 +148,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
} }
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-frame-id-numbers-present: off=%ld\n", printf("SEQHDR: post-frame-id-numbers-present: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
hdr->sb128 = dav1d_get_bits(gb, 1); hdr->sb128 = dav1d_get_bits(gb, 1);
...@@ -180,7 +182,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -180,7 +182,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->screen_content_tools = dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1); hdr->screen_content_tools = dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-screentools: off=%ld\n", printf("SEQHDR: post-screentools: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
hdr->force_integer_mv = hdr->screen_content_tools ? hdr->force_integer_mv = hdr->screen_content_tools ?
dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1) : 2; dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1) : 2;
...@@ -192,7 +194,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -192,7 +194,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->restoration = dav1d_get_bits(gb, 1); hdr->restoration = dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-featurebits: off=%ld\n", printf("SEQHDR: post-featurebits: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
const int hbd = dav1d_get_bits(gb, 1); const int hbd = dav1d_get_bits(gb, 1);
...@@ -243,18 +245,22 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -243,18 +245,22 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
} }
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-colorinfo: off=%ld\n", printf("SEQHDR: post-colorinfo: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
hdr->film_grain_present = dav1d_get_bits(gb, 1); hdr->film_grain_present = dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR #if DEBUG_SEQ_HDR
printf("SEQHDR: post-filmgrain: off=%ld\n", printf("SEQHDR: post-filmgrain: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left); dav1d_get_bits_pos(gb) - init_bit_pos);
#endif #endif
dav1d_get_bits(gb, 1); // dummy bit dav1d_get_bits(gb, 1); // dummy bit
return dav1d_flush_get_bits(gb) - init_ptr; // We needn't bother flushing the OBU here: we'll check we didn't
// overrun in the caller and will then discard gb, so there's no
// point in setting its position properly.
return 0;
error: error:
fprintf(stderr, "Error parsing sequence header\n"); fprintf(stderr, "Error parsing sequence header\n");
...@@ -313,16 +319,16 @@ static const Av1LoopfilterModeRefDeltas default_mode_ref_deltas = { ...@@ -313,16 +319,16 @@ static const Av1LoopfilterModeRefDeltas default_mode_ref_deltas = {
.ref_delta = { 1, 0, 0, 0, -1, 0, -1, -1 }, .ref_delta = { 1, 0, 0, 0, -1, 0, -1, -1 },
}; };
static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb, static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb) {
const int have_trailing_bit) #define DEBUG_FRAME_HDR 0
{
#if DEBUG_FRAME_HDR
const uint8_t *const init_ptr = gb->ptr; const uint8_t *const init_ptr = gb->ptr;
#endif
const Av1SequenceHeader *const seqhdr = &c->seq_hdr; const Av1SequenceHeader *const seqhdr = &c->seq_hdr;
Av1FrameHeader *const hdr = &c->frame_hdr; Av1FrameHeader *const hdr = &c->frame_hdr;
int res; int res;
#define DEBUG_FRAME_HDR 0
hdr->show_existing_frame = hdr->show_existing_frame =
!seqhdr->reduced_still_picture_header && dav1d_get_bits(gb, 1); !seqhdr->reduced_still_picture_header && dav1d_get_bits(gb, 1);
#if DEBUG_FRAME_HDR #if DEBUG_FRAME_HDR
...@@ -335,7 +341,7 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -335,7 +341,7 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->frame_presentation_delay = dav1d_get_bits(gb, seqhdr->frame_presentation_delay_length); hdr->frame_presentation_delay = dav1d_get_bits(gb, seqhdr->frame_presentation_delay_length);
if (seqhdr->frame_id_numbers_present) if (seqhdr->frame_id_numbers_present)
hdr->frame_id = dav1d_get_bits(gb, seqhdr->frame_id_n_bits); hdr->frame_id = dav1d_get_bits(gb, seqhdr->frame_id_n_bits);
goto end; return 0;
} }
hdr->frame_type = seqhdr->reduced_still_picture_header ? DAV1D_FRAME_TYPE_KEY : dav1d_get_bits(gb, 2); hdr->frame_type = seqhdr->reduced_still_picture_header ? DAV1D_FRAME_TYPE_KEY : dav1d_get_bits(gb, 2);
...@@ -976,21 +982,14 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb, ...@@ -976,21 +982,14 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
(gb->ptr - init_ptr) * 8 - gb->bits_left); (gb->ptr - init_ptr) * 8 - gb->bits_left);
#endif #endif
end: return 0;
if (have_trailing_bit)
dav1d_get_bits(gb, 1); // dummy bit
return dav1d_flush_get_bits(gb) - init_ptr;
error: error:
fprintf(stderr, "Error parsing frame header\n"); fprintf(stderr, "Error parsing frame header\n");
return -EINVAL; return -EINVAL;
} }
static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) { static void parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
const uint8_t *const init_ptr = gb->ptr;
int have_tile_pos = 0; int have_tile_pos = 0;
const int n_tiles = c->frame_hdr.tiling.cols * c->frame_hdr.tiling.rows; const int n_tiles = c->frame_hdr.tiling.cols * c->frame_hdr.tiling.rows;
if (n_tiles > 1) if (n_tiles > 1)
...@@ -1005,8 +1004,31 @@ static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) { ...@@ -1005,8 +1004,31 @@ static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
c->tile[c->n_tile_data].start = 0; c->tile[c->n_tile_data].start = 0;
c->tile[c->n_tile_data].end = n_tiles - 1; c->tile[c->n_tile_data].end = n_tiles - 1;
} }
}
// Check that we haven't read more than obu_len bytes from the buffer
// since init_bit_pos.
static int
check_for_overrun(GetBits *const gb, unsigned init_bit_pos, unsigned obu_len)
{
// Make sure we haven't actually read past the end of the gb buffer
if (gb->error) {
fprintf(stderr, "Overrun in OBU bit buffer\n");
return 1;
}
return dav1d_flush_get_bits(gb) - init_ptr; unsigned pos = dav1d_get_bits_pos(gb);
// We assume that init_bit_pos was the bit position of the buffer
// at some point in the past, so cannot be smaller than pos.
assert (init_bit_pos <= pos);
if (pos - init_bit_pos > 8 * obu_len) {
fprintf(stderr, "Overrun in OBU bit buffer into next OBU\n");
return 1;
}
return 0;
} }
int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
...@@ -1041,9 +1063,23 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { ...@@ -1041,9 +1063,23 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
} while (more); } while (more);
if (gb.error) goto error; if (gb.error) goto error;
unsigned off = dav1d_flush_get_bits(&gb) - in->data; const unsigned init_bit_pos = dav1d_get_bits_pos(&gb);
const unsigned init_off = off; const unsigned init_byte_pos = init_bit_pos >> 3;
if (len > in->sz - off) goto error; const unsigned pkt_bytelen = init_byte_pos + len;
// We must have read a whole number of bytes at this point (1 byte
// for the header and whole bytes at a time when reading the
// leb128 length field).
assert(init_bit_pos & 7 == 0);
// We also know that we haven't tried to read more than in->sz
// bytes yet (otherwise the error flag would have been set by the
// code in getbits.c)
assert(in->sz >= init_byte_pos);
// Make sure that there are enough bits left in the buffer for the
// rest of the OBU.
if (len > in->sz - init_byte_pos) goto error;
switch (type) { switch (type) {
case OBU_SEQ_HDR: { case OBU_SEQ_HDR: {
...@@ -1052,8 +1088,8 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { ...@@ -1052,8 +1088,8 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
c->have_frame_hdr = 0; c->have_frame_hdr = 0;
if ((res = parse_seq_hdr(c, &gb, hdr_ptr)) < 0) if ((res = parse_seq_hdr(c, &gb, hdr_ptr)) < 0)
return res; return res;
if ((unsigned)res != len) if (check_for_overrun(&gb, init_bit_pos, len))
goto error; return -EINVAL;
if (!c->have_frame_hdr || memcmp(&hdr, &c->seq_hdr, sizeof(hdr))) { if (!c->have_frame_hdr || memcmp(&hdr, &c->seq_hdr, sizeof(hdr))) {
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
if (c->refs[i].p.p.data[0]) if (c->refs[i].p.p.data[0])
...@@ -1076,29 +1112,48 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { ...@@ -1076,29 +1112,48 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
case OBU_FRAME_HDR: case OBU_FRAME_HDR:
c->have_frame_hdr = 0; c->have_frame_hdr = 0;
if (!c->have_seq_hdr) goto error; if (!c->have_seq_hdr) goto error;
if ((res = parse_frame_hdr(c, &gb, type != OBU_FRAME)) < 0) if ((res = parse_frame_hdr(c, &gb)) < 0)
return res; return res;
c->have_frame_hdr = 1; c->have_frame_hdr = 1;
for (int n = 0; n < c->n_tile_data; n++) for (int n = 0; n < c->n_tile_data; n++)
dav1d_data_unref(&c->tile[n].data); dav1d_data_unref(&c->tile[n].data);
c->n_tile_data = 0; c->n_tile_data = 0;
c->n_tiles = 0; c->n_tiles = 0;
if (type != OBU_FRAME) break; if (type != OBU_FRAME) {
// This is actually a frame header OBU so read the
// trailing bit and check for overrun.
dav1d_get_bits(&gb, 1);
if (check_for_overrun(&gb, init_bit_pos, len))
return -EINVAL;
break;
}
// OBU_FRAMEs shouldn't be signalled with show_existing_frame
if (c->frame_hdr.show_existing_frame) goto error; if (c->frame_hdr.show_existing_frame) goto error;
off += res;
// This is the frame header at the start of a frame OBU.
// There's no trailing bit at the end to skip, but we do need
// to align to the next byte.
dav1d_bytealign_get_bits(&gb);
// fall-through // fall-through
case OBU_TILE_GRP: case OBU_TILE_GRP: {
if (!c->have_frame_hdr) goto error; if (!c->have_frame_hdr) goto error;
if (c->n_tile_data >= 256) goto error; if (c->n_tile_data >= 256) goto error;
if ((res = parse_tile_hdr(c, &gb)) < 0) parse_tile_hdr(c, &gb);
return res; // Align to the next byte boundary and check for overrun.
off += res; dav1d_bytealign_get_bits(&gb);
if (off > len + init_off) if (check_for_overrun(&gb, init_bit_pos, len))
goto error; return -EINVAL;
// The current bit position is a multiple of 8 (because we
// just aligned it) and less than 8*pkt_bytelen because
// otherwise the overrun check would have fired.
const unsigned bit_pos = dav1d_get_bits_pos(&gb);
assert(bit_pos & 7 == 0);
assert(pkt_bytelen > (bit_pos >> 3));
dav1d_ref_inc(in->ref); dav1d_ref_inc(in->ref);
c->tile[c->n_tile_data].data.ref = in->ref; c->tile[c->n_tile_data].data.ref = in->ref;
c->tile[c->n_tile_data].data.data = in->data + off; c->tile[c->n_tile_data].data.data = in->data + (bit_pos >> 3);
c->tile[c->n_tile_data].data.sz = len + init_off - off; c->tile[c->n_tile_data].data.sz = pkt_bytelen - (bit_pos >> 3);
// ensure tile groups are in order and sane, see 6.10.1 // ensure tile groups are in order and sane, see 6.10.1
if (c->tile[c->n_tile_data].start > c->tile[c->n_tile_data].end || if (c->tile[c->n_tile_data].start > c->tile[c->n_tile_data].end ||
c->tile[c->n_tile_data].start != c->n_tiles) c->tile[c->n_tile_data].start != c->n_tiles)
...@@ -1113,6 +1168,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { ...@@ -1113,6 +1168,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
c->tile[c->n_tile_data].start; c->tile[c->n_tile_data].start;
c->n_tile_data++; c->n_tile_data++;
break; break;
}
case OBU_PADDING: case OBU_PADDING:
case OBU_TD: case OBU_TD:
case OBU_METADATA: case OBU_METADATA:
...@@ -1192,7 +1248,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) { ...@@ -1192,7 +1248,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
} }
} }
return len + init_off; return len + init_byte_pos;
error: error:
fprintf(stderr, "Error parsing OBU data\n"); fprintf(stderr, "Error parsing OBU data\n");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment