Contents

Einsum for Tensor Manipulation

In the ethereal dance of the cosmos, where the arcane whispers intertwine with the silent echoes of unseen dimensions, the Ioun Stone of Mastery emerges as a beacon of unparalleled prowess. This luminescent orb, orbiting its bearer’s head, is a testament to the mastery of both magical and mathematical realms, offering a bridge between the manipulation of arcane energies and the intricate ballet of tensor mathematics. As the stone orbits, it casts a subtle glow, its presence a constant reminder of the dual dominion it grants over the spellbinding complexities of magic and the abstract elegance of multidimensional calculations, making the wielder a maestro of both mystical incantations and the unseen algebra of the universe.

ioun-stone.png
Ioun Stone of Mastery

The Quest

Study how the Ioun Stone powers work. Understand how Einsum operates over tensors.

Einsum Uses

Einsum (and einops in general) is a great tool for manipulating tensors. In ML it is often used to implement matrix multiplication or dot products. The simplest case would look like:

x = torch.rand((4, 5))
y = torch.rand((5, 3))

# the torch way
res = x @ y

# einsum way
res = einsum(x, y, 'a b, b c -> a c')

Right now it looks like a verbose way of doing the same thing, but it sometimes presents the following advantages:

  • documenting the tensor dimensions for ease of reading
  • implicit reordering of dimensions
query = torch.rand((100, 20, 32))
key   = torch.rand((100, 20, 32))

# the torch way
keyT = key.permute((0, 2, 1))
res = query @ keyT

# einsum way
res2 = dumbsum(query, key, 'batch seq_q d_model, batch seq_k d_model -> batch seq_q seq_k')

Einsum the Iterative way

Conceptually it’s possible to think of einsum as bunch of nested loops:

  • the first set of nested loops is used to index into the inputs and output.
  • the second set of nested loops for summing all the left over dimensions that are getting reduced.

It could be written by hand as:

result = torch.zeros((10, 20, 20))
for batch in range(10):
  for seq_q in range(20):
    for seq_k in range(20):
      tot = 0
      for d_model in range(32):
        tot += query[batch, seq_q, d_model] * key[batch, seq_k, d_model]
      result[batch, seq_q, seq_k] = tot

One way to generate these nested loops is to use recursion:

def dumbsum(x, y, shapes):
  '''
  dumb implem for my own intuition building sake, with absolutely no value for real life use.
  not vectorized, and do not handle splitting / merging / creating extra dim.
  
  the main idea is to:
  1- generate nested loops for indexing for each dim in the output
  2- generate nexted loops for summing everything else
  e.g. 'a b c d e, a c e -> a d b'
  for a in range(x.shape[0]):
    for d in range(x.shape[3]):
      for b in range(x.shape[1]):
        tot = 0
        for c in range(x.shape[2]):
          for e in range(x.shape[4]):
            tot += x[a, b, c, d, e] * y[a, c, e]
        res[a, d, b] = tot

  in practice I initialize res to a tensor of zero, and update it in place instead of accumulating in a tot
  res[a, d, b] += x[a, b, c, d, e] * y[a, c, e]
  '''
  def split_shape(shape):
    return [x for x in shape.split(' ') if x]
  def parse(shapes):
    assert shapes.count(',') == 1
    assert shapes.count('->') == 1
    shapes, res_shape = shapes.split('->')
    x_shape, y_shape = shapes.split(',')
    x_shape, y_shape, res_shape = (split_shape(s) for s in (x_shape, y_shape, res_shape))
    sum_shape = list(set(x_shape + y_shape) - set(res_shape))
    assert set(res_shape).issubset(set(x_shape + y_shape))
    return x_shape, y_shape, res_shape, sum_shape
  def build_dim_lookup(t, t_shape, lookup=None):
    if not lookup: lookup = {}
    dims = t.shape
    for dim, letter in zip(dims, t_shape):
      assert lookup.get(letter, dim) == dim
      lookup[letter] = dim
    return lookup
  def iterate(shape, sum_shape, fn, lookup, indexes):
    if not shape:
      iterate_sum(sum_shape[:], fn, lookup, indexes)
      return
    dim = shape.pop(-1)
    # print(f'iterate over → {dim}')
    for i in range(lookup[dim]):
      indexes[dim] = i
      iterate(shape[:], sum_shape, fn, lookup, indexes)
  def iterate_sum(sum_shape, fn, lookup, indexes):
    if not sum_shape:
      fn(indexes)
      return
    dim = sum_shape.pop(-1)
    # print(f'sum over → {dim}')
    for i in range(lookup[dim]):
      indexes[dim] = i
      iterate_sum(sum_shape[:], fn, lookup, indexes)
  def ind(t_shape, indexes):
    return (indexes[dim] for dim in t_shape)
  def close_sum(x, y, res, x_shape, y_shape, res_shape):
    def fn(indexes):
      # print(f'res[{tuple(ind(res_shape, indexes))}] += x[{tuple(ind(x_shape, indexes))}] * y[{tuple(ind(y_shape, indexes))}]')
      res[*ind(res_shape, indexes)] += x[*ind(x_shape, indexes)] * y[*ind(y_shape, indexes)]
    return fn

  x_shape, y_shape, res_shape, sum_shape = parse(shapes)
  assert len(x_shape) == x.dim()
  assert len(y_shape) == y.dim()
  lookup = build_dim_lookup(x, x_shape)
  lookup = build_dim_lookup(y, y_shape, lookup=lookup)
  res = t.zeros(tuple(lookup[s] for s in res_shape))
  fn = close_sum(x, y, res, x_shape, y_shape, res_shape)
  iterate(res_shape[:], sum_shape[:], fn, lookup, {})
  return res

