From b129d9f2cb897cedba77a60bd5e3621c14ee5484 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20Storsj=C3=B6?= <martin@martin.st>
Date: Tue, 24 Dec 2024 21:52:15 +0200
Subject: [PATCH] mc: Reduce stack use in {put,prep}_scaled_{bilin,8tap}

For the bilin cases, this seems to make things marginally faster
(measured on x86_64; 7-25% faster with compiler autovectorization).
For 8tap, it doesn't make much of a difference at all.

Before:                                      GCC   Clang
mc_scaled_8tap_regular_w128_8bpc_c:     115155.5   98549.3
mc_scaled_8tap_regular_w128_8bpc_ssse3:  17936.0   18411.1
mc_scaled_bilinear_w128_8bpc_c:          40290.0   51812.9
mc_scaled_bilinear_w128_8bpc_ssse3:      18243.9   18177.0
After:
mc_scaled_8tap_regular_w128_8bpc_c:     116304.3   99453.2
mc_scaled_8tap_regular_w128_8bpc_ssse3:  18387.0   18077.3
mc_scaled_bilinear_w128_8bpc_c:          37381.4   41145.0
mc_scaled_bilinear_w128_8bpc_ssse3:      18423.8   18031.6

(Benchmarked with the seed 0; the total runtime for the scaled
benchmarks are significantly affected by the random seed.)

This reduces the stack usage of these functions from around 65 KB
each, to less than 1 KB for bilin, and around 2 KB for 8tap.

With this in place, the required stack space for dav1d should
be mostly identical across configurations; on x86_64 (both with
and without assembly), it can run with 62 KB of stack, and
on arm and aarch64, it can run with 58 KB of stack.
---
 src/mc_tmpl.c | 227 +++++++++++++++++++++++++++++---------------------
 1 file changed, 134 insertions(+), 93 deletions(-)

diff --git a/src/mc_tmpl.c b/src/mc_tmpl.c
index 46a57a14..0fd63669 100644
--- a/src/mc_tmpl.c
+++ b/src/mc_tmpl.c
@@ -84,18 +84,34 @@ prep_c(int16_t *tmp, const pixel *src, const ptrdiff_t src_stride,
      F[6] * src[x + +3 * stride] + \
      F[7] * src[x + +4 * stride])
 
+#define FILTER_8TAP2(src, x, F) \
+    (F[0] * src[0][x] + \
+     F[1] * src[1][x] + \
+     F[2] * src[2][x] + \
+     F[3] * src[3][x] + \
+     F[4] * src[4][x] + \
+     F[5] * src[5][x] + \
+     F[6] * src[6][x] + \
+     F[7] * src[7][x])
+
 #define DAV1D_FILTER_8TAP_RND(src, x, F, stride, sh) \
     ((FILTER_8TAP(src, x, F, stride) + ((1 << (sh)) >> 1)) >> (sh))
 
 #define DAV1D_FILTER_8TAP_RND2(src, x, F, stride, rnd, sh) \
     ((FILTER_8TAP(src, x, F, stride) + (rnd)) >> (sh))
 
+#define DAV1D_FILTER_8TAP_RND3(src, x, F, sh) \
+    ((FILTER_8TAP2(src, x, F) + ((1 << (sh)) >> 1)) >> (sh))
+
 #define DAV1D_FILTER_8TAP_CLIP(src, x, F, stride, sh) \
     iclip_pixel(DAV1D_FILTER_8TAP_RND(src, x, F, stride, sh))
 
 #define DAV1D_FILTER_8TAP_CLIP2(src, x, F, stride, rnd, sh) \
     iclip_pixel(DAV1D_FILTER_8TAP_RND2(src, x, F, stride, rnd, sh))
 
