How to properly implement disjoint set data structure for finding spanning forests in Python?

落爺英雄遲暮 提交于 2021-02-05 08:34:51

问题


Recently, I was trying to implement the solutions of google kickstater's 2019 programming questions and tried to implement Round E's Cherries Mesh by following the analysis explanation. Here is the link to the question and the analysis. https://codingcompetitions.withgoogle.com/kickstart/round/0000000000050edb/0000000000170721

Here is the code I implemented:

t = int(input())
for k in range(1,t+1):
    n, q = map(int,input().split())
    se = list()
    for _ in range(q):
        a,b = map(int,input().split())
        se.append((a,b))
    l = [{x} for x in range(1,n+1)]
    #print(se)
    for s in se:
        i = 0
        while ({s[0]}.isdisjoint(l[i])):
            i += 1
        j = 0
        while ({s[1]}.isdisjoint(l[j])):
            j += 1
        if i!=j:
            l[i].update(l[j])
            l.pop(j)
        #print(l)
    count = q+2*(len(l)-1)
    print('Case #',k,': ',count,sep='')



This passes the sample case but not the test cases. To the best of my knowledge, this should be right. Am I doing something wrong?


回答1:


Two issues:

  • Your algorithm to check whether an edge links two disjoint sets, and join them if not, is inefficient. The Union-Find algorithm on a Disjoint-Set data structure is more efficient
  • The final count is not dependent on the original number of black edges, as those black edges may have cycles, and so some of them should not be counted. Instead count how many edges there are in total (irrespective of colour). As the solution represents a minimum spanning tree, the number of edges is n-1. Subtract from that the number of disjoint sets you have (like you already did).

I would also advise to use meaningful variable names. The code is much easier to understand. One-letter variables, like t, q or s, are not very helpful.

There are several ways to implement the Union-Find functions. Here I have defined a Node class which has those methods:

# Implementation of Union-Find (Disjoint Set)
class Node:
    def __init__(self):
        self.parent = self
        self.rank = 0

    def find(self):
        if self.parent.parent != self.parent:
            self.parent = self.parent.find()
        return self.parent

    def union(self, other):
        node = self.find()
        other = other.find()
        if node == other:
            return True # was already in same set
        if node.rank > other.rank:
            node, other = other, node
        node.parent = other
        other.rank = max(other.rank, node.rank + 1)
        return False # was not in same set, but now is

testcount = int(input())
for testid in range(1, testcount + 1):
    nodecount, blackcount = map(int, input().split())
    # use Union-Find data structure
    nodes = [Node() for _ in range(nodecount)]
    blackedges = []
    for _ in range(blackcount):
        start, end = map(int, input().split())
        blackedges.append((nodes[start - 1], nodes[end - 1]))

    # Start with assumption that all edges on MST are red:
    sugarcount = nodecount * 2 - 2
    for start, end in blackedges:
        if not start.union(end): # When edge connects two disjoint sets:
            sugarcount -= 1 # Use this black edge instead of red one

    print('Case #{}: {}'.format(testid, sugarcount))



回答2:


You are getting an incorrect answer, because you're calculating the count incorrectly. The it takes n-1 edges to connect n nodes into a tree, and num_clusters-1 of those have to be red.

But if you fix that, your program will still be very slow, because of your disjoint set implementation.

Thankfully, it's actually pretty easy to implement a very efficient disjoint set data structure in a single array/list/vector in just about any programming language. Here's a nice one in python. I have python 2 on my box, so my print and input statements are a little different from yours:

# Create a disjoint set data structure, with n singletons, numbered 0 to n-1
# This is a simple array where for each item x:
# x > 0 => a set of size x, and x <= 0 => a link to -x

def ds_create(n):
    return [1]*n

# Find the current root set for original singleton index

def ds_find(ds, index):
    val = ds[index]
    if (val > 0):
        return index
    root = ds_find(-val)
    if (val != -root):
        ds[index] = -root # path compression
    return root

# Merge given sets. returns False if they were already merged

def ds_union(ds, a, b):
    aroot = ds_find(ds, a)
    broot = ds_find(ds, b)
    if aroot == broot:
        return False
    # union by size
    if ds[aroot] >= ds[broot]:
        ds[aroot] += ds[broot]
        ds[broot] = -aroot
    else:
        ds[broot] += ds[aroot]
        ds[aroot] = -broot
    return True

# Count root sets

def ds_countRoots(ds):
    return sum(1 for v in ds if v > 0)

#
# CherriesMesh solution
#
numTests = int(raw_input())
for testNum in range(1,numTests+1):
    numNodes, numEdges = map(int,raw_input().split())
    sets = ds_create(numNodes)
    for _ in range(numEdges):
        a,b = map(int,raw_input().split())
        print a,b
        ds_union(sets, a-1, b-1)
    count = numNodes + ds_countRoots(sets) - 2
    print 'Case #{0}: {1}'.format(testNum, count)


来源:https://stackoverflow.com/questions/60004277/how-to-properly-implement-disjoint-set-data-structure-for-finding-spanning-fores

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!