Einsum Vectorized

The loop version is great for intuition building, but it is extremely slow. Another way to implement einsum is to compose vectorized torch operations.

By hand it would look something like:

query = query[..., None] # add a seq_k dimension
key = key[..., None]     # add a seq_q dimension
query = query.permute((0, 1, 3, 2)) # align the dimensions as: batch, seq_q, seq_k, d_model
key = key.permute((0, 3, 1, 2))     # align the dimensions as: batch, seq_q, seq_k, d_model 
product = query * key # multiply element wise using implicit broadcasting
result = product.sum((3)) # reduce the extra dimension out

Which in code could look a little something like:

def dumbsum_vectorized(x, y, shapes):
  '''
  vectorize it, still do not handle splitting / merging / creating extra dim.
  my vectorized also does not handle repeated dim (e.g. 'a a b, a a c -> a a').
  
  the main idea is to:
  1- align the dimensions of x and y, completing the holes with fake `1` dimensions
  2- multiply x and y
  3- sum out the extra dims
  e.g. 'a c d e, a c e -> a d b'
  # align dims
  x = reshape('a c d e -> a 1 c d e')
  y = reshape('a c e   -> a 1 c 1 e')
  # order dims
  x = reshape('a 1 c d e -> a d 1 c e')
  y = reshape('a 1 c 1 e -> a 1 1 c e')
  # mult and sum
  res = x * y
  res = res.sum((3, 4))
  '''
  def split_shape(shape):
    return [x for x in shape.split(' ') if x]
  def parse(shapes):
    assert shapes.count(',') == 1
    assert shapes.count('->') == 1
    shapes, res_shape = shapes.split('->')
    x_shape, y_shape = shapes.split(',')
    x_shape, y_shape, res_shape = (split_shape(s) for s in (x_shape, y_shape, res_shape))
    sum_shape = list(set(x_shape + y_shape) - set(res_shape))
    assert set(res_shape).issubset(set(x_shape + y_shape))
    return x_shape, y_shape, res_shape, sum_shape
  def build_dim_pos_lookup(t_shape):
    return {letter: dim for dim, letter in enumerate(t_shape)}
  def expand(t, t_shape, merged):
    lookup = build_dim_pos_lookup(t_shape)
    ind = len(lookup)
    for dim in merged:
      if dim not in lookup:
        t = t.unsqueeze(-1)
        lookup[dim] = ind
        ind += 1
    return t, lookup
  def align(t, lookup, res_lookup):
    # rely on dict being ordered (python >= 3.7)
    permuted_dims = tuple(lookup[dim] for dim in res_lookup)
    return t.permute(permuted_dims)
  def dims_to_sum(res_shape, res_lookup):
    return tuple(range(len(res_shape), len(res_lookup)))

  x_shape, y_shape, res_shape, sum_shape = parse(shapes)
  assert len(x_shape) == x.dim()
  assert len(y_shape) == y.dim()
  merged = set(x_shape + y_shape)
  x, x_lookup = expand(x, x_shape, merged)
  y, y_lookup = expand(y, y_shape, merged)
  _, res_lookup = expand(t.zeros((0)), res_shape, merged)
  x = align(x, x_lookup, res_lookup)
  y = align(y, y_lookup, res_lookup)
  res = x * y
  dims = dims_to_sum(res_shape, res_lookup)
  if dims: res = res.sum(dims)
  return res

Compare both

Correctness

We can verify that both versions are producing the same results as the original einsum:

import torch, einops

def einops_test(x, y, pattern):
  a = dumbsum(x, y, pattern)
  b = dumbsum_vectorized(x, y, pattern)
  c = einops.einsum(x, y, pattern)
  assert a.allclose(c)
  assert b.allclose(c)

x = torch.rand((10, 5, 2, 3))
y = torch.rand((3, 10, 5, 7))
einops_test(x, y, 'a b c d, d a b e -> b e c')
einops_test(x, y, 'a b c d, d a b e -> a b c d e')
einops_test(x, y, 'a b c d, d a b e -> e d c b a')
einops_test(x, y, 'a b c d, d a b e -> a')
einops_test(x, y, 'a b c d, d a b e ->')
einops_test(x, y, 'a b c d, d a b e -> a e')

Speed

Timing the iterative version:

%%time
query = torch.rand((100, 20, 32))
key = torch.rand((100, 20, 32))
_ = dumbsum(query, key, 'batch seq_q d_model, batch seq_k d_model -> batch seq_q seq_k')
CPU times: total: 9.58 s
Wall time: 31.3 s

Against the vectorized version:

%%time
query = torch.rand((100, 20, 32))
key = torch.rand((100, 20, 32))
_ = dumbsum_vectorized(query, key, 'batch seq_q d_model, batch seq_k d_model -> batch seq_q seq_k')
CPU times: total: 0 ns
Wall time: 975 µs

Demonstrates the significant speedup brought by using vectorized code.

The code

You can get the code at https://github.com/peluche/ml-misc/blob/master/einsum-intuition.ipynb