+#define DAV1D_FILTER_8TAP_CLIP3(src, x, F, sh) \
+    iclip_pixel(DAV1D_FILTER_8TAP_RND3(src, x, F, sh))
+
 #define GET_H_FILTER(mx) \
     const int8_t *const fh = !(mx) ? NULL : w > 4 ? \
         dav1d_mc_subpel_filters[filter_type & 3][(mx) - 1] : \
@@ -179,43 +195,50 @@ put_8tap_scaled_c(pixel *dst, const ptrdiff_t dst_stride,
 {
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
     const int intermediate_rnd = (1 << intermediate_bits) >> 1;
-    int tmp_h = (((h - 1) * dy + my) >> 10) + 8;
-    int16_t mid[128 * (256 + 7)], *mid_ptr = mid;
+    int16_t mid[128 * 8];
+    int16_t *mid_ptrs[8];
+    int in_y = -8;
     src_stride = PXSTRIDE(src_stride);
 
-    src -= src_stride * 3;
-    do {
-        int x;
-        int imx = mx, ioff = 0;
-
-        for (x = 0; x < w; x++) {
-            GET_H_FILTER(imx >> 6);
-            mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1,
-                                                    6 - intermediate_bits) :
-                              src[ioff] << intermediate_bits;
-            imx += dx;
-            ioff += imx >> 10;
-            imx &= 0x3ff;
-        }
+    for (int i = 0; i < 8; i++)
+        mid_ptrs[i] = &mid[128 * i];
 
-        mid_ptr += 128;
-        src += src_stride;
-    } while (--tmp_h);
+    src -= src_stride * 3;
 
-    mid_ptr = mid + 128 * 3;
     for (int y = 0; y < h; y++) {
         int x;
-        GET_V_FILTER(my >> 6);
+        int src_y = my >> 10;
+        GET_V_FILTER((my & 0x3ff) >> 6);
+
+        while (in_y < src_y) {
+            int imx = mx, ioff = 0;
+            int16_t *mid_ptr = mid_ptrs[0];
+
+            for (int i = 0; i < 7; i++)
+                mid_ptrs[i] = mid_ptrs[i + 1];
+            mid_ptrs[7] = mid_ptr;
+
+            for (x = 0; x < w; x++) {
+                GET_H_FILTER(imx >> 6);
+                mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1,
+                                                        6 - intermediate_bits) :
+                                  src[ioff] << intermediate_bits;
+                imx += dx;
+                ioff += imx >> 10;
+                imx &= 0x3ff;
+            }
+
+            src += src_stride;
+            in_y++;
+        }
 
         for (x = 0; x < w; x++)
-            dst[x] = fv ? DAV1D_FILTER_8TAP_CLIP(mid_ptr, x, fv, 128,
-                                                 6 + intermediate_bits) :
-                          iclip_pixel((mid_ptr[x] + intermediate_rnd) >>
+            dst[x] = fv ? DAV1D_FILTER_8TAP_CLIP3(mid_ptrs, x, fv,
+                                                  6 + intermediate_bits) :
+                          iclip_pixel((mid_ptrs[3][x] + intermediate_rnd) >>
                                               intermediate_bits);
 
         my += dy;
-        mid_ptr += (my >> 10) * 128;
-        my &= 0x3ff;
         dst += PXSTRIDE(dst_stride);
     }
 }
@@ -288,41 +311,48 @@ prep_8tap_scaled_c(int16_t *tmp, const pixel *src, ptrdiff_t src_stride,
                    HIGHBD_DECL_SUFFIX)
 {
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
-    int tmp_h = (((h - 1) * dy + my) >> 10) + 8;
-    int16_t mid[128 * (256 + 7)], *mid_ptr = mid;
+    int16_t mid[128 * 8];
+    int16_t *mid_ptrs[8];
+    int in_y = -8;
     src_stride = PXSTRIDE(src_stride);
 
-    src -= src_stride * 3;
-    do {
-        int x;
-        int imx = mx, ioff = 0;
-
-        for (x = 0; x < w; x++) {
-            GET_H_FILTER(imx >> 6);
-            mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1,
-                                                    6 - intermediate_bits) :
-                              src[ioff] << intermediate_bits;
-            imx += dx;
-            ioff += imx >> 10;
-            imx &= 0x3ff;
-        }
+    for (int i = 0; i < 8; i++)
+        mid_ptrs[i] = &mid[128 * i];
 
-        mid_ptr += 128;
-        src += src_stride;
-    } while (--tmp_h);
+    src -= src_stride * 3;
 
