3つ以上の集合の直積を求めるプログラム
http://d.hatena.ne.jp/yagiey/20100705/1278340237
http://d.hatena.ne.jp/youkoso_guest/20081212/1229089595
に、直積を作る関数について書かれていたのですが、どちらも有限なlistに限定したものだったので、無限リストにも扱える関数を書いてみました。
itertools.productというのが標準ライブラリにあるので、思いっきり車輪の再生産ですが。
#encoding:shift-jis """3つ以上の集合の直積を求めるプログラム""" from __future__ import division, print_function, unicode_literals __metaclass__ = type from functools import partial from itertools import izip, count def product_rec(*iterables): #再起版直積関数 if not iterables: yield () else: it_first = iter(iterables[0]) it_rests = product_rec(*iterables[1:]) firsts = [] rests = [] for height in count(): try: firsts.append(next(it_first)) except StopIteration: pass try: rests.append(next(it_rests)) except StopIteration: pass if len(firsts) + len(rests) < height: return for i, x in enumerate(firsts[:height + 1]): k = height - i if k < len(rests): yield (x,) + rests[k] def product_non_rec(*iterables): #非再起版直積関数 iterables = [iter(it) for it in iterables] values = [[] for it in iterables] stack = [] push = stack.append def pop(): try: return stack.pop(-1) except IndexError: raise StopIteration i = 0 while 1: if len(stack) == len(values): yield tuple(values[j][k] for j, k in enumerate(stack)) i = pop() + 1 else: if len(values[len(stack)]) == i: try: v = next(iterables[len(stack)]) except StopIteration: i = pop() + 1 continue else: values[len(stack)].append(v) push(i) i = 0 def main(): import itertools numbers = [1, 2, 3] for prod in [product_rec, product_non_rec, itertools.product]: p0 = sorted(prod()) assert p0 == [()], (prod, p0) p1 = sorted(prod(numbers)) q1 = sorted([(1,), (2,), (3,)]) assert p1 == q1, (prod, p1) p2 = sorted(prod(numbers, numbers)) q2 = sorted([(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)]) assert p2 == q2, (prod, p2) p3 = sorted(prod(numbers, numbers, numbers)) q3 = sorted([(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 2, 1), (1, 2, 2), (1, 2, 3), (1, 3, 1), (1, 3, 2), (1, 3, 3), (2, 1, 1), (2, 1, 2), (2, 1, 3), (2, 2, 1), (2, 2, 2), (2, 2, 3), (2, 3, 1), (2, 3, 2), (2, 3, 3), (3, 1, 1), (3, 1, 2), (3, 1, 3), (3, 2, 1), (3, 2, 2), (3, 2, 3), (3, 3, 1), (3, 3, 2), (3, 3, 3), ]) assert p3 == q3, (prod, p3) if "__main__" == __name__: main()
なお、再起版のproduct_recはリストを渡しすぎると、再帰呼び出しが深すぎてエラーになります。
$>python -c "from product import product_rec;print next(product_rec(*[[1, 2, 3]]*1000))" File "product.py", line 26, in product_rec rests.append(next(it_rests)) File "product.py", line 26, in product_rec rests.append(next(it_rests)) (中略) File "product.py", line 15, in product_rec it_rests = product_rec(*iterables[1:]) RuntimeError: maximum recursion depth exceeded while calling a Python object