Commit c59f1940 authored by Rupert Swarbrick's avatar Rupert Swarbrick Committed by Ronald S. Bultje
Browse files

Correctly flush at the end of OBUs

This fixes failures when an OBU has more than a byte's worth of
trailing zeros.

As part of this work, it also rejigs the dav1d_flush_get_bits function
slightly. This worked before, but it wasn't very obvious why (it
worked because bits_left was never more than 7). This patch renames it
to dav1d_bytealign_get_bits, which makes it clearer what it does and
adds a comment explaining why it works properly.

The new dav1d_bytealign_get_bits is also now void (rather than
returning the next byte to read). The patch defines
dav1d_get_bits_pos, which returns the current bit position. This feels
a little easier to reason about.

We also add a new check to make sure that we haven't fallen off the
end of the OBU. This can happen when a byte buffer contains more than
one OBU: the GetBits might not have got to EOF, but we might now be
half-way through the next OBU.
parent 2532642b
Pipeline #2604 passed with stages
in 2 minutes and 53 seconds
......@@ -126,8 +126,16 @@ int dav1d_get_bits_subexp(GetBits *const c, const int ref, const unsigned n) {
return (int) get_bits_subexp_u(c, ref + (1 << n), 2 << n) - (1 << n);
}
const uint8_t *dav1d_flush_get_bits(GetBits *c) {
void dav1d_bytealign_get_bits(GetBits *c) {
// bits_left is never more than 7, because it is only incremented
// by refill(), called by dav1d_get_bits and that never reads more
// than 7 bits more than it needs.
//
// If this wasn't true, we would need to work out how many bits to
// discard (bits_left % 8), subtract that from bits_left and then
// shift state right by that amount.
assert(c->bits_left <= 7);
c->bits_left = 0;
c->state = 0;
return c->ptr;
}
......@@ -46,6 +46,13 @@ int dav1d_get_sbits(GetBits *c, unsigned n);
unsigned dav1d_get_uniform(GetBits *c, unsigned max);
unsigned dav1d_get_vlc(GetBits *c);
int dav1d_get_bits_subexp(GetBits *c, int ref, unsigned n);
const uint8_t *dav1d_flush_get_bits(GetBits *c);
// Discard bits from the buffer until we're next byte-aligned.
void dav1d_bytealign_get_bits(GetBits *c);
// Return the current bit position relative to the start of the buffer.
static inline unsigned dav1d_get_bits_pos(const GetBits *c) {
return (c->ptr - c->ptr_start) * 8 - c->bits_left;
}
#endif /* __DAV1D_SRC_GETBITS_H__ */
......@@ -46,15 +46,17 @@
static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
Av1SequenceHeader *const hdr)
{
const uint8_t *const init_ptr = gb->ptr;
#define DEBUG_SEQ_HDR 0
#if DEBUG_SEQ_HDR
const unsigned init_bit_pos = dav1d_get_bits_pos(gb);
#endif
hdr->profile = dav1d_get_bits(gb, 3);
if (hdr->profile > 2) goto error;
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-profile: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->still_picture = dav1d_get_bits(gb, 1);
......@@ -62,7 +64,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
if (hdr->reduced_still_picture_header && !hdr->still_picture) goto error;
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-stillpicture_flags: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
if (hdr->reduced_still_picture_header) {
......@@ -97,7 +99,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-timinginfo: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->display_model_info_present = dav1d_get_bits(gb, 1);
......@@ -126,7 +128,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-operating-points: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
}
......@@ -136,7 +138,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->max_height = dav1d_get_bits(gb, hdr->height_n_bits) + 1;
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-size: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->frame_id_numbers_present =
hdr->reduced_still_picture_header ? 0 : dav1d_get_bits(gb, 1);
......@@ -146,7 +148,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-frame-id-numbers-present: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->sb128 = dav1d_get_bits(gb, 1);
......@@ -180,7 +182,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->screen_content_tools = dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-screentools: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->force_integer_mv = hdr->screen_content_tools ?
dav1d_get_bits(gb, 1) ? ADAPTIVE : dav1d_get_bits(gb, 1) : 2;
......@@ -192,7 +194,7 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->restoration = dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-featurebits: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
const int hbd = dav1d_get_bits(gb, 1);
......@@ -243,18 +245,22 @@ static int parse_seq_hdr(Dav1dContext *const c, GetBits *const gb,
}
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-colorinfo: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
hdr->film_grain_present = dav1d_get_bits(gb, 1);
#if DEBUG_SEQ_HDR
printf("SEQHDR: post-filmgrain: off=%ld\n",
(gb->ptr - init_ptr) * 8 - gb->bits_left);
dav1d_get_bits_pos(gb) - init_bit_pos);
#endif
dav1d_get_bits(gb, 1); // dummy bit
return dav1d_flush_get_bits(gb) - init_ptr;
// We needn't bother flushing the OBU here: we'll check we didn't
// overrun in the caller and will then discard gb, so there's no
// point in setting its position properly.
return 0;
error:
fprintf(stderr, "Error parsing sequence header\n");
......@@ -313,16 +319,16 @@ static const Av1LoopfilterModeRefDeltas default_mode_ref_deltas = {
.ref_delta = { 1, 0, 0, 0, -1, 0, -1, -1 },
};
static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
const int have_trailing_bit)
{
static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb) {
#define DEBUG_FRAME_HDR 0
#if DEBUG_FRAME_HDR
const uint8_t *const init_ptr = gb->ptr;
#endif
const Av1SequenceHeader *const seqhdr = &c->seq_hdr;
Av1FrameHeader *const hdr = &c->frame_hdr;
int res;
#define DEBUG_FRAME_HDR 0
hdr->show_existing_frame =
!seqhdr->reduced_still_picture_header && dav1d_get_bits(gb, 1);
#if DEBUG_FRAME_HDR
......@@ -335,7 +341,7 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
hdr->frame_presentation_delay = dav1d_get_bits(gb, seqhdr->frame_presentation_delay_length);
if (seqhdr->frame_id_numbers_present)
hdr->frame_id = dav1d_get_bits(gb, seqhdr->frame_id_n_bits);
goto end;
return 0;
}
hdr->frame_type = seqhdr->reduced_still_picture_header ? DAV1D_FRAME_TYPE_KEY : dav1d_get_bits(gb, 2);
......@@ -976,21 +982,14 @@ static int parse_frame_hdr(Dav1dContext *const c, GetBits *const gb,
(gb->ptr - init_ptr) * 8 - gb->bits_left);
#endif
end:
if (have_trailing_bit)
dav1d_get_bits(gb, 1); // dummy bit
return dav1d_flush_get_bits(gb) - init_ptr;
return 0;
error:
fprintf(stderr, "Error parsing frame header\n");
return -EINVAL;
}
static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
const uint8_t *const init_ptr = gb->ptr;
static void parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
int have_tile_pos = 0;
const int n_tiles = c->frame_hdr.tiling.cols * c->frame_hdr.tiling.rows;
if (n_tiles > 1)
......@@ -1005,8 +1004,31 @@ static int parse_tile_hdr(Dav1dContext *const c, GetBits *const gb) {
c->tile[c->n_tile_data].start = 0;
c->tile[c->n_tile_data].end = n_tiles - 1;
}
}
// Check that we haven't read more than obu_len bytes from the buffer
// since init_bit_pos.
static int
check_for_overrun(GetBits *const gb, unsigned init_bit_pos, unsigned obu_len)
{
// Make sure we haven't actually read past the end of the gb buffer
if (gb->error) {
fprintf(stderr, "Overrun in OBU bit buffer\n");
return 1;
}
return dav1d_flush_get_bits(gb) - init_ptr;
unsigned pos = dav1d_get_bits_pos(gb);
// We assume that init_bit_pos was the bit position of the buffer
// at some point in the past, so cannot be smaller than pos.
assert (init_bit_pos <= pos);
if (pos - init_bit_pos > 8 * obu_len) {
fprintf(stderr, "Overrun in OBU bit buffer into next OBU\n");
return 1;
}
return 0;
}
int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
......@@ -1041,9 +1063,23 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
} while (more);
if (gb.error) goto error;
unsigned off = dav1d_flush_get_bits(&gb) - in->data;
const unsigned init_off = off;
if (len > in->sz - off) goto error;
const unsigned init_bit_pos = dav1d_get_bits_pos(&gb);
const unsigned init_byte_pos = init_bit_pos >> 3;
const unsigned pkt_bytelen = init_byte_pos + len;
// We must have read a whole number of bytes at this point (1 byte
// for the header and whole bytes at a time when reading the
// leb128 length field).
assert(init_bit_pos & 7 == 0);
// We also know that we haven't tried to read more than in->sz
// bytes yet (otherwise the error flag would have been set by the
// code in getbits.c)
assert(in->sz >= init_byte_pos);
// Make sure that there are enough bits left in the buffer for the
// rest of the OBU.
if (len > in->sz - init_byte_pos) goto error;
switch (type) {
case OBU_SEQ_HDR: {
......@@ -1052,8 +1088,8 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
c->have_frame_hdr = 0;
if ((res = parse_seq_hdr(c, &gb, hdr_ptr)) < 0)
return res;
if ((unsigned)res != len)
goto error;
if (check_for_overrun(&gb, init_bit_pos, len))
return -EINVAL;
if (!c->have_frame_hdr || memcmp(&hdr, &c->seq_hdr, sizeof(hdr))) {
for (int i = 0; i < 8; i++) {
if (c->refs[i].p.p.data[0])
......@@ -1076,29 +1112,48 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
case OBU_FRAME_HDR:
c->have_frame_hdr = 0;
if (!c->have_seq_hdr) goto error;
if ((res = parse_frame_hdr(c, &gb, type != OBU_FRAME)) < 0)
if ((res = parse_frame_hdr(c, &gb)) < 0)
return res;
c->have_frame_hdr = 1;
for (int n = 0; n < c->n_tile_data; n++)
dav1d_data_unref(&c->tile[n].data);
c->n_tile_data = 0;
c->n_tiles = 0;
if (type != OBU_FRAME) break;
if (type != OBU_FRAME) {
// This is actually a frame header OBU so read the
// trailing bit and check for overrun.
dav1d_get_bits(&gb, 1);
if (check_for_overrun(&gb, init_bit_pos, len))
return -EINVAL;
break;
}
// OBU_FRAMEs shouldn't be signalled with show_existing_frame
if (c->frame_hdr.show_existing_frame) goto error;
off += res;
// This is the frame header at the start of a frame OBU.
// There's no trailing bit at the end to skip, but we do need
// to align to the next byte.
dav1d_bytealign_get_bits(&gb);
// fall-through
case OBU_TILE_GRP:
case OBU_TILE_GRP: {
if (!c->have_frame_hdr) goto error;
if (c->n_tile_data >= 256) goto error;
if ((res = parse_tile_hdr(c, &gb)) < 0)
return res;
off += res;
if (off > len + init_off)
goto error;
parse_tile_hdr(c, &gb);
// Align to the next byte boundary and check for overrun.
dav1d_bytealign_get_bits(&gb);
if (check_for_overrun(&gb, init_bit_pos, len))
return -EINVAL;
// The current bit position is a multiple of 8 (because we
// just aligned it) and less than 8*pkt_bytelen because
// otherwise the overrun check would have fired.
const unsigned bit_pos = dav1d_get_bits_pos(&gb);
assert(bit_pos & 7 == 0);
assert(pkt_bytelen > (bit_pos >> 3));
dav1d_ref_inc(in->ref);
c->tile[c->n_tile_data].data.ref = in->ref;
c->tile[c->n_tile_data].data.data = in->data + off;
c->tile[c->n_tile_data].data.sz = len + init_off - off;
c->tile[c->n_tile_data].data.data = in->data + (bit_pos >> 3);
c->tile[c->n_tile_data].data.sz = pkt_bytelen - (bit_pos >> 3);
// ensure tile groups are in order and sane, see 6.10.1
if (c->tile[c->n_tile_data].start > c->tile[c->n_tile_data].end ||
c->tile[c->n_tile_data].start != c->n_tiles)
......@@ -1113,6 +1168,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
c->tile[c->n_tile_data].start;
c->n_tile_data++;
break;
}
case OBU_PADDING:
case OBU_TD:
case OBU_METADATA:
......@@ -1192,7 +1248,7 @@ int dav1d_parse_obus(Dav1dContext *const c, Dav1dData *const in) {
}
}
return len + init_off;
return len + init_byte_pos;
error:
fprintf(stderr, "Error parsing OBU data\n");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment