Last modified: 2023-09-19 @ 8ded1b5
Linear Algebra with Const Generics
I was recently doing a project with some friends where we extracted the pixel data from RAW images. One step in extracting a picture from a RAW image is demosaicing, which boils down to doing some kind of convolution over the pixel intensities in the image with a convolution kernel defined by the color filter array. That’s when I had the idea to just tinker with creating a linear algebra/image processing library in Rust that utilizes const generics in order to encode some properties of matrices in the type system. This post chronicles the creation of one such library. It is very bare bones, but I find it interesting what you can encode in the type system.
This is not useful as a general linear algebra/image processing library, since const generics are compile-time. Thus creating matrices of arbitrary dimensions at runtime can’t be done with const generics. This is purely an academic exercise in what you can do with them. But it was fun nonetheless!
Table of Contents
Earlier (better) work
There are a couple of crates that do linear algebra, multidimensional arrays, and computer graphics. Here’s a small selection:
None of these utilize const generics the way I’ll be doing, and for good reason. We will get to that in due time, though.
In the beginning, there was the matrix
I will be restricting this to 2D matrices 1, as the original domain where I got
the idea was relating to 2D photographs. There are multiple ways we could define
a matrix with M
rows and N
columns, encoding the dimensions with const
generics. There are a couple of data types that you could base a Matrix<T>
on.
They all have their pros and cons.
Option 1: Vec<Vec<T>>
A nested Vec
is very flexible and easy to reason about when it comes to
indexing, as there is no need to manually calculate a linear index from a
two-dimensional one, and its size is completely dynamic. However, given that it
is a double Vec
, there is also potentially a double indirection happening,
having to address memory into the outer Vec
to find the inner one and then
index into that one. So I opted not to use this representation.
Option 2: [[T; N]; M]
A two-dimensional array of T
has the same advantage of intuitive indexing as a
nested Vec
, but without the disadvantage of a double indirection, since its
dimensions are known statically. This means the compiler will take care of the
calculation of a linear index for us! However, an array lives on the stack and
perhaps we want to work with some really large matrices in certain applications.
So I opted not to use this representation either.
Option 3: [[T; M]; N]
Wait, did I accidentally put option 2 in twice? No, look again! The dimensions
are switched. Option 2 contained M
arrays of length N
, whereas this one has
N
arrays of length M
. This is known as row-major and
column-major order.
Option 2 stored the elements in memory sequentially, row by row. This option
instead stores them column by column. The choice between them is essentially
arbitrary, but it is important to know. I elected to store my matrices in
row-major order, so this option is not that relevant but very important to
mention. There will be a cool little trick regarding row-major vs. column-major
later.
Option 4: [T; M * N]
Ah, a linear array. Pretty simple, though you have to calculate the linear index
manually to use it. That’s not too bad though. It’s really simple to do
operations on all elements at once since there is no nesting to think about and
row-major vs. column-major order is entirely arbitrary. Both options can be
represented by this type and you don’t need to go through std::mem::transmute
to get from one to the other. Wonderful!
However, you cannot do this with const generics. They can only be used as standalone arguments, not in const operations. This just leaves us with…
Option 5: Vec<T>
This has the same properties of option 4, but with the huge advantage that it is
actually valid Rust! Also, since it’s a Vec
it lives on the heap and our
Matrix<T>
type can own that memory. No large amounts of data on the stack. And
the information about the matrix dimensions is entirely on the Matrix
type.
Our underlying type is completely agnostic to the size and layout of the
Matrix
type, which I think is pretty nice. No double bookkeeping. What relates
to the matrix, is defined in the type of the matrix. All of these reasons (and
the row-major vs. column-major trick I will get to later) made me decide on this
as the underlying type of the matrix:
#[derive(Debug, Clone)]
pub struct Matrix<
T,
const ROWS: usize,
const COLS: usize
>(Vec<T>);
You can probably already start to see another reason why other libraries do not use const generics this way. That is one long type definition. But this is just for fun, so we’ll roll with it. Besides, it’s going to get so, so much worse :)
impl
-blocks from Hell
Let’s implement some functionality for our shiny new Matrix<T>
. How about
being able to construct a matrix from a two-dimensional array? That would enable
us to do things like
let matrix = Matrix::new([
[1, 2],
[3, 4],
]);
Pretty nice and readable! Now let’s see what the implementation of that would look like:
impl<T: Element, const ROWS: usize, const COLS: usize>
Matrix<T, ROWS, COLS> {
pub fn new(array: [[T; COLS]; ROWS]) -> Self {
Self(array.concat())
}
}
Oh. Oh no, this is going to get out of hand fast. And that’s exactly what’s so fun about it. :)
Element
is just a combination trait of Copy + Clone + PartialEq
that I
created and implemented for i64
and f64
. It’s going to show up in every
single impl
block.
Indexing
Continuing on with this syntax-driven development, one of the simplest
operations I would like to be able to do is access elements of the matrix. Say I
have a matrix m
. I would like to get the element at row 0, column 1 by doing
something like this:
#[test]
fn index() {
let m = Matrix::new([
[1, 2],
[3, 4],
]);
let element = m[(0, 1)];
assert_eq!(element, 2);
}
This means we’ll have to implement the Index<(usize, usize)>
trait for our
Matrix<T>
. However, it’s not a Matrix<T>
, is it? It’s a Matrix<T, ROWS, COLS>
. Oh no, time for another impl
block!
impl<T: Element, const ROWS: usize, const COLS: usize>
Index<(usize, usize)> for Matrix<T, ROWS, COLS> {
type Output = T;
fn index(&self, (row, col): (usize, usize)) -> &T {
assert!(row < ROWS && col < COLS);
&self.0[row * COLS + col]
}
}
Okay, the implementation itself is straight-forward. But are all the impl<...>
lines going to look like that? No, no. They’re going to get worse ;)
IndexMut<(usize, usize)>
is essentially identical, but pretty important to
implement as well if we want to be able to change our matrices.
The assert!
is pretty ugly, but I didn’t want to introduce even more syntax
(there will be plenty of that) by making the indexing return an Option<&T>
. It
also makes operations on the numbers simpler by not having to check and unwrap
them all the time, and since this is not intended to be a good, production-ready
library I will just do the easy thing here.
Transposition
Transposing a matrix is a pretty common operation. We want to be able to reflect
it along its diagonal, swapping the rows and columns. So do we have to reach
into our underlying Vec<T>
and shuffle it around? Remember that it is stored
in row-major order, so if the columns are now considered the rows, we have to
swap all the elements around, right?
Wrong!
Here’s the sneaky little row-major vs. column-major trick: Don’t touch the underlying memory at all. Just add more information into the type system. That’s right, we’re adding another const generic to our matrix type, baby!
pub struct Matrix<
T,
const ROWS: usize,
const COLS: usize
const TRANSPOSED: bool,
>(Vec<T>);
This means the two previous impl
blocks have to be updated. For the one
containing the constructor, I will simply restrict it to always have this new
bool
set to false
, meaning all matrices start out non-transposed, in
row-major order:
impl<T: Element, const ROWS: usize, const COLS: usize>
- Matrix<T, ROWS, COLS> {
+ Matrix<T, ROWS, COLS, false> {
pub fn new(array: [[T; COLS]; ROWS]) -> Self {
Self(array.concat())
}
}
For the Index<(usize, usize)>
implementation, there are two possible ways to
go about it. Either you just add the same false
as for the constructor above
and implement it again for true
(we still want to be able to index transposed
matrices after all), or you put the checking of this bool inside the index()
function. Either way to go about it is fine. I’m just going to arbitrarily put
the check inside to keep the number of impl
blocks down.
- impl<T: Element, const ROWS: usize, const COLS: usize>
- Index<(usize, usize)> for Matrix<T, ROWS, COLS> {
+ impl<T: Element, const ROWS: usize, const COLS: usize, const TRANSPOSED: bool>
+ Index<(usize, usize)> for Matrix<T, ROWS, COLS, TRANSPOSED> {
type Output = T;
fn index(&self, (row, col): (usize, usize)) -> &T {
assert!(row < ROWS && col < COLS);
+ if TRANSPOSED {
+ &self.0[col * ROWS + row]
+ } else {
&self.0[row * COLS + col]
+ }
}
}
Pretty simple, just flip col
/COLS
with row
/ROWS
2! Remember, the underlying Vec
is always stored in
row-major order, so we’re just pretending to actually have it be transposed to
column-major.
So where’s the trick? How do we actually do this pretend transposition? It’s
really simple. It’s just this pair of impl
blocks 3:
impl<T: Element, const ROWS: usize, const COLS: usize>
Matrix<T, ROWS, COLS, false> {
pub fn transpose(self) -> Matrix<T, COLS, ROWS, true> {
Matrix(self.0)
}
}
impl<T: Element, const ROWS: usize, const COLS: usize>
Matrix<T, ROWS, COLS, true> {
pub fn transpose(self) -> Matrix<T, COLS, ROWS, false> {
Matrix(self.0)
}
}
One for each value of TRANSPOSED
. They just return a Matrix
with ROWS
and
COLS
swapped, and TRANSPOSED
inverted. It takes ownership of self
,
transferring ownership of the memory to the transposed matrix. No
copying/cloning at all! The memory is not touched. The underlying Vec
points
to the exact same memory!
#[test]
fn transpose_moves() {
let m = Matrix::new([[1, 2, 3, 4]]);
assert_eq!(m.0, [1, 2, 3, 4]);
assert_eq!(m[(0, 1)], 2);
let initial_addr = m.0.as_ptr();
let m = m.transpose();
assert_eq!(m.0, [1, 2, 3, 4]);
assert_eq!(m[(1, 0)], 2);
let transposed_addr = m.0.as_ptr();
assert_eq!(initial_addr, transposed_addr);
}
The test passes, meaning indexing works, and the memory is left where it is with only ownership being transferred! Success! We can transpose matrices of any size in constant time.
Equality
Another thing I would like to be able to do is check whether two matrices are
equal4. That’s easy, right? Just check if the underlying Vec
:s are equal!
Not so fast.
Remember the whole thing about row-major vs. column-major? And the transposition we just implemented? Yeah, since the transposition is pretend, and the underlying memory is not touched, two matrices can have the same memory contents but be unequal. Moreover, they can have different dimensions, so we shouldn’t even pass the type-check in that case. Expressed in a test, the following should pass:
#[test]
fn equality() {
let m1 = Matrix::new([
[1, 2],
[3, 4],
]);
let m2 = m1.clone();
let m3 = m1.clone().transpose();
assert_eq!(m1, m2);
assert_ne!(m1, m3); // asymmetric transpose
assert_eq!(m1.0, m2.0);
assert_eq!(m1.0, m3.0);
let m1 = Matrix::new([
[1, 2],
[2, 1],
]);
let m2 = m1.clone().transpose();
assert_eq!(m1, m2); // symmetric transpose
assert_eq!(m1.0, m2.0);
}
Equality is impossible for matrices of different dimensions, so this test sticks
to square matrices to prove that the asymmetric transposition equality check
should fail, even though all Vec
:s are identical. Now, we need to implement
PartialEq
for our matrix type. Oh no, I hear an impl
block coming…
impl<
T: Element,
const ROWS: usize,
const COLS: usize,
const LHS_T: bool,
const RHS_T: bool,
>
PartialEq<Matrix<T, ROWS, COLS, RHS_T>>
for Matrix<T, ROWS, COLS, LHS_T> {
fn eq(&self, other: &Matrix<T, ROWS, COLS, RHS_T>) -> bool {
for row in 0..ROWS {
for col in 0..COLS {
if self[(row, col)] != other[(row, col)] {
return false;
}
}
}
true
}
}
That’s right, this impl
-block has yet another const generic. Either one of
the left and right hand sides of the equals sign could have been transposed, and
we need to make sure this is implemented for every possible combination. So the
impl
needs another bool
in this case. Will it ever end 5? Other
than that, the implementation itself is very simple and utilizes the fact that
we already implemented Index
.
Element-wise addition
Now, let’s actually start doing things with our matrices! How about being able to add them? This simple test should suffice:
#[test]
fn addition() {
let a = Matrix::new([
[1, 2],
[3, 4],
]).transpose();
let b = Matrix::new([
[1, 1],
[1, 1],
]);
let c = Matrix::new([
[2, 3],
[4, 5],
]).transpose();
assert_eq!(a + b, c);
}
Okay, now let’s brace for the impl
…
impl<
T: Element + Add<Output = T>,
const ROWS: usize,
const COLS: usize,
const LHS_T: bool,
const RHS_T: bool,
>
Add<Matrix<T, ROWS, COLS, RHS_T>>
for Matrix<T, ROWS, COLS, LHS_T> {
type Output = Matrix<T, ROWS, COLS, LHS_T>;
fn add(mut self, other: Matrix<T, ROWS, COLS, RHS_T>) -> Self::Output {
for row in 0..ROWS {
for col in 0..COLS {
self[(row, col)] = self[(row, col)] + rhs[(row, col)];
}
}
self
}
}
Oh. That wasn’t so bad. There’s a new bound on T
, it has to implement the
Add
trait, outputting itself for the addition of the elements to work. Other
than that though, the implementation looks really similar to the one for
PartialEq
. And in this one we’re using the implementation of IndexMut
that I
left out for brevity to assign into the left hand matrix. I only implemented the
trait for taking ownership of the added matrices, but it’s straightforward to
implement for taking them by reference as well.
Well, since that wasn’t so bad, I guess we’ve been through the worst now! The
trait impl
:s surely won’t get much more complex than this! :)
Multiplication
Hey, did you know that matrix multiplication, unlike addition, doesn’t require
the dimensions to be identical? That’s right, the impl
for Mul
is going to
be more complex! But first, a test. Multiplication requires the left-hand side
to have as many columns as the right-hand side has rows. This means that a
matrix can always be multiplied with its transpose, which makes for a pretty
simple to write test:
#[test]
fn multiplication() {
let a = Matrix::new([
[1, 2, 3],
[4, 5, 6],
]);
let b = a.clone().transpose();
let c = Matrix::new([
[14, 32],
[32, 77],
]);
assert_eq!(a * b, c);
}
That should be good enough! Now we just need to impl Mul
. And this time I will
just go through the trait bounds and const generic on their own before the
implementation, because it will be clearer that way and I want to talk about
some weird choices in the implementation. Here is the impl<...>
part:
impl<
T: Element + Mul<Output = T> + Add<Output = T>,
const ROWS: usize,
const COLS: usize,
const MATCHING_DIM: usize,
const LHS_T: bool,
const RHS_T: bool,
>
Phew. Now, it’s the most complex one in this post, but it’s not that bad. We
can see that T
has a new Mul
trait bound, but it also has the Add
trait
bound. This is because matrix multiplication is defined by a bunch of
multiplications and additions, so both these bounds must be present. And then we
have yet another const generic. Last one, I promise. As I said before, the
left-hand side of the multiplication must have as many columns as the right-hand
side has rows. This is what MATCHING_DIM
will represent. The remaining two
dimensions of the left and right side are arbitrary, but will become the
dimensions of the output matrix. So multiplying a 2x3 matrix with a 3x4 matrix
will yield a 2x4 matrix.
Now, before I show you the implementation, I have to warn you that it’s a little
strange. It looks like this purely because I’m a nerd and liked the idea of it.
The implementation makes the following (non-)assumptions of the underlying type
T
in our matrix:
T
might not be numerical.T
might not have a multiplicative or additive identity. 6T
might notimpl Default
.
These are implicit in the trait bounds I have chosen for T
, but I figured I
would make them explicit.
Now, on to the implementation (impl
line from above excluded):
Mul<Matrix<T, MATCHING_DIM, COLS, RHS_T>>
for Matrix<T, ROWS, MATCHING_DIM, LHS_T> {
type Output = Matrix<T, ROWS, COLS, LHS_T>;
fn mul(self, rhs: Matrix<T, MATCHING_DIM, COLS, RHS_T>) -> Self::Output {
let mut vec = Vec::with_capacity(ROWS * COLS);
unsafe {
vec.set_len(vec.capacity());
}
for row in 0..ROWS {
for col in 0..COLS {
let mut products = Vec::with_capacity(MATCHING_DIM);
for i in 0..MATCHING_DIM {
products.push(
self[(row, i)] * rhs[(i, col)]
);
}
vec[row * COLS + col] = products
.iter()
.skip(1)
.fold(products[0], |acc, x| acc + *x);
}
}
Matrix(vec)
}
}
Whoa, is that an unsafe
block?? Yes, but don’t worry. It’s fine. Really, I
promise 7. Mathematically, the resulting matrix is guaranteed to have
ROWS * COLS
elements in it. And to avoid having a bunch of reallocations, I
initialize the matrix to have this capacity. But if I don’t set the length, I
can’t assign to indices in the Vec
. Rust will panic!
if I try. So into
unsafe
land we go, and we pinky-promise Rust that the Vec
does indeed have
the same length as it has capacity. So indexing into it and assigning things
there is fine. Of course, in reality this is uninitialized memory. We are just
manually initializing it.
“Why not just call Vec::resize
instead?”, you may ask. To which I reply “And
fill it with what?”. Vec::resize
takes an element to pad the Vec
with to the
desired length. And remember point 3 from above. T
might not impl Default
,
so what should we put there? Better to just leave it uninitialized since we’re
immediately going to initialize it anyway.
The rest of the implementation is relatively straighforward. Each element in the
output matrix is a sum of products from the input matrices, so the products
vector is self-explanatory. But why am I summing products
like that? Why not
just call .sum()
or .fold()
directly on the iterator? Well, that all comes
back to point 2 from above. T
might not have a multiplicative or additive
identity. So we skip the first element, use that as the initial value for the
accumulator in .fold()
, and do a folding addition like normal. This
implementation requires the dimensions to be non-zero, but we can just skip that
check for brevity 8.
If I wanted to use .sum()
, I could constrain T
to require an additive
identity by adding another trait bound requiring T
to implement
std::iter::Sum
. If we also require std::iter::Product
, we could force a
multiplicative identity to exist, and that could simplify the innermost loop of
the implementation. But again, I wanted to keep this general for fun.
And there we have it! A very simple linear algebra library. Or rather, a simple matrix addition and multiplication library. But other operations are now pretty trivial to implement and build on top of this foundation.
Conclusion
This is not how you should do things. All the const generics and trait bounds make it really hard to read and implement anything, and trying to stay too general will result in strange implementations. But boy, is it fun!
The real libraries do not use const generics and implement this stuff in much smarter ways in order to actually be useful at runtime. This was all just a fun excursion into const generic land in order to learn and see how far I could take it. If nothing else, it was a fun way to spend an evening.
Use cases?
Very few. As mentioned, all matrix dimensions are created and checked at compile-time, so any real world use for a library like this would be pretty limited.
Further work
Absolutely not. But if one wanted to expand on this, implementing subtraction, inverses, convolutions, views, scalar multiplication, and all manner of fun things should in theory be possible. Don’t blame me for any loss of sanity if you try, though. Keeping all these trait bounds and const generics in check takes a good bit of care and energy.
Vectors are just 2D matrices where one of the dimensions is 1. So matrix-vector operations will be supported. ↩︎
What about the
assert!
? Well, theROWS/COLS
is theCOLS/ROWS
of the non-transposed matrix. So we don’t have to do the swap there. It’s a bit confusing, but given some thought and playing around gives a good feeling for it. ↩︎I wish it could be one
impl
block that returnsMatrix<T, COLS, ROWS, !TRANSPOSED>
. But that doesn’t work for basically the same reason why[T; ROWS * COLS]
doesn’t. ↩︎Ignore the floating point equality monster under the bed. ↩︎
Not yet. ↩︎
I wish Rust had this, maybe as a trait of some kind. It would be very simple to implement myself, but I liked being more general for the fun of it in this case. ↩︎
Source: trust me bro ↩︎
“For brevity”, he said. In the almost 600 line, 3700 word markdown document. ↩︎