Hierarchical data: efficiently build a list of every descendant for each node

前端 未结 4 789
闹比i
闹比i 2020-12-24 04:26

I have a two column data set depicting multiple child-parent relationships that form a large tree. I would like to use this to build an updated list of every descendant for

4条回答
  •  悲哀的现实
    2020-12-24 05:06

    Here is a method which builds a dict to allow easier navigation of the tree. Then runs the tree once and adds the children to their grand parents and above. And finally adds the new data to the dataframe.

    Code:

    def add_children_of_children(dataframe, root_node):
        # build a dict of lists to allow easy tree descent
        tree = {}
        for idx, (child, parent) in dataframe.iterrows():
            tree.setdefault(parent, []).append(child)
    
        data = []
    
        def descend_tree(parent):
            # get list of children of this parent
            children = tree[parent]
    
            # reverse order so that we can modify the list while looping
            for child in reversed(children):
                if child in tree:
    
                    # descend tree and find children which need to be added
                    lower_children = descend_tree(child)
    
                    # add children from below to parent at this level
                    data.extend([(c, parent) for c in lower_children])
    
                    # return lower children to parents above
                    children.extend(lower_children)
    
            return children
    
        descend_tree(root_node)
    
        return dataframe.append(
            pd.DataFrame(data, columns=dataframe.columns))
    

    Timings:

    There are three test methods in the test code, seconds from a timeit run:

    • 0.073 - add_children_of_children() from above.
    • 0.153 - add_children_of_children() with the output sorted.
    • 3.385 - original get_ancestry_dataframe_flat() pandas implementation.

    So a native data structure approach is considerably faster than the original implementation.

    Test Code:

    import pandas as pd
    
    df = pd.DataFrame(
        {
            'child': [3102, 2010, 3011, 3000, 3033, 2110, 3111, 2100],
            'parent': [2010, 1000, 2010, 2110, 2100, 1000, 2110, 1000]
        }, columns=['child', 'parent']
    )
    
    def method1():
        # the root node is the node which is not a child
        root = set(df.parent) - set(df.child)
        assert len(root) == 1, "Number of roots != 1 '{}'".format(root)
        return add_children_of_children(df, root.pop())
    
    def method2():
        dataframe = method1()
        names = ['ancestor', 'descendant']
        rename = {o: n for o, n in zip(dataframe.columns, reversed(names))}
        return dataframe.rename(columns=rename) \
            .sort_values(names).reset_index(drop=True)
    
    def method3():
        return get_ancestry_dataframe_flat(df)
    
    def get_ancestry_dataframe_flat(df):
    
        def get_child_list(parent_id):
    
            list_of_children = list()
            list_of_children.append(
                df[df['parent'] == parent_id]['child'].values)
    
            for i, r in df[df['parent'] == parent_id].iterrows():
                if r['child'] != parent_id:
                    list_of_children.append(get_child_list(r['child']))
    
            # flatten list
            list_of_children = [
                item for sublist in list_of_children for item in sublist]
            return list_of_children
    
        new_df = pd.DataFrame(columns=['descendant', 'ancestor']).astype(int)
        for index, row in df.iterrows():
            temp_df = pd.DataFrame(columns=['descendant', 'ancestor'])
            temp_df['descendant'] = pd.Series(get_child_list(row['parent']))
            temp_df['ancestor'] = row['parent']
            new_df = new_df.append(temp_df)
    
        new_df = new_df\
            .drop_duplicates()\
            .sort_values(['ancestor', 'descendant'])\
            .reset_index(drop=True)
    
        return new_df
    
    print(method2())
    print(method3())
    
    from timeit import timeit
    print(timeit(method1, number=50))
    print(timeit(method2, number=50))
    print(timeit(method3, number=50))
    

    Test Results:

        descendant  ancestor
    0         2010      1000
    1         2100      1000
    2         2110      1000
    3         3000      1000
    4         3011      1000
    5         3033      1000
    6         3102      1000
    7         3111      1000
    8         3011      2010
    9         3102      2010
    10        3033      2100
    11        3000      2110
    12        3111      2110
    
        descendant  ancestor
    0         2010      1000
    1         2100      1000
    2         2110      1000
    3         3000      1000
    4         3011      1000
    5         3033      1000
    6         3102      1000
    7         3111      1000
    8         3011      2010
    9         3102      2010
    10        3033      2100
    11        3000      2110
    12        3111      2110
    
    0.0737142168563
    0.153700592966
    3.38558308083
    

提交回复
热议问题