Commit c59f1940 authored by Rupert Swarbrick's avatar Rupert Swarbrick Committed by Ronald S. Bultje

Correctly flush at the end of OBUs

This fixes failures when an OBU has more than a byte's worth of
trailing zeros.

As part of this work, it also rejigs the dav1d_flush_get_bits function
slightly. This worked before, but it wasn't very obvious why (it
worked because bits_left was never more than 7). This patch renames it
to dav1d_bytealign_get_bits, which makes it clearer what it does and
adds a comment explaining why it works properly.

The new dav1d_bytealign_get_bits is also now void (rather than
returning the next byte to read). The patch defines
dav1d_get_bits_pos, which returns the current bit position. This feels
a little easier to reason about.

We also add a new check to make sure that we haven't fallen off the
end of the OBU. This can happen when a byte buffer contains more than
one OBU: the GetBits might not have got to EOF, but we might now be
half-way through the next OBU.
parent 2532642b
......@@ -126,8 +126,16 @@ int dav1d_get_bits_subexp(GetBits *const c, const int ref, const unsigned n) {
return (int) get_bits_subexp_u(c, ref + (1 << n), 2 << n) - (1 << n);
}
const uint8_t *dav1d_flush_get_bits(GetBits *c) {
void dav1d_bytealign_get_bits(GetBits *c) {
// bits_left is never more than 7, because it is only incremented
// by refill(), called by dav1d_get_bits and that never reads more
// than 7 bits more than it needs.
//
// If this wasn't true, we would need to work out how many bits to
// discard (bits_left % 8), subtract that from bits_left and then
// shift state right by that amount.
assert(c->bits_left <= 7);
c->bits_left = 0;
c->state = 0;
return c->ptr;
}
......@@ -46,6 +46,13 @@ int dav1d_get_sbits(GetBits *c, unsigned n);
unsigned dav1d_get_uniform(GetBits *c, unsigned max);
unsigned dav1d_get_vlc(GetBits *c);
int dav1d_get_bits_subexp(GetBits *c, int ref, unsigned n);
const uint8_t *dav1d_flush_get_bits(GetBits *c);
// Discard bits from the buffer until we're next byte-aligned.
void dav1d_bytealign_get_bits(GetBits *c);
// Return the current bit position relative to the start of the buffer.
static inline unsigned dav1d_get_bits_pos(const GetBits *c) {
return (c->ptr - c->ptr_start) * 8 - c->bits_left;
}
#endif /* __DAV1D_SRC_GETBITS_H__ */
......@@ -46,15 +46,17 @@
static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
Av1SequenceHeader *const hdr)
{
const uint8_t *const init_ptr = gb->ptr;
#define DEBUG_SEQ_HDR 0
#if DEBUG_SEQ_HDR
const unsigned init_bit_pos = dav1d_get_bits_pos(gb);
#endif
hdr->profile = dav1d_get_bits(gb, 3);
if (hdr->profile > 2) goto error;
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-profile: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->still_picture = dav1d_get_bits(gb, 1);
......@@ -62,7 +64,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
if (hdr->reduced_still_picture_header && !hdr->still_picture) goto error;
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-stillpicture_flags: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
if (hdr->reduced_still_picture_header) {
......@@ -97,7 +99,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-timinginfo: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->display_model_info_present = dav1d_get_bits(gb, 1);
......@@ -126,7 +128,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-operating-points: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
}
......@@ -136,7 +138,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->max_height = dav1d_get_bits(gb, hdr->height_n_bits) + 1;
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-size: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->frame_id_numbers_present =
hdr->reduced_still_picture_header ? 0 : dav1d_get_bits(gb, 1);
......@@ -146,7 +148,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-frame-id-numbers-present: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->sb128 = dav1d_get_bits(gb, 1);
......@@ -180,7 +182,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->screen_content_tools = dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-screentools: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->force_integer_mv = hdr->screen_content_tools ?
dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1) : 2;
......@@ -192,7 +194,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->restoration = dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-featurebits: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
const int hbd = dav1d_get_bits(gb, 1);
......@@ -243,18 +245,22 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-colorinfo: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->film_grain_present = dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-filmgrain: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
dav1d_get_bits(gb, 1); // dummy bit
return dav1d_flush_get_bits(gb) - init_ptr;
// We needn't bother flushing the OBU here: we'll check we didn't
// overrun in the caller and will then discard gb, so there's no
// point in setting its position properly.
return 0;
error:
fprintf(stderr, "Error parsing sequence header\n");
......@@ -313,16 +319,16 @@ static const Av1LoopfilterModeRefDeltas default_mode_ref_deltas = {
.ref_delta = { 1, 0, 0, 0, -1, 0, -1, -1 },
};
static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
const int have_trailing_bit)
{
static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb) {
#define DEBUG_FRAME_HDR 0
#if DEBUG_FRAME_HDR
const uint8_t *const init_ptr = gb->ptr;
#endif
const Av1SequenceHeader *const seqhdr = &c->seq_hdr;
Av1FrameHeader *const hdr = &c->frame_hdr;
int res;
#define DEBUG_FRAME_HDR 0
hdr->show_existing_frame =
!seqhdr->reduced_still_picture_header && dav1d_get_bits(gb, 1);
#if DEBUG_FRAME_HDR
......@@ -335,7 +341,7 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->frame_presentation_delay = dav1d_get_bits(gb, seqhdr->frame_presentation_delay_length);
if (seqhdr->frame_id_numbers_present)
hdr->frame_id = dav1d_get_bits(gb, seqhdr->frame_id_n_bits);
goto end;
return 0;
}
hdr->frame_type = seqhdr->reduced_still_picture_header ? DAV1D_FRAME_TYPE_KEY : dav1d_get_bits(gb, 2);
......@@ -976,21 +982,14 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
(gb->ptr - init_ptr) * 8 - gb->bits_left);
#endif
end:
if (have_trailing_bit)
dav1d_get_bits(gb, 1); // dummy bit
return dav1d_flush_get_bits(gb) - init_ptr;
return 0;
error:
fprintf(stderr, "Error parsing frame header\n");
return -EINVAL;
}
static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
const uint8_t *const init_ptr = gb->ptr;
static void parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
int have_tile_pos = 0;
const int n_tiles = c->frame_hdr.tiling.cols * c->frame_hdr.tiling.rows;
if (n_tiles > 1)
......@@ -1005,8 +1004,31 @@ static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
c->tile[c->n_tile_data].start = 0;
c->tile[c->n_tile_data].end = n_tiles - 1;
}
}
// Check that we haven't read more than obu_len bytes from the buffer
// since init_bit_pos.
static int
check_for_overrun(GetBits *const gb, unsigned init_bit_pos, unsigned obu_len)
{
// Make sure we haven't actually read past the end of the gb buffer
if (gb->error) {
fprintf(stderr, "Overrun in OBU bit buffer\n");
return 1;
}
return dav1d_flush_get_bits(gb) - init_ptr;
unsigned pos = dav1d_get_bits_pos(gb);
// We assume that init_bit_pos was the bit position of the buffer
// at some point in the past, so cannot be smaller than pos.
assert (init_bit_pos <= pos);
if (pos - init_bit_pos > 8 * obu_len) {
fprintf(stderr, "Overrun in OBU bit buffer into next OBU\n");
return 1;
}
return 0;
}
int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
......@@ -1041,9 +1063,23 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
} while (more);
if (gb.error) goto error;
unsigned off = dav1d_flush_get_bits(&gb) - in->data;
const unsigned init_off = off;
if (len > in->sz - off) goto error;
const unsigned init_bit_pos = dav1d_get_bits_pos(&gb);
const unsigned init_byte_pos = init_bit_pos >> 3;
const unsigned pkt_bytelen = init_byte_pos + len;
// We must have read a whole number of bytes at this point (1 byte
// for the header and whole bytes at a time when reading the
// leb128 length field).
assert(init_bit_pos & 7 == 0);
// We also know that we haven't tried to read more than in->sz
// bytes yet (otherwise the error flag would have been set by the
// code in getbits.c)
assert(in->sz >= init_byte_pos);
// Make sure that there are enough bits left in the buffer for the
// rest of the OBU.
if (len > in->sz - init_byte_pos) goto error;
switch (type) {
case OBU_SEQ_HDR: {
......@@ -1052,8 +1088,8 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
c->have_frame_hdr = 0;
if ((res = parse_seq_hdr(c, &gb, hdr_ptr)) < 0)
return res;
if ((unsigned)res != len)
goto error;
if (check_for_overrun(&gb, init_bit_pos, len))
return -EINVAL;
if (!c->have_frame_hdr || memcmp(&hdr, &c->seq_hdr, sizeof(hdr))) {
for (int i = 0; i < 8; i++) {
if (c->refs[i].p.p.data[0])
......@@ -1076,29 +1112,48 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
case OBU_FRAME_HDR:
c->have_frame_hdr = 0;
if (!c->have_seq_hdr) goto error;
if ((res = parse_frame_hdr(c, &gb, type != OBU_FRAME)) < 0)
if ((res = parse_frame_hdr(c, &gb)) < 0)
return res;
c->have_frame_hdr = 1;
for (int n = 0; n < c->n_tile_data; n++)
dav1d_data_unref(&c->tile[n].data);
c->n_tile_data = 0;
c->n_tiles = 0;
if (type != OBU_FRAME) break;
if (type != OBU_FRAME) {
// This is actually a frame header OBU so read the
// trailing bit and check for overrun.
dav1d_get_bits(&gb, 1);
if (check_for_overrun(&gb, init_bit_pos, len))
return -EINVAL;
break;
}
// OBU_FRAMEs shouldn't be signalled with show_existing_frame
if (c->frame_hdr.show_existing_frame) goto error;
off += res;
// This is the frame header at the start of a frame OBU.
// There's no trailing bit at the end to skip, but we do need
// to align to the next byte.
dav1d_bytealign_get_bits(&gb);
// fall-through
case OBU_TILE_GRP:
case OBU_TILE_GRP: {
if (!c->have_frame_hdr) goto error;
if (c->n_tile_data >= 256) goto error;
if ((res = parse_tile_hdr(c, &gb)) < 0)
return res;
off += res;
if (off > len + init_off)
goto error;
parse_tile_hdr(c, &gb);
// Align to the next byte boundary and check for overrun.
dav1d_bytealign_get_bits(&gb);
if (check_for_overrun(&gb, init_bit_pos, len))
return -EINVAL;
// The current bit position is a multiple of 8 (because we
// just aligned it) and less than 8*pkt_bytelen because
// otherwise the overrun check would have fired.
const unsigned bit_pos = dav1d_get_bits_pos(&gb);
assert(bit_pos & 7 == 0);
assert(pkt_bytelen > (bit_pos >> 3));
dav1d_ref_inc(in->ref);
c->tile[c->n_tile_data].data.ref = in->ref;
c->tile[c->n_tile_data].data.data = in->data + off;
c->tile[c->n_tile_data].data.sz = len + init_off - off;
c->tile[c->n_tile_data].data.data = in->data + (bit_pos >> 3);
c->tile[c->n_tile_data].data.sz = pkt_bytelen - (bit_pos >> 3);
// ensure tile groups are in order and sane, see 6.10.1
if (c->tile[c->n_tile_data].start > c->tile[c->n_tile_data].end ||
c->tile[c->n_tile_data].start != c->n_tiles)
......@@ -1113,6 +1168,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
c->tile[c->n_tile_data].start;
c->n_tile_data++;
break;
}
case OBU_PADDING:
case OBU_TD:
case OBU_METADATA:
......@@ -1192,7 +1248,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
}
}
return len + init_off;
return len + init_byte_pos;
error:
fprintf(stderr, "Error parsing OBU data\n");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment