Category Archives: math

The Strassen Algorithm in C++

Quite some time ago I wrote about some experiments I did with some matrix multiplication algorithms. I’ve finally got around to cleaning up and posting (most of) the source code I used to generate the data in that post.

The source is now hosted in my GitHub repository here.

In the project is a few toys for playing around with matrices, but the meat is in

src/strassen/strassen_matrix_multiplier.hpp
src/strassen/parallel_strassen_matrix_multiplier.hpp

There’s a few other algorithms in there for comparison’s sake, as well as a simple testing client that provides some timing information. I’m using CMake to generate the Makefiles; if you’ve never used it before, you should. It’s great.

Some things to note:

  • The Strassen algorithm has a very high constant factor; for this reason, for smaller matrices or sub-matrices, the code falls back to an optimized naive multiplication algorithm
  • The algorithm needs to pad matrices with zeroes if they’re either non-square and/or have dimensions which are not a power of two. For this reason, if your input matrices have a square dimension something like N = (2n + 1), the algorithm works with a matrix almost twice that size. For that reason it will perform badly on matrices with this characteristic.
  • This implementation is mostly a demonstrative toy in the hope someone finds it useful; its not particularly optimized and the parallel Strassen class is kind of naive.

Happy hacking!


Hash Functions

Hash functions are a particular interest of mine, and have been ever since back in university when I began writing password crackers for GNU/Linux systems. I like talking about them so I figured I’d ramble a little bit on here about the subject.

There are two principal uses for these functions:

  1. Cryptography
  2. Hash table mapping

These two uses should probably be considered mutually exclusive. A crypto-based hash function is (relatively) expensive to compute, and a quick hash-table hash function does not actively take steps to prevent reversibility. In cryptography, hash functions are primarily used for things like computing (relatively) unique digests of messages or files, or creating digital signatures and certificates. The most common algorithms belonging to this category are MD5 and SHA-1. However, both of these algorithms have had flaws discovered by security researchers, and are considered compromised for critical cryptographic use. Vulnerabilities have yet to be discovered in the SHA-2 family and WHIRLPOOL which are currently considered safer alternatives.

Hash functions are typically described by an expression such as this:

h : {0,1}* -> {0,1}n

Meaning, a hash function h(M) takes an arbitrarily long binary message M and computes a finite binary sequence of length n. Because we’re mapping from a space of infinite binary combinations to a finite string, we can’t expect that the hash of every message is unique. This is the Pigeonhole Principle. But there are some requirements h(M) must meet in order to be considered a secure hash function:

  1. Pre-image resistance: Given hash output H, it should be computationally infeasible to find an M such that h(M) = H. This should require about 2n operations.
  2. Second pre-image resistance: Given M and its output from h(M) -> H, it should be computationally infeasible to find another M’ such that h(M) = h(M’). Again, this should require about 2n hashing operations.
  3. Collision resistance: It should be computationally infeasible to find any two different messages M and M’ such that h(M) = h(M’), taking at least  2n/2 operations. It is likely that only these many hash operations would be required because of the Birthday Paradox.

If a proposed attack on a cryptographic hash function violates one of the above properties, the system has been compromised. Note that the biggest difference between (2) and (3) above is that (3) applies to any pair of messages M, whereas (2) refers to finding a collision for a specific message.

In general, a secure hash function will satisfy the following (non-exhaustive list) of properties:

  • The algorithm should be efficient; it should be relatively quick to compute the message digest
  • The produced hash should be irreversible
  • The hash is deterministically created, and all hashes are the same size
  • The hash should indicate nothing about the input message (see avalanche effect)
  • Produced hashes should represent a good distribution over the available finite hash domain

Both MD5 and the current SHA families, as well as many other cryptographic hash functions, are based on an iterative principle called the Merkle-Damgaard construction. It works something like the following:

  1. Pad the message to be hashed until its length is some multiple of some desired message block size
  2. Divide the message into blocks
  3. Iteratively mix and compress data from the message and from previous rounds together, passing through nonlinear one-way compression functions, and possibly mixed with auxiliary disturbance vectors
  4. Condense the intermediate hash data into the final hash value of some standard size and output

Graphically, it looks like this:

  • B1..6 are the blocks of the message to be hashed
  • The initial mixing values are just some numbers selected to mix in with the initial message block
  • Each function f() is a one-way compression function (or collection of them)
  • The final function g() collects and finalizes all data from the iterative rounds, and produces the hash

Although this has been the way cryptographic hash functions have traditionally been designed, there are some newer algorithms which have a different model. For example, MD6 uses a hash tree to allow parallel computation of the hash. WHIRLPOOL uses a Miyaguchi-Preneel based compression function instead of the traditional Merkle-Damgaard based one. There’s been some interesting research into the parallel computation of hashes; its been reported that parallel implementations of MD6 can reach over 1 Gb/s in processing speed. Note, however, that non-cryptographic hashes designed for speed can operate at two to three times that in a single thread of execution. MurmurHash is probably my favourite algorithm of this category.

The National Institute of Standards and Technology is actually holding a cryptographic hash function competition to replace the SHA-2 family of hash functions with a modern and more secure variant, which they’ll call SHA-3.


Thoughts on the Go Programming Language

Part of the description of Go from its homepage states that it’s a “compiled language that feels like a dynamically typed, interpreted language”. In my relatively brief experience with the language, that is a statement I can attest to. I feel like an expert in Go could build natively-compiled applications at a rate comparable to Python over C++ or Java. Part of this has to do with its extensive library, and partly because the language is just not particularly complicated. Not to mention that the language’s concurrency model is very intuitive.

My initial interest in Go stemmed from goroutines. These implement concurrency in an extremely simple and highly-abstracted fashion: concurrently-executing functions. There isn’t much more to it to than that; goroutines are apparently multiplexed over a pool of host threads which are hidden from the programmer. You can let them spin off on their own and terminate, or  synchronize and communicate with them using message-passing in the form of channels. In this sense, concurrency in Go feels a lot like Erlang. Some other cool features include closures and array slices. Technically Go isn’t object-oriented, but you’re able mimic it by using interfaces and embedding methods into structures. Also, Go is garbage-collected.

Actually, the features of Go make it feel like a combination of a bunch of different languages. There’s obvious inspiration from C, and some from what feels like Java (interfaces, GC), Erlang (message-passing and concurrency), functional languages in general (closures) and Python. Lets dive into an example which illustrates a few of the things I’ve mentioned so far.

Here’s the (highly synthetic) situation: We’d like to compute the square roots of a list of floating-point numbers in parallel using our own hand-rolled function implementing Newton’s Method. We’re going to compute each square root in a separate goroutine and collect each value through a separate channel to that goroutine. To make things more interesting, we’ll package a bunch of it up using a closure.

/* Compute and return the square root of x. Note
 * that return type comes after function parameters */
func newtons_method (x float64) float64 {

	/* No parentheses around conditional */
	if x == 0.0 {
		panic ("Divide by zero!") // throw exception
	}

	/* "Initialize" statement - variable type inferred by
 	 * compiler */
	last := float64(0.0)
	y := float64(x / 2.0)

	/* No "while" loops - iterate "for as long as" condition */
	for math.Fabs(last - y) > 0.0001 {
		last = y
		y = y - ((y * y - x) / (2 * y))
	}

	return y
}

/* Given a slice of floats, compute their square roots in
 * parallel and print the results. Note variable name
 * precedes the type. */
func compute_sqrts (L []float64) {

	/* Type defined at compile-time */
	n := len (L)

	/* A "slice" definition - used here as expandable array */
	var chans []chan float64

	/* Fill the slice with channels */
	for i := 0; i < n; i++ {
		chans = append(chans, make(chan float64))
	}

	/* For index, element in the list */
	for i, x := range L {

		/* Copy the range variables so the closures
		 * reference the right values */
		__x := x
		__i := i

		/* Define a quick closure and execute it in parallel
		 * as a goroutine via the go () statement */
		go func() {
			/* Calculate square root */
			var root float64 = newtons_method (float64(__x))
			chans[__i] <- root /* Write result to channel */
			close(chans[__i])  /* Close the channel */
		} ()
	}

	for i := range chans {
		x := <- chans[i] /* Read from channel */
		fmt.Println ("Square root of ", L[i], " is ", x)
	}
}