-    mid_ptr = mid + 128 * 3;
     for (int y = 0; y < h; y++) {
         int x;
-        GET_V_FILTER(my >> 6);
+        int src_y = my >> 10;
+        GET_V_FILTER((my & 0x3ff) >> 6);
+
+        while (in_y < src_y) {
+            int imx = mx, ioff = 0;
+            int16_t *mid_ptr = mid_ptrs[0];
+
+            for (int i = 0; i < 7; i++)
+                mid_ptrs[i] = mid_ptrs[i + 1];
+            mid_ptrs[7] = mid_ptr;
+
+            for (x = 0; x < w; x++) {
+                GET_H_FILTER(imx >> 6);
+                mid_ptr[x] = fh ? DAV1D_FILTER_8TAP_RND(src, ioff, fh, 1,
+                                                        6 - intermediate_bits) :
+                                  src[ioff] << intermediate_bits;
+                imx += dx;
+                ioff += imx >> 10;
+                imx &= 0x3ff;
+            }
+
+            src += src_stride;
+            in_y++;
+        }
 
         for (x = 0; x < w; x++)
-            tmp[x] = (fv ? DAV1D_FILTER_8TAP_RND(mid_ptr, x, fv, 128, 6)
-                         : mid_ptr[x]) - PREP_BIAS;
+            tmp[x] = (fv ? DAV1D_FILTER_8TAP_RND3(mid_ptrs, x, fv, 6)
+                         : mid_ptrs[3][x]) - PREP_BIAS;
 
         my += dy;
-        mid_ptr += (my >> 10) * 128;
-        my &= 0x3ff;
         tmp += w;
     }
 }
@@ -392,6 +422,15 @@ filter_fns(sharp_smooth,   DAV1D_FILTER_8TAP_SHARP,   DAV1D_FILTER_8TAP_SMOOTH)
 #define FILTER_BILIN_CLIP(src, x, mxy, stride, sh) \
     iclip_pixel(FILTER_BILIN_RND(src, x, mxy, stride, sh))
 
