素集合データ構造
元ネタ:
http://con-leche.blogspot.com/2010/03/google-devfest-2010.html
参考:
http://www.kmonos.net/wlog/88.html
http://d.hatena.ne.jp/rubyco/20080719/equiv
wikipedia:素集合データ構造
素集合データ構造(Union-Find)は、要素を同値類(素集合)に分類するためのデータ構造です。例えば、"a=b, b=c, c=d, A=B, B=C, C=D"という関係が与えられたとき、{a, b, c, d}と{A, B, C, D}という2つのグループに要素を分類します。
「同値類」ではなく「素集合」と呼ぶあたりが、数学ではなく計算機科学っぽいところです。しかし、「素集合データ構造」という身も蓋もない名前からは、白衣の研究者って感じがして素敵です。
実装方法
同値類ごとに、同値類の元をノードとする木構造を作ります。木のルートノードが同値類の代表元になります。
ただし、普通木構造はルート→葉の方向にリンクを張りますが、素集合データ構造では葉→ルートの方向にリンクを張ります。そのためどのノードからも、リンクを辿っていくとルートノードに到達できるようになっています(findメソッド)。
そして、同値関係が追加されるごとに、それぞれの属する木を統合します( unionメソッド)。
速度
ただし、単純に作るとfindにかかる時間が要素数に比例して(O(n))増えてしまいます。
高速にfindするには、要素がルートを直接指すように追加していけばよろしい。しかし、すると今度はunionで統合するノードのリンクを張り替えなければならないので、unionにO(n)の時間がかかってしまいます。
また、木の長さがわかっているなら、unionする時に短い木を長い木に連結すれば、速度が上がります。
そこで、
- findでルートへのリンクをたどっていく際、要素のリンクをルートに張りかえながら辿る
- リストの長さを比べる代わりに、ヒューリスティクスとして要素毎にrankを定め、rankの小さいリンクを大きいリンクに連結する
すると、要素追加にかかる時間を O(α(n)) に抑えることができます。
ここで、α(n)は関数 f(n) = A(n, n) の逆関数で、A(n, m)はアッカーマン関数という非常に急速に増加する関数なので、
α(n)はほとんど定数関数と言えます。
class UnionFind: def __init__(self): self._parent = {} self._rank = {} self._equiv = {} def _find(self, x): if self._parent[x] == x: return self._parent[x] else: p = self._find(self._parent[x]) self._parent[x] = p return p def _set_parent(self, child, parent): self._parent[child] = parent self._equiv[parent].update(self._equiv.pop(child)) def _union(self, x, y): x = self._find(x) y = self._find(y) if x == y: return if self._rank[x] > self._rank[y]: self._set_parent(y, x) elif self._rank[x] < self._rank[y]: self._set_parent(x, y) else: self._set_parent(x, y) self._rank[x] += 1 def _make_set(self, x): self._parent[x] = x self._rank[x] = 0 self._equiv[x] = {x} def unite(self, *elements): if not elements: return for x in elements: if x in self._parent: continue self._make_set(x) x0 = elements[0] for x in elements[1:]: self._union(x0, x) def groups(self): return [set(e) for e in self._equiv.values()] if __name__ == "__main__": uf = UnionFind() uf.unite("a", "b") uf.unite("b", "c") uf.unite("c", "d") uf.unite("A", "B") uf.unite("B", "C") uf.unite("C", "D") print(uf.groups()) #=> [{'a', 'c', 'b', 'd'}, {'A', 'C', 'B', 'D'}]
パッチワーク問題を解いてみる
http://con-leche.blogspot.com/2010/03/google-devfest-2010.html
ここに "A" または "B" という文字のみを含む 600 桁、600 行のテキストがあります。 これを 600 x 600 の升目状に並べ、上下左右に同じ文字がある部分をつながっているとみなします。
まず、最も多くの文字がつながっている領域をすべて "_" で塗りつぶしてください。 最も多くの文字がつながっている領域が複数存在するならば、それらすべての領域を "_"で塗りつぶすこととします。
そして、各行ごとに "_" が何個含まれているかを数え、それらの数値を改行区切りのテキスト(600 行)として答えてください。
以下の例1を見てください。この入力には単一の文字4つでつながった領域が3箇所あります。これらすべてが「最も多くの文字がつながっている領域」なので、全て"_"で塗りつぶし、その数を数えています。
例1:
入力ABAAB
BABAA
BAABB
ABABB
BABAA塗りつぶしたところ.
AB__B
B_B__
B____
AB___
BABAA答え
2
3
4
3
0
例2:
入力BAABBABBBB
BAABABBBBB
BBAABABBBA
BABBBABBAA
BBABAAABAB
BABABBBAAA
AABBBAAAAA
BAAAAAABBB
AAABABBAAB
AABAABBABA塗りつぶしたところ.
BAABBABBBB
BAABABBBBB
BBAABABBB_
BABBBABB__
BBABAAAB_B
B_BABBB___
__BBB_____
B______BBB
___B_BBAAB
__B__BBABA答え
0
0
1
2
1
4
7
6
4
4
#patchwork.py import sys from itertools import product from unionfind import UnionFind def solve(w, h, elements): uf = UnionFind() for p in product(xrange(w), xrange(h)): left = (p[0]-1, p[1]) up = (p[0], p[1]-1) for q in [left, up]: if q not in elements: continue if elements[p] == elements[q]: uf.unite(p, q) groups = uf.groups() maxSize = max(len(e) for e in groups) toDelete = set() for e in groups: if len(e) == maxSize: toDelete.update(e) results = [] for y in xrange(h): rowIndexes = ((x, y) for x in xrange(w)) v = len(toDelete.intersection(rowIndexes)) results.append(v) return results def main(): elements = {} for y, line in enumerate(open(sys.argv[1])): for x, c in enumerate(line.strip()): elements[x, y] = c size = y for x in solve(size, size, elements): print(x) if __name__ == "__main__": main()
$ cat data.txt BAABBABBBB BAABABBBBB BBAABABBBA BABBBABBAA BBABAAABAB BABABBBAAA AABBBAAAAA BAAAAAABBB AAABABBAAB AABAABBABA $ patchwork.py data.txt 0 0 0 1 1 3 6 6 4