Arbitrarily nested dictionary from tuples

前端 未结 2 979
猫巷女王i
猫巷女王i 2020-12-19 18:21

Given a dictionary with tuples as keys (and numbers/scalars as values), what is a Pythonic way to convert to a nested dictionary? The hitch is that from input-to-input, the

相关标签:
2条回答
  • 2020-12-19 19:11

    You can use itertools.groupby with recursion:

    from itertools import groupby
    def create_nested_dict(d):
      _c = [[a, [(c, d) for (_, *c), d in b]] for a, b in groupby(sorted(d, key=lambda x:x[0][0]), key=lambda x:x[0][0])]
      return {a:b[0][-1] if not any([c for c, _ in b]) else create_nested_dict(b) for a, b in _c}
    

    from itertools import product
    
    d1 = dict(zip(product('AB', [0, 1]), range(2*2)))
    d2 = dict(zip(product('AB', [0, 1], [True, False]), range(2*2*2)))
    d3 = dict(zip(product('CD', [0, 1], [True, False], 'AB'), range(2*2*2*2)))
    print(create_nested_dict(d1.items()))
    print(create_nested_dict(d2.items()))
    print(create_nested_dict(d3.items())) 
    

    Output:

    {'A': {0: 0, 1: 1}, 'B': {0: 2, 1: 3}}
    {'A': {0: {False: 1, True: 0}, 1: {False: 3, True: 2}}, 'B': {0: {False: 5, True: 4}, 1: {False: 7, True: 6}}}
    {'C': {0: {False: {'A': 2, 'B': 3}, True: {'A': 0, 'B': 1}}, 1: {False: {'A': 6, 'B': 7}, True: {'A': 4, 'B': 5}}}, 'D': {0: {False: {'A': 10, 'B': 11}, True: {'A': 8, 'B': 9}}, 1: {False: {'A': 14, 'B': 15}, True: {'A': 12, 'B': 13}}}}
    
    0 讨论(0)
  • 2020-12-19 19:19

    Just loop over each key, and use all but the last element of the key to add dictionaries. Keep a reference to the last dictionary so set, then use the last element in the key tuple to actually set a key-value pair in the output dictionary:

    def nest(d: dict) -> dict:
        result = {}
        for key, value in d.items():
            target = result
            for k in key[:-1]:  # traverse all keys but the last
                target = target.setdefault(k, {})
            target[key[-1]] = value
        return result
    

    You could use functools.reduce() to handle the traversing-down-the-dictionaries work:

    from functools import reduce
    
    def nest(d: dict) -> dict:
        result = {}
        traverse = lambda r, k: r.setdefault(k, {})
        for key, value in d.items():
            reduce(traverse, key[:-1], result)[key[-1]] = value
        return result
    

    I used dict.setdefault() rather than the auto-vivication defaultdict(nested_dict) option, as this produces a regular dictionary that won't further implicitly add keys when they don't yet exist.

    Demo:

    >>> from pprint import pprint
    >>> pprint(nest(d1))
    {'A': {0: 0, 1: 1}, 'B': {0: 2, 1: 3}}
    >>> pprint(nest(d2))
    {'A': {0: {False: 1, True: 0}, 1: {False: 3, True: 2}},
     'B': {0: {False: 5, True: 4}, 1: {False: 7, True: 6}}}
    >>> pprint(nest(d3))
    {'C': {0: {False: {'A': 2, 'B': 3}, True: {'A': 0, 'B': 1}},
           1: {False: {'A': 6, 'B': 7}, True: {'A': 4, 'B': 5}}},
     'D': {0: {False: {'A': 10, 'B': 11}, True: {'A': 8, 'B': 9}},
           1: {False: {'A': 14, 'B': 15}, True: {'A': 12, 'B': 13}}}}
    

    Note that the above solution is a clean O(N) loop (N being the length of the input dictionary), whereas a groupby solution as proposed by Ajax1234 has to sort the input to work, making that a O(NlogN) solution. That means that for a dictionary with 1000 elements, a groupby() would need 10.000 steps to produce the output, whereas an O(N) linear loop only takes 1000 steps. For a million keys, this increases to 20 million steps, etc.

    Moreover, recursion in Python is.. slow, as Python can't optimise such solutions to an iterative approach. Function calls are relatively expensive, so using recursion can carry significant performance costs as you greatly increase the number of function calls and by extension frame stack operations.

    A time trial shows by how much this matters; using your sample d3 and 100k runs, we easily see a 5x speed difference:

    >>> from timeit import timeit
    >>> timeit('n(d)', 'from __main__ import create_nested_dict as n, d3; d=d3.items()', number=100_000)
    8.210276518017054
    >>> timeit('n(d)', 'from __main__ import nest as n, d3 as d', number=100_000)
    1.6089267160277814
    
    0 讨论(0)
提交回复
热议问题