素集合データ構造

元ネタ:
http://con-leche.blogspot.com/2010/03/google-devfest-2010.html

参考:
http://www.kmonos.net/wlog/88.html
http://d.hatena.ne.jp/rubyco/20080719/equiv
wikipedia:素集合データ構造


素集合データ構造(Union-Find)は、要素を同値類(素集合)に分類するためのデータ構造です。例えば、"a=b, b=c, c=d, A=B, B=C, C=D"という関係が与えられたとき、{a, b, c, d}と{A, B, C, D}という2つのグループに要素を分類します。


「同値類」ではなく「素集合」と呼ぶあたりが、数学ではなく計算機科学っぽいところです。しかし、「素集合データ構造」という身も蓋もない名前からは、白衣の研究者って感じがして素敵です。

実装方法

同値類ごとに、同値類の元をノードとする木構造を作ります。木のルートノードが同値類の代表元になります。

ただし、普通木構造はルート→葉の方向にリンクを張りますが、素集合データ構造では葉→ルートの方向にリンクを張ります。そのためどのノードからも、リンクを辿っていくとルートノードに到達できるようになっています(findメソッド)。

そして、同値関係が追加されるごとに、それぞれの属する木を統合します( unionメソッド)。

速度

ただし、単純に作るとfindにかかる時間が要素数に比例して(O(n))増えてしまいます。
高速にfindするには、要素がルートを直接指すように追加していけばよろしい。しかし、すると今度はunionで統合するノードのリンクを張り替えなければならないので、unionにO(n)の時間がかかってしまいます。
また、木の長さがわかっているなら、unionする時に短い木を長い木に連結すれば、速度が上がります。

そこで、

  • findでルートへのリンクをたどっていく際、要素のリンクをルートに張りかえながら辿る
  • リストの長さを比べる代わりに、ヒューリスティクスとして要素毎にrankを定め、rankの小さいリンクを大きいリンクに連結する

すると、要素追加にかかる時間を O(α(n)) に抑えることができます。

ここで、α(n)は関数 f(n) = A(n, n) の逆関数で、A(n, m)はアッカーマン関数という非常に急速に増加する関数なので、
α(n)はほとんど定数関数と言えます。

class UnionFind:
    def __init__(self):
        self._parent = {}
        self._rank = {}
        self._equiv = {}
    
    def _find(self, x):
        if self._parent[x] == x:
            return self._parent[x]
        else:
            p = self._find(self._parent[x])
            self._parent[x] = p
            return p
    
    def _set_parent(self, child, parent):
        self._parent[child] = parent
        self._equiv[parent].update(self._equiv.pop(child))
    
    def _union(self, x, y):
        x = self._find(x)
        y = self._find(y)
        if x == y: return
        if self._rank[x] > self._rank[y]:
            self._set_parent(y, x)
        elif  self._rank[x] < self._rank[y]:
            self._set_parent(x, y)
        else:
            self._set_parent(x, y)
            self._rank[x] += 1
    
    def _make_set(self, x):
        self._parent[x] = x
        self._rank[x] = 0
        self._equiv[x] = {x}
    
    def unite(self, *elements):
        if not elements: return
        for x in elements:
            if x in self._parent: continue
            self._make_set(x)
        x0 = elements[0]
        for x in elements[1:]:
            self._union(x0, x)
    
    def groups(self):
        return [set(e) for e in self._equiv.values()]

if __name__ == "__main__":
    uf = UnionFind()
    uf.unite("a", "b")
    uf.unite("b", "c")
    uf.unite("c", "d")
    uf.unite("A", "B")
    uf.unite("B", "C")
    uf.unite("C", "D")
    
    print(uf.groups()) #=> [{'a', 'c', 'b', 'd'}, {'A', 'C', 'B', 'D'}]

パッチワーク問題を解いてみる

http://con-leche.blogspot.com/2010/03/google-devfest-2010.html
ここに "A" または "B" という文字のみを含む 600 桁、600 行のテキストがあります。 これを 600 x 600 の升目状に並べ、上下左右に同じ文字がある部分をつながっているとみなします。
まず、最も多くの文字がつながっている領域をすべて "_" で塗りつぶしてください。 最も多くの文字がつながっている領域が複数存在するならば、それらすべての領域を "_"で塗りつぶすこととします。
そして、各行ごとに "_" が何個含まれているかを数え、それらの数値を改行区切りのテキスト(600 行)として答えてください。
以下の例1を見てください。この入力には単一の文字4つでつながった領域が3箇所あります。これらすべてが「最も多くの文字がつながっている領域」なので、全て"_"で塗りつぶし、その数を数えています。
例1:
入力

  ABAAB
  BABAA
  BAABB
  ABABB
  BABAA

塗りつぶしたところ.

  AB__B
  B_B__
  B____
  AB___
  BABAA

答え

  2
  3
  4
  3
  0


例2:
入力

  BAABBABBBB
  BAABABBBBB
  BBAABABBBA
  BABBBABBAA
  BBABAAABAB
  BABABBBAAA
  AABBBAAAAA
  BAAAAAABBB
  AAABABBAAB
  AABAABBABA

塗りつぶしたところ.

  BAABBABBBB
  BAABABBBBB
  BBAABABBB_
  BABBBABB__
  BBABAAAB_B
  B_BABBB___
  __BBB_____
  B______BBB
  ___B_BBAAB
  __B__BBABA

答え

  0
  0
  1
  2
  1
  4
  7
  6
  4
  4

#patchwork.py
import sys
from itertools import product
from unionfind import UnionFind

def solve(w, h, elements):
    uf = UnionFind()
    for p in product(xrange(w), xrange(h)):
        left = (p[0]-1, p[1])
        up = (p[0], p[1]-1)
        for q in [left, up]:
            if q not in elements: continue
            if elements[p] == elements[q]:
                uf.unite(p, q)
    
    groups =  uf.groups()
    maxSize = max(len(e) for e in groups)
    toDelete = set()
    for e in groups:
        if len(e) == maxSize:
            toDelete.update(e)
    
    results = []
    for y in xrange(h):
        rowIndexes = ((x, y) for x in xrange(w))
        v = len(toDelete.intersection(rowIndexes))
        results.append(v)
    return results

def main():
    elements = {}
    for y, line in enumerate(open(sys.argv[1])):
        for x, c in enumerate(line.strip()):
            elements[x, y] = c
    size = y
    for x in solve(size, size, elements):
        print(x)

if __name__ == "__main__":
    main()
$ cat data.txt
BAABBABBBB
BAABABBBBB
BBAABABBBA
BABBBABBAA
BBABAAABAB
BABABBBAAA
AABBBAAAAA
BAAAAAABBB
AAABABBAAB
AABAABBABA

$ patchwork.py data.txt
0
0
0
1
1
3
6
6
4