I recently tried to optimize convolutions using SIMD instructions, but what I thought would be a simple task ended up taking me days, with issue after issue popping up one after another. Some of them make sense in hindsight, but others were utterly baffling. While the specific examples are for direct convolution, these considerations apply to pretty much any code with a hot loop.
I work on burn and recently wanted to
optimize direct convolution on the burn-ndarray
CPU backend.
For convolutions you need to move a two-dimensional kernel across an input
feature map and sum all the values across all input channels. This is repeated
for each output channel. The input can have padding
pixels of zero-padding
around the actual data, and the kernel can move in a strided manner (i.e. two
pixels at once). There are many algorithms with different tradeoffs, but I
decided to go with direct convolutions since they don't have memory overhead and
are still very efficient when implemented correctly. The basic outline is that
you have many nested loops, some bounds checks and a very frequently executed
fused-multiply-add (FMA) instruction.
My initial implementation looked something like this1 (simplified):
In this implementation I use several techniques. In addition to SIMD loads and
fmadds I use the optimized loop order and register blocking (using the
macro) techniques from
this paper. I finished the implementation,
executed a benchmark, and... it's slower. More than two times slower than a
naive unvectorized2 implementation in fact (~670ms vs ~300ms).
To do this I tried to use various profilers, cargo-flamegraph, samply and, after a lot of desperation, AMD μProf. After a few days of trying to get useful information out of these profilers (and getting μProf to work at all), I realized it wasn't getting me anywhere. The flamegraph and hotspots just didn't seem to make any sense at all.
So what's the next step?
Ok, none of my attempts to profile led to any success. So let's try to reduce the code to only what's actually needed in the benchmark. The benchmark uses unpadded, unstrided, undilated and ungrouped convolutions, so I stripped all padding checks and all stride/dilation calculations - it was faster, but still slow.
There was one branch left to eliminate: The check for border pixels in the register-blocking loop.
Just to check I shortened the loop to only consider pixels up to the last multiple of 8. This yields incorrect results, but should help with debugging performance.
let ow_blocks = out_width / ow_b // changed from `div_ceil`
for ow_block in 0..ow_blocks
Executing the benchmark, the code is now significantly faster on a single thread than the old code with multiple threads!
Benchmarking - conv2d-input_16x512x512_weight_16x3x3_stride_1
―――――――― Result ―――――――――
Timing full
Samples 40
Mean 205.12ms
Variance 69.420µs
Median 203.24ms
Min 201.12ms
Max 207.23ms
The problem seems to have been a mixture between spilling registers (as someone
previously focused on GPU, I was shocked to find modern CPUs only have 16 of
them), and too many branches. This is why the profilers didn't lead me anywhere
useful. Branching in modern CPUs is just too complicated to be meaningfully
represented by a profiler hotspot. This is probably the biggest takeaway from
this article: branches are much worse than you think, because the CPU can't
predict more than one3 branch per cycle4. A single if
statement inside a loop is enough to
stop any further instructions from being decoded in that cycle. Since optimal performance requires
2 FMA instructions per cycle (they take 1 cycle with a 5 cycle latency, and Zen 4 has 2 FMA
units), having a branch on every instruction massively hamstrings the performance. This may be
different on Zen 5, but remember, we still have other branches in addition to the ow
bounds check
that need predicting. So it's worse than 50% performance in practice.
Alright, now we have a good place to start. Let's start adding things back again and see where the performance starts getting bad.
First, we need to deal with the remaining pixels after the register-blocking code. To do this we're going to use a technique, that we'll be using several more times coming up:
I mentioned before that I shortened the loop to only deal with clean multiples of the register-blocking factor. So the way to deal with these remaining pixels is to just... - add another loop.
We add a second unblocked loop that starts at the end of the first loop, and runs until the edge of the feature map. Since it's not unrolled, we don't need to add any bounds checks.
for ow_block in 0..ow_blocks
for ow in ow_blocks * 8..out_width
Running the benchmarks, it's still fast - yay! It's much more efficient to run two loops than to check if we're in bounds on every iteration.
To add back padding, stride and dilation, without tanking the performance again,
I decided to use
compile-time monomorphization
to eliminate the common zero-padding and/or unit stride/dilation cases. So I use
a technique I saw used in the original convolution implementation, added by
Justin Moore, to enable auto-vectorization for unit
stride convolutions. By adding an if
-statement that checks if stride and
dilation are all 1
, we allow the compiler to
constant propagate
this value into that branch. The inner loop is extracted into a separate,
inlined function. This trick allows unstrided convolution to be auto-vectorized in the original,
non-SIMD implementation.
for ow_block in 0..ow_blocks
The padding support is added back via a const generic bool
, that sets the
padding to 0
. This allows the compiler to, once again, constant propagate it.
Nice and easy.
Let's run the benchmark again!
Benchmarking - conv2d-input_16x512x512_weight_16x3x3_stride_1
―――――――― Result ―――――――――
Timing full
Samples 40
Mean 8.136s (+3868%)
Variance 75.115µs
Median 8.042s (+3861%)
Min 8.020s (+3890%)
Max 8.341s (+3929%)
Oh. Oh dear. What happened?
To explain what just happened I need to add another small background detail I
didn't mention before. To use modern SIMD features, the code uses runtime
feature selection with pulp. The way
this works is that pulp
annotates a function with something like
#[target_feature(enable = "avx2")]
, based on the available features. This
tells the compiler it's allowed to use avx2 features, even if the target
wouldn't normally include avx2. However, only inlined functions will have the
features enabled and non-inlined function calls will not (this is
This is where samply
starts becoming actually useful. Running it allows me to
see the assembly for each function and find the hotspots. And this time they
are actually meaningful! samply
tells me, I'm spending all my time in the
line that calls the FMA and in the FMA itself. So I take a look at the assembly
and -oh no!
movaps xmm6, xmmword [rsp + 0x2e0]
movaps xmm7, xmmword [rsp + 0x2f0]
movaps xmmword [rsp + 0x170], xmm15
movaps xmmword [rsp + 0x160], xmm14
movaps xmm0, xmmword [rsp + 0x180]
movaps xmmword [rsp + 0xb0], xmm0
movaps xmm0, xmmword [rsp + 0x190]
movaps xmmword [rsp + 0xa0], xmm0
mov rcx, rbx
lea rbx, qword [rsp + 0x320]
mov rdx, rbx
mov r8, rdi
mov r9, r14
call 0x4edb0
to execute the
intrinsic for some reason?Turns out: These things are linked. Coming up is what I think is a pretty accurate guess of what happened here.
See, the compiler has a size limit for inlined functions. #[inline(always)]
tells the compiler to ignore the size limit, and almost all of my functions
were marked as #[inline(always)]
. However, the outermost function was
These are the steps I think happened next:
s wrapper
function, the one that is marked with #[target_feature]
is an AVX2 instruction that requires 256-bit
registers, the compiler must now call it dynamically and transfer the data
via the stack. This is slow. Very slow.I'm somewhat unsure about that last step, maybe someone with more knowledge of compiler internals can enlighten me on the actual reason the intrinsic is no longer inlined.
Fortunately, the solution was much simpler than this chain of events: Annotate
the top-level function with #[inline(always)]
vbroadcastss ymm9, dword [r12 + r11 * 1]
lea rcx, qword [r11 + r12 * 1]
vfmadd231ps ymm6, ymm8, ymm10
Much better. And the benchmark?
Benchmarking - conv2d-input_16x512x512_weight_16x3x3_stride_1
―――――――― Result ―――――――――
Timing full
Samples 40
Mean 230.12ms
Variance 69.420µs
Median 232.24ms
Min 224.12ms
Max 236.23ms
Performance was good for unpadded convolutions, but would still have been
lackluster for padded ones, since we need to check if we are in padding in
every single loop iteration. To solve this, we can use the same technique we
used for the out_width
earlier: All pixels that are more than padding
from the edges are guaranteed to always be in bounds, so we can run one loop
from padding_h
to out_height - padding_h
and padding_w
width - padding_w
without bounds checks, then a second loop for the border
pixels that does do bounds checks. This is much faster than checking every
pixel, since most pixels are always in bounds.
if !=
Modern CPUs are weird, and performance is not always obvious. Inlining is fragile, and adding a single line of code can completely change the way your program is compiled without you even noticing. Profilers aren't always helpful, especially when your problem is more complex than something like using a slow function or allocating memory too often. My tip is to just try to figure out just when things start getting bad and maybe learn some basic assembly, so you can spot things like way too many stack loads/stores (which indicates register spilling).
I hope this helps someone deal with their performance issues a little bit more quickly than I did.
Also big shout out once again to samply. Even when the performance data didn't mean much, being able to easily view assembly for any given function was very useful.
The final version of the code for this implementation can be found here. And just for fun, here is the final benchmark after all optimizations, using multi-threading:
Benchmarking - conv2d-input_16x512x512_weight_16x3x3_stride_1
―――――――― Result ―――――――――
Timing full
Samples 40
Mean 42.731ms
Variance 5.115µs
Median 42.906ms
Min 38.162ms
Max 47.554ms