diff --git a/src/shaders.c b/src/shaders.c
index 6ec5052c8f7f6c3e982c7882bf2c103e29ed7e92..93c93afd993f8944fe07218e9f238bbf03aa196f 100644
--- a/src/shaders.c
+++ b/src/shaders.c
@@ -106,6 +106,21 @@ bool sh_try_compute(struct pl_shader *sh, int bw, int bh, bool flex, size_t mem)
         return false;
     }
 
+    if (bw > gpu->limits.max_group_size[0] ||
+        bh > gpu->limits.max_group_size[1] ||
+        (bw * bh) > gpu->limits.max_group_threads)
+    {
+        if (!flex) {
+            PL_TRACE(sh, "Disabling compute shader due to exceeded group "
+                     "thread count.");
+            return false;
+        } else {
+            // Pick better group sizes
+            bw = PL_MIN(bw, gpu->limits.max_group_size[0]);
+            bh = gpu->limits.max_group_threads / bw;
+        }
+    }
+
     sh->res.compute_shmem += mem;
 
     // If the current shader is either not a compute shader, or we have no
@@ -121,6 +136,7 @@ bool sh_try_compute(struct pl_shader *sh, int bw, int bh, bool flex, size_t mem)
     if (sh->flexible_work_groups && flex) {
         *sh_bw = PL_MAX(*sh_bw, bw);
         *sh_bh = PL_MAX(*sh_bh, bh);
+        pl_assert(*sh_bw * *sh_bh <= gpu->limits.max_group_threads);
         return true;
     }