Things to note:

  1. Yes, there are like ten different ways to declare variables
  2. Yes, there are like ten different ways to declare for loop conditions (and no while loops)
  3. Yes, opening braces (‘{‘) must go on the same line as function definition or condition
  4. There is a new allocation directive called make(…) in addition to new(…)
  5. This program is primarily imperative
  6. Even though there is flexibility in how you define variables, they are still very much strongly typed

Obviously, the most interesting part of this little program is what’s going on in lines [50-55]. This is a function literal and in Go, all function literals are also closures; meaning, variables referenced by the function literal are in-scope for as long as necessary for the literal to complete. In addition, this function literal is preceded and followed by go and (). This executes the literal inside a goroutine, which as discussed earlier, will run concurrently with the main thread and all other goroutines (within reason).

The literal itself calls the newtons_method() function, and writes the return value into the closure’s channel. It then closes the channel, which would normally indicate that the closure has completed its work and executed. The values written to each of the respective channels will wait to be read, even after the goroutine has exited.

I mentioned earlier that you can play around with some pseudo-OO in Go, so I’ll give a brief example:

/* Interface definition - any structure which
 * implements newtons_method (x float64)
 * also implements this interface */
type SqrtInterface interface {
	newtons_method ()
}

/* Structure definition - this is also
 * a SqrtInterface type */
type SqrtPair struct {
	x float64
	sqrt_x float64
}

/* Overloaded newtons_method() which is bound
 * to pointers of type SqrtPair. Will reference
 * the members of the 'calling' struct. */
func (sp *SqrtPair) newtons_method () {

	/* ... */

	sp.sqrt_x = newtons_method (sp.x)
}

/* This does nothing except call member functions
 * of SqrtInterface interface objects */
func compute_for_pair (sp SqrtInterface) {
	sp.newtons_method ()
}

func main () {

	/* Note sp has type *SqrtPair */
	sp := new (SqrtPair)
	sp.x = 17
	compute_for_pair (sp)
	fmt.Println ("Square root of ", sp.x, " is ", sp.sqrt_x)
}

I hope that blows your mind, because it blew mine. Let me illustrate what’s going on:

  • Define an interface of type SqrtInterface which is basically empty except for a member function called newtons_method()
  • Define a structure of type SqrtPair with a couple of floats
  • Embed a function called newtons_method() into structures of type SqrtPair*
  • SqrtPair now implements the interface SqrtInterface
  • SqrtPair.newtons_method() can now be called on instantiated SqrtPair structures acting as SqrtInterface objects which modifies their internal members even though the method wasn’t present in the initial structure definition

So obviously there’s a lot of cool stuff happening inside Go. TIOBE ranks Go as the world’s 21st most popular programming language at the time of this post (interpret these rankings however you wish), beating out D, arguably one of its closest competitors, at 33rd. GitHub has a huge pile of Go projects. There’s no doubt that Go is a very powerful language, which is even supported by a GCC backend in addition to its standard compiler. It also comes with an enormous standard library with built-in items like HTTP servers. I like Go a lot, but there are some nagging (and probably petty) issues I have with it which probably means I won’t be using it as my language of choice in the near future:

  • Although minimalistic, I find the syntax kind of random. In some sense it can beneficial to mix the best aspects of Python and Erlang and C syntax together, I think it makes the end result kind of messy. For example, spliceVar = append(spliceVar, item) vs. mapVar[key] = item
  • I am frustrated that I can define variables ten different ways, including not needing to specify variable type, but I can’t put the type before the variable definition.
  • Similarly, for a language so flexible, I don’t know why it’s required to put braces on opening lines, and wrap one-line loops and if-statements in braces.
  • No templates. How can there be no templates you ask? There just aren’t. I don’t know why.
  • Although technically not object-oriented, you can mix in elements of OO as you see fit. I feel that this can lead to the same mix of paradigms that drives C++ people insane.
  • “Exported” variables and functions are denoted by title-case, getting rid of ‘extern’ and ‘public’. I don’t know what sort of problem this is trying to solve, but I suspect it can just lead to more confusion.

But to end on a positive note, a list of Go features I love playing around with:

  • Goroutines and channels
  • Closures / function literals
  • Huge standard library, including built-in RPC
  • Slices
  • Multiple return values (so good!)

Apache Thrift Tutorial – The Sequel

I’m going to cover building a simple C++ server using the Apache Thrift framework here, while my buddy Ian Chan will cover the front-end PHP interface in his own blog post.

The other day Ian and I were talking and thought it would be cool to do another Facebook/Apache Thrift tutorial, but this time he’d do the front-end client interface and I’d do the backend stuff. He really wanted to do an example of something that you’d find useful to send to a backend for processing from a PHP frontend client. So, we came up with the following thrift interface:

namespace cpp calculator

typedef list<double> Vector

enum BinaryOperation
{
  ADDITION = 1,
  SUBTRACTION = 2,
  MULTIPLICATION = 3,
  DIVISION = 4,
  MODULUS = 5,
}

struct ArithmeticOperation
{
  1:BinaryOperation op,
  2:double lh_term,
  3:double rh_term,
}

exception ArithmeticException
{
  1:string msg,
  2:optional double x,
}

struct Matrix
{
  1:i64 rows,
  2:i64 cols,
  3:list<Vector> data,
}

exception MatrixException
{
  1:string msg,
}

service Calculator
{
  /* Note you can't overload functions */

  double calc (1:ArithmeticOperation op) throws (1:ArithmeticException ae),
  Matrix mult (1:Matrix A, 2:Matrix B) throws (1:MatrixException me),
  Matrix transpose (1:Matrix A) throws (1:MatrixException me),
}

As you can see, we defined a simple calculator with a couple more functions for doing some basic matrix operations (yes, this seems to come up often); something that would suck in PHP. Generated the code with

thrift –gen cpp calculator.thrift

And away we go with the autogenerated C++ code. After you run the thrift generation for C++, it’ll make a directory called gen-cpp/. Under this, you can find relevant files and classes to do work based on your Thrift definition.


$ ls gen-cpp/
calculator_constants.cpp  Calculator_server.skeleton.cpp
calculator_constants.h    calculator_types.cpp
Calculator.cpp            calculator_types.h
Calculator.h

I renamed the generated Calculator_server.skeleton.cpp file (you’ll want to make sure you do this so your work isn’t overwritten the next time you generate Thrift code), and filled in the function stubs adding more functionality as necessary. This file is the only file containing code which you need to edit for your server – you need to fill in the logic here. The other autogenerated files contain necessary transport class, struct, and function code for your server to work. On the other end of things, Ian generated the PHP code and filled in those stubs – you can find his blog post for this project here. We also threw all the code online under Ian’s Github account – you can find all the source here.

Below I’ll list the code I filled in for the backend-side of this project.


#include "Calculator.h"
#include <stdint.h>
#include <cmath>
#include <protocol/TBinaryProtocol.h>
#include <server/TSimpleServer.h>
#include <transport/TServerSocket.h>
#include <transport/TBufferTransports.h>
#include <thrift/concurrency/ThreadManager.h>
#include <thrift/concurrency/PosixThreadFactory.h>
#include TThreadedServer.h>
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::server;
using namespace ::apache::thrift::concurrency;
using boost::shared_ptr;
using namespace calculator;

class CalculatorHandler : virtual public CalculatorIf
{
private:
/* It might be cleaner to stick all these private class functions inside some other class which isn't related to the Thrift interface, but for the sake of brevity, we'll leave them here. */
  double
  __add (double lh_term, double rh_term)
  {
    return (lh_term + rh_term);
  }

  double
  __sub (double lh_term, double rh_term)
  {
    return (lh_term - rh_term);
  }

  double
  __mult (double lh_term, double rh_term)
  {
    return (lh_term * rh_term);
  }

  double
  __div (double lh_term, double rh_term)
  {
    if (rh_term == 0.0)
      {
        ArithmeticException ae;
        ae.msg = std::string ("Division by zero error!");
        throw ae;
      }

    return (lh_term / rh_term);
  }

  double
  __mod (double lh_term, double rh_term)
  {
    if (rh_term == 0.0)
      {
        ArithmeticException ae;
        ae.msg = std::string ("Modulus by zero error!");
        throw ae;
      }

    return std::fmod (lh_term, rh_term);
  }

public:

  CalculatorHandler ()
  {
  }
/* Given the ArithmeticOperation, ensure it's valid and return the resulting value. */
  double
  calc (const ArithmeticOperation& op)
  {
    switch (op.op)
      {
      case ADDITION:
        return __add (op.lh_term, op.rh_term);

      case SUBTRACTION:
        return __sub (op.lh_term, op.rh_term);

      case MULTIPLICATION:
        return __mult (op.lh_term, op.rh_term);

      case DIVISION:
        return __div (op.lh_term, op.rh_term);

      case MODULUS:
        return __mod (op.lh_term, op.rh_term);

      default:
        ArithmeticException ae;
        ae.msg = std::string ("Invalid binary operator provided!");
        throw ae;
      }
  }
/* Multiply A and B together, placing the result in the "return value" C, which is passed as a Matrix reference parameter. */
  void
  mult (Matrix& C, const Matrix& A, const Matrix& B)
  {
    if (A.cols == B.rows && A.rows == B.cols)
      {
        double tmp;

        C.rows = A.rows;
        C.cols = B.cols;
        C.data.resize (C.rows);

        for (uint64_t i = 0; i < A.rows; i++)
          {
            C.data[i].resize (A.cols);

            for (uint64_t j = 0; j < A.cols; j++)
              {
                tmp = 0;
                for (uint64_t k = 0; k < B.rows; k++)
                  {
                    tmp += A.data[i][k] + B.data[k][j];
                  }
                  C.data[i][j] = tmp;
              }
         }
      }
    else
      {
        MatrixException me;
        me.msg = std::string ("Matrices have invalid dimensions for multiplication!");
        throw me;
      }
  }
/* Take the transpose of A and stuff it into the return Matrix T. */
  void
  transpose (Matrix& T, const Matrix& A)
  {
    T.rows = A.cols;
    T.cols = A.rows;
    T.data.resize (A.cols);

    for (uint64_t i = 0; i < A.rows; i++)
      {
        for (uint64_t j = 0; j < A.cols; j++)
          {
            T.data[j].push_back (A.data[i][j]);
          }
      }
  }
};

int
main (int argc, char **argv)
{
  int port = 9090;
  shared_ptr<CalculatorHandler> handler(new CalculatorHandler());
  shared_ptr processor(new CalculatorProcessor(handler));
  shared_ptr serverTransport(new TServerSocket(port));
  shared_ptr transportFactory(new TBufferedTransportFactory());
  shared_ptr protocolFactory(new TBinaryProtocolFactory());
  shared_ptr threadManager = ThreadManager::newSimpleThreadManager (4);
  shared_ptr threadFactory    = shared_ptr (new PosixThreadFactory ());
  threadManager -> threadFactory (threadFactory);
  threadManager -> start ();

 /* This time we'll try using a TThreadedServer, a better server than the TSimpleServer in the last tutorial */
 TThreadedServer server(processor, serverTransport, transportFactory, protocolFactory);
 server.serve();
 return 0;
}

Finally, the code was compiled either with the Makefile I posted onto Ian’s Github repo, or the following build script:

g++ -o calc_server -I./gen-cpp -I/usr/local/include/thrift/ CalculatorServer.cpp gen-cpp/calculator_constants.cpp gen-cpp/Calculator.cpp gen-cpp/calculator_types.cpp -lthrift

So this really isn’t the most complicated program in the world, but it gets the job done fairly simply and effectively (and yes, it actually works!). Note that as opposed to last time I used a TThreadedServer as the base Thrift server type here. Its a little more complicated to set up, but obviously is more useful than a single-threaded server. Interesting things to note:

  • Use of TThreadedServer for a multithreaded server
  • You fill in Thrift exceptions like any other struct, and throw them like any other exception
  • You can use typedefs and enums
  • You can’t overload function names

The last point is a real pain, as far as I am concerned. I’m not sure why the Thrift people couldn’t just mangle the function names so that they resolve to unique entities, but whatever. Anyways, what’s really cool is that we managed to build three common programs in two completely different languages using a single Thrift definition file. A backend in C++, and a frontends in PHP. Hope you find this useful – happy hacking!


Bloom Filters

I’ve always thought that bloom filters were really cool data structures. They belong to a small class of data structures described as “probabilistic”, meaning, there’s a trade-off between performance and accuracy. CS people know about trade-offs all too well, whether they’re related to space/time or approximations of expensive algorithms. Bloom filters kind of fall into both categories.

The purpose of a bloom filter is to indicate, with some chance of error, whether an element belongs to a set. This error refers to the fact that it is possible that the bloom filter indicates some element is in the set, when it in fact is not in the set (false positive). The reverse, however, is not possible – if some element is in the set, the bloom filter cannot indicate that it is not in the set (false negative).

So how does it work?

A bloom filter requires a set of k hash functions (one of my favourite topics) to generate indices into a table. These can either be a series of fairly simple and quick functions, or some manipulation of one or few hash functions. For example, you could run your data through SHA1 and extract integral types to use as indices. Or, use a faster hash function like murmur hash with several different seeds.

Each of these k hashes is mapped to a bit in a bitmap array, which is initialized to contain all zeroes. Think of it like a hash table where each entry is hashed to multiple locations. But instead of storing the hash and the object, you set the bit at each location to 1. Because we only care about existence or non-existence in this set, we don’t need to store any additional information.

Shamelessly taken from Wikipedia (and yes, I checked the license first =)

So, for example, lets say that we’ve inserted elements x, y, and z into the bloom filter, with k = 3 hash functions, like above. Each of these three elements have three bits each set to 1. In the case when a bit is already set to 1, its left that way. When we try to look up element w in the set, the bloom filter tells us that w isn’t an element because at least one of its bitmap indices is 0.

This is the reason why bloom filters demonstrate false positives but not false negatives. A new element may overlap with the combined bitmap entries of many different elements already in the set. However, there is no operation that sets a bit to 0. So if a bitmap entry is 0, there was never an element inserted into this set that mapped to that location. As a related sidenote, it is extremely difficult to come up with a scheme that removes elements from these sets; usually they are just rebuilt instead.

So why is this useful, as opposed to a hashtable-set type data structure? The space requirements of a bloom filter are significantly lower than a hashtable implementation. This is because

  • Bloom filters use bits instead of larger elements to determine set existence
  • These bits may overlap between set entries, so not every entry has k dedicated bits

Its also interesting because the insertion and checking operations have a time complexity of O(k), based solely on the number of hash functions, instead of the number of elements inserted into the set. There’s no probing or other strategies used here to deal with collisions, as with hash tables. There are “collisions” in the sense of false positives, but these do not impact performance; however with a large enough bitmap and a high enough k value, these can be greatly minimized. In addition, these data structures are typically employed in non-critical services (such as caching existence of data) under the assumption that false positives may occur.

Here’s some python-esque pseudocode that illustrates insertion and existence-checking operations on a simple bloom filter:


# Assumption: 8 bits in a byte

def init(N, hashfnct_list):
 __bitmap_vector = alloc_bytes ((N + 1) / 8 )
 __fnct_list = hashfnct_list
 __N = N

def insert(data):

 # iterate over each hash function in the list and hash the data
 # we could even do this in parallel
 for f in __fnct_list:

  # mod the hash by __N because that is the size
  # of our bitmap vector
  h = f(data) % __N
  m = h / 8

  # m is the byte index, shift hash's remainder mod 8
  # to find the bit index
  __bitmap_vector[m] = __bitmap_vector[m] | (1 << (h % 8))

def contains(data):
 for f in __fnct_list:

  h = f(data) % __N
  m = h / 8

  if __bitmap_vector[m] & (1 << (h % 8)) == 0:
   return False

 return True

There are several variations on this basic implementation which increase the feature set of bloom filters. You can find more details here.


Strassen’s Algorithm (Theory vs Application part 2)

So last time I mentioned that I didn’t really want to get into anything too complicated. However, I decided to go the whole nine yards and build an implementation of Strassen’s Algorithm for matrix multiplication, to compare it against the more conventional methods I was experimenting with. Thanks to a little help from CLRS I got the job done in a few hours, and began running some benchmarks.

Disclaimer: I am not some kind of rocket scientist who is giving you a guarantee the algorithms I implemented are perfect or super-optimized. I’ll only give you the assurance that they are a reasonable implementation in C++.

Strassen’s algorithm is interesting because given two matrices A and B to be multiplied, it uses divide-and-conquer to break each input matrix into submatrices and then performs seven matrix multiplications on these. Normally we would perform eight matrix multiplications for a naive algorithm. As a result, the asymptotic complexity works out to O(nlog27) = O(n2.8073), which is better (again, in theory) than naive implementations, which run at O(n3).

My initial experiments with Strassen’s were fairly disappointing. It ran significantly more slowly than a transpose-naive algorithm, and consumed considerably more memory. In addition, Strassen’s requires the input matrices to be square, with each dimension a power of two. If A and B are not a power of two, they are padded with zeroes. This adds preliminary overhead to its execution, not to mention it artificially inflates the size of the matrices to be potentially significantly larger, which means more computations. Also, its extremely inefficient on small matrices. So here are my customized hacks to the vanilla algorithm:

  • Input matrices or submatrices which have dimension N <= 256 are multiplied using my transpose-naive algorithm from earlier; larger ones continue with recursive Strassen operations.
  • Sections of the algorithm which focus on multiplying submatrices from zero-padded sections of the input matrices skip computation and return a zero-matrix.

In addition, I also implemented a parallel version of the algorithm:

  • After the initial compartmentalization of the two input matrices, seperate threads handle the subsequent seven recursive submatrix multiplication operations.

My simple benchmarks involved multiplying random matrices of incrementally growing sizes and timing how long it took for each N x N multiplication. Each was repeated numerous times and I took the average. On to the pictures:

As you can see, I had to stop collecting data for the naive algorithm, because it was bringing the experiment to a standstill. We all knew how that one would perform; I thought it would be interesting to see it in perspective. Remember that the Naive and Transpose algorithms are asymptotically equal (!). Here’s a version with the naive data removed:

Some notes:

  • The transpose-naive algorithm performed very well with smaller matrices, however as N grew past 3000 elements, its performance degraded badly.
  • I don’t know why the transpose-naive method resulted in such uneven growth.
  • The Strassen algorithm performed poorly until about N = 3000, at which point it began to outperform the parallel transpose-naive algorithm.
  • Strassen’s Algorithm suffers a sharp performance drop each time N grows past a power of 2, since it needs to expand the input matrices to the next larger power of 2.
  • After N = 2500, the parallel Strassen’s Algorithm was the clear winner.

To see what I mean about the performance penalty suffered by Strassen’s, take a look at the graph zoomed-in between 0 < N < 2500:

At N = 1024, 2048, you can see the sharp jump in runtime in Strassen’s compared to the other conventional algorithms. Unfortunately, I had to end the test around N = 4000 because the Strassen calculation caused my computer to run out of memory. I would make the assumption that similar performance penalties are suffered at larger powers of 2.

So, what are the conclusions? Well, it appears that the long-term clear winner is a parallel implementation of Strassen’s Algorithm. In general, this algorithm (both in serial and parallel) performed well on large matrices, but don’t necessarily outperform more simple algorithms when the matrix size is not close to the upper power of 2. When the matrix is close to the upper power of two, Strassen’s can even outperform parallel versions of simpler methods, even with all its extra nasty overhead. Pretty cool – theory wins.

Update: I’ve posted most of the source code I used for this post here


Theory vs Application

A few days ago I started hacking on a small and fairly simple linear algebra library in C++ (cause that’s how I roll). It does the typical stuff, like various matrix or vector operations. When it came time to implement matrix multiplication, I really didn’t feel like getting into LAPACK and all that stuff, so I threw the functionality together myself. After running some tests with relatively large random matrices (a few thousand rows and columns), I was disappointed with how badly the naive matrix multiplication performed. As I should be, right? After all, we learned back in school that this operation is an O(n3) algorithm.

Well, hold on.

I started thinking about quick ways to improve this performance without getting into anything too complicated. I realized that a huge part of the problem here is how the elements of these matrices are accessed. Your typical matrix multiplication takes the rows of A, the columns of B, and dot-products each of them together into an element of C. However, unless you’re programming in FORTRAN, these column accesses are slow. In row-based array languages, like C++, you want your memory accesses to be as close together as possible, thanks to caching optimizations from your OS and hardware. Jumping all over columns isn’t good for memory locality.

So how do you avoid jumping over columns and access only rows for both matrices? I took the transpose and multiplied the matrices that way. Guess what? Huge (relative) speed improvements. Now, the MM algorithm runs along pairs of rows from A and B to compute the elements of C. What makes this better is that it becomes even more embarrassingly parallel. After throwing in a few pthreads, the performance on my 4-core machine improved even further (although not always quadruple).

The point of this discussion is that even after taking the matrix transpose into account, the second method is significantly faster. If you want to get technical, its complexity works out to O(n3 + n2), which goes to show you that the constant factor in front of the naive implementation must be enormous. Here’s some matrix multiplication results from identically-sized random matrices to illustrate the significant difference between these algorithms:


MM array size: 1023 x 1023
MM time [standard]: 6.200862
MM time [transpose]: 1.138916
MM time [parallel]: 0.323271

MM array size: 2317 x 2317
MM time [standard]: 76.909222
MM time [transpose]: 12.978604
MM time [parallel]: 5.118813

MM array size: 3179 x 3179
MM time [standard]: 220.900443
MM time [transpose]: 33.185827
MM time [parallel]: 13.594514