How to “perfectly” override a dict?

后端 未结 5 2014
情话喂你
情话喂你 2020-11-22 08:14

How can I make as \"perfect\" a subclass of dict as possible? The end goal is to have a simple dict in which the keys are lowercase.

It would seem

5条回答
  •  谎友^
    谎友^ (楼主)
    2020-11-22 09:08

    My requirements were a bit stricter:

    • I had to retain case info (the strings are paths to files displayed to the user, but it's a windows app so internally all operations must be case insensitive)
    • I needed keys to be as small as possible (it did make a difference in memory performance, chopped off 110 mb out of 370). This meant that caching lowercase version of keys is not an option.
    • I needed creation of the data structures to be as fast as possible (again made a difference in performance, speed this time). I had to go with a builtin

    My initial thought was to substitute our clunky Path class for a case insensitive unicode subclass - but:

    • proved hard to get that right - see: A case insensitive string class in python
    • turns out that explicit dict keys handling makes code verbose and messy - and error prone (structures are passed hither and thither, and it is not clear if they have CIStr instances as keys/elements, easy to forget plus some_dict[CIstr(path)] is ugly)

    So I had finally to write down that case insensitive dict. Thanks to code by @AaronHall that was made 10 times easier.

    class CIstr(unicode):
        """See https://stackoverflow.com/a/43122305/281545, especially for inlines"""
        __slots__ = () # does make a difference in memory performance
    
        #--Hash/Compare
        def __hash__(self):
            return hash(self.lower())
        def __eq__(self, other):
            if isinstance(other, CIstr):
                return self.lower() == other.lower()
            return NotImplemented
        def __ne__(self, other):
            if isinstance(other, CIstr):
                return self.lower() != other.lower()
            return NotImplemented
        def __lt__(self, other):
            if isinstance(other, CIstr):
                return self.lower() < other.lower()
            return NotImplemented
        def __ge__(self, other):
            if isinstance(other, CIstr):
                return self.lower() >= other.lower()
            return NotImplemented
        def __gt__(self, other):
            if isinstance(other, CIstr):
                return self.lower() > other.lower()
            return NotImplemented
        def __le__(self, other):
            if isinstance(other, CIstr):
                return self.lower() <= other.lower()
            return NotImplemented
        #--repr
        def __repr__(self):
            return '{0}({1})'.format(type(self).__name__,
                                     super(CIstr, self).__repr__())
    
    def _ci_str(maybe_str):
        """dict keys can be any hashable object - only call CIstr if str"""
        return CIstr(maybe_str) if isinstance(maybe_str, basestring) else maybe_str
    
    class LowerDict(dict):
        """Dictionary that transforms its keys to CIstr instances.
        Adapted from: https://stackoverflow.com/a/39375731/281545
        """
        __slots__ = () # no __dict__ - that would be redundant
    
        @staticmethod # because this doesn't make sense as a global function.
        def _process_args(mapping=(), **kwargs):
            if hasattr(mapping, 'iteritems'):
                mapping = getattr(mapping, 'iteritems')()
            return ((_ci_str(k), v) for k, v in
                    chain(mapping, getattr(kwargs, 'iteritems')()))
        def __init__(self, mapping=(), **kwargs):
            # dicts take a mapping or iterable as their optional first argument
            super(LowerDict, self).__init__(self._process_args(mapping, **kwargs))
        def __getitem__(self, k):
            return super(LowerDict, self).__getitem__(_ci_str(k))
        def __setitem__(self, k, v):
            return super(LowerDict, self).__setitem__(_ci_str(k), v)
        def __delitem__(self, k):
            return super(LowerDict, self).__delitem__(_ci_str(k))
        def copy(self): # don't delegate w/ super - dict.copy() -> dict :(
            return type(self)(self)
        def get(self, k, default=None):
            return super(LowerDict, self).get(_ci_str(k), default)
        def setdefault(self, k, default=None):
            return super(LowerDict, self).setdefault(_ci_str(k), default)
        __no_default = object()
        def pop(self, k, v=__no_default):
            if v is LowerDict.__no_default:
                # super will raise KeyError if no default and key does not exist
                return super(LowerDict, self).pop(_ci_str(k))
            return super(LowerDict, self).pop(_ci_str(k), v)
        def update(self, mapping=(), **kwargs):
            super(LowerDict, self).update(self._process_args(mapping, **kwargs))
        def __contains__(self, k):
            return super(LowerDict, self).__contains__(_ci_str(k))
        @classmethod
        def fromkeys(cls, keys, v=None):
            return super(LowerDict, cls).fromkeys((_ci_str(k) for k in keys), v)
        def __repr__(self):
            return '{0}({1})'.format(type(self).__name__,
                                     super(LowerDict, self).__repr__())
    

    Implicit vs explicit is still a problem, but once dust settles, renaming of attributes/variables to start with ci (and a big fat doc comment explaining that ci stands for case insensitive) I think is a perfect solution - as readers of the code must be fully aware that we are dealing with case insensitive underlying data structures. This will hopefully fix some hard to reproduce bugs, which I suspect boil down to case sensitivity.

    Comments/corrections welcome :)

提交回复
热议问题