Wednesday, March 9, 2011

Making numpy ndarray's Hashable

I am doing some research work using numpy and scipy, and I have been amazed by how fast they run; the other day, though, I stumbled on what seemed a roadblock. I wanted to use these quite large (length 1024) arrays as keys for storing data in Python dictionaries, however numpy's ndarray objects are not hashable; converting them to tuples would take forever, besides eating up an unbelievably large amount of memory. Searching the web provided no satisfactory answers.

After a while, though, I realized that it would be simple to implement a hashable wrapper for my array objects. So here you have it:
from hashlib import sha1

from numpy import all, array, uint8


class hashable(object):
    r'''Hashable wrapper for ndarray objects.

        Instances of ndarray are not hashable, meaning they cannot be added to
        sets, nor used as keys in dictionaries. This is by design - ndarray
        objects are mutable, and therefore cannot reliably implement the
        __hash__() method.

        The hashable class allows a way around this limitation. It implements
        the required methods for hashable objects in terms of an encapsulated
        ndarray object. This can be either a copied instance (which is safer)
        or the original object (which requires the user to be careful enough
        not to modify it).
    '''
    def __init__(self, wrapped, tight=False):
        r'''Creates a new hashable object encapsulating an ndarray.

            wrapped
                The wrapped ndarray.

            tight
                Optional. If True, a copy of the input ndaray is created.
                Defaults to False.
        '''
        self.__tight = tight
        self.__wrapped = array(wrapped) if tight else wrapped
        self.__hash = int(sha1(wrapped.view(uint8)).hexdigest(), 16)

    def __eq__(self, other):
        return all(self.__wrapped == other.__wrapped)

    def __hash__(self):
        return self.__hash

    def unwrap(self):
        r'''Returns the encapsulated ndarray.

            If the wrapper is "tight", a copy of the encapsulated ndarray is
            returned. Otherwise, the encapsulated ndarray itself is returned.
        '''
        if self.__tight:
            return array(self.__wrapped)

        return self.__wrapped
Using the wrapper class is simple enough:
>>> from numpy import arange
>>> a = arange(0, 1024)
>>> d = {}
>>> d[a] = 'foo'
TypeError: unhashable type: 'numpy.ndarray'
>>> b = hashable(a)
>>> d[b] = 'bar'
>>> d[b]
'bar'
In my profiling sessions, adding the wrapped-up 1024-long arrays as keys to a dictionary amounted to no more overhead than adding the naked arrays to a list.