+#define FILTER_BILIN2(src1, src2, x, mxy) \
+    (16 * src1[x] + ((mxy) * (src2[x] - src1[x])))
+
+#define FILTER_BILIN_RND2(src1, src2, x, mxy, sh) \
+    ((FILTER_BILIN2(src1, src2, x, mxy) + ((1 << (sh)) >> 1)) >> (sh))
+
+#define FILTER_BILIN_CLIP2(src1, src2, x, mxy, sh) \
+    iclip_pixel(FILTER_BILIN_RND2(src1, src2, x, mxy, sh))
+
 static void put_bilin_c(pixel *dst, ptrdiff_t dst_stride,
                         const pixel *src, ptrdiff_t src_stride,
                         const int w, int h, const int mx, const int my
@@ -456,36 +495,37 @@ static void put_bilin_scaled_c(pixel *dst, ptrdiff_t dst_stride,
                                HIGHBD_DECL_SUFFIX)
 {
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
-    int tmp_h = (((h - 1) * dy + my) >> 10) + 2;
-    int16_t mid[128 * (256 + 1)], *mid_ptr = mid;
+    int16_t mid[128 * 2];
+    int in_y = -2;
 
     do {
         int x;
-        int imx = mx, ioff = 0;
-
-        for (x = 0; x < w; x++) {
-            mid_ptr[x] = FILTER_BILIN_RND(src, ioff, imx >> 6, 1,
-                                          4 - intermediate_bits);
-            imx += dx;
-            ioff += imx >> 10;
-            imx &= 0x3ff;
-        }
-
-        mid_ptr += 128;
-        src += PXSTRIDE(src_stride);
-    } while (--tmp_h);
+        int y = my >> 10;
+        int16_t *mid1 = &mid[(y & 1) * 128];
+        int16_t *mid2 = &mid[((y + 1) & 1) * 128];
+        int dmy = my & 0x3ff;
+
+        while (in_y < y) {
+            int imx = mx, ioff = 0;
+            int16_t *mid_ptr = &mid[(in_y & 1) * 128];
+
+            for (x = 0; x < w; x++) {
+                mid_ptr[x] = FILTER_BILIN_RND(src, ioff, imx >> 6, 1,
+                                              4 - intermediate_bits);
+                imx += dx;
+                ioff += imx >> 10;
+                imx &= 0x3ff;
+            }
 
-    mid_ptr = mid;
-    do {
-        int x;
+            src += PXSTRIDE(src_stride);
+            in_y++;
+        }
 
         for (x = 0; x < w; x++)
-            dst[x] = FILTER_BILIN_CLIP(mid_ptr, x, my >> 6, 128,
+            dst[x] = FILTER_BILIN_CLIP2(mid1, mid2, x, dmy >> 6,
                                        4 + intermediate_bits);
 
         my += dy;
-        mid_ptr += (my >> 10) * 128;
-        my &= 0x3ff;
         dst += PXSTRIDE(dst_stride);
     } while (--h);
 }
@@ -551,35 +591,36 @@ static void prep_bilin_scaled_c(int16_t *tmp,
                                 const int dx, const int dy HIGHBD_DECL_SUFFIX)
 {
     const int intermediate_bits = get_intermediate_bits(bitdepth_max);
-    int tmp_h = (((h - 1) * dy + my) >> 10) + 2;
-    int16_t mid[128 * (256 + 1)], *mid_ptr = mid;
+    int16_t mid[128 * 2];
+    int in_y = -2;
 
     do {
         int x;
-        int imx = mx, ioff = 0;
-
-        for (x = 0; x < w; x++) {
-            mid_ptr[x] = FILTER_BILIN_RND(src, ioff, imx >> 6, 1,
-                                          4 - intermediate_bits);
-            imx += dx;
-            ioff += imx >> 10;
-            imx &= 0x3ff;
-        }
-
-        mid_ptr += 128;
-        src += PXSTRIDE(src_stride);
-    } while (--tmp_h);
+        int y = my >> 10;
+        int16_t *mid1 = &mid[(y & 1) * 128];
+        int16_t *mid2 = &mid[((y + 1) & 1) * 128];
+        int dmy = my & 0x3ff;
+
+        while (in_y < y) {
+            int imx = mx, ioff = 0;
+            int16_t *mid_ptr = &mid[(in_y & 1) * 128];
+
+            for (x = 0; x < w; x++) {
+                mid_ptr[x] = FILTER_BILIN_RND(src, ioff, imx >> 6, 1,
+                                              4 - intermediate_bits);
+                imx += dx;
+                ioff += imx >> 10;
+                imx &= 0x3ff;
+            }
 
-    mid_ptr = mid;
-    do {
-        int x;
+            src += PXSTRIDE(src_stride);
+            in_y++;
+        }
 
         for (x = 0; x < w; x++)
-            tmp[x] = FILTER_BILIN_RND(mid_ptr, x, my >> 6, 128, 4) - PREP_BIAS;
+            tmp[x] = FILTER_BILIN_RND2(mid1, mid2, x, dmy >> 6, 4) - PREP_BIAS;
 
         my += dy;
-        mid_ptr += (my >> 10) * 128;
-        my &= 0x3ff;
         tmp += w;
     } while (--h);
 }
-- 
GitLab