3つ以上の集合の直積を求めるプログラム

http://d.hatena.ne.jp/yagiey/20100705/1278340237
http://d.hatena.ne.jp/youkoso_guest/20081212/1229089595
に、直積を作る関数について書かれていたのですが、どちらも有限なlistに限定したものだったので、無限リストにも扱える関数を書いてみました。

itertools.productというのが標準ライブラリにあるので、思いっきり車輪の再生産ですが。

#encoding:shift-jis
"""3つ以上の集合の直積を求めるプログラム"""
from __future__ import division, print_function, unicode_literals
__metaclass__ = type 

from functools import partial
from itertools import izip, count

def product_rec(*iterables):
    #再起版直積関数
    if not iterables:
        yield ()
    else:
        it_first = iter(iterables[0])
        it_rests = product_rec(*iterables[1:])
        
        firsts = []
        rests  = []
        
        for height in count():
            try:
                firsts.append(next(it_first))
            except StopIteration:
                pass
            try:
                rests.append(next(it_rests))
            except StopIteration:
                pass
            
            if len(firsts) + len(rests) < height:
                return
            
            for i, x in enumerate(firsts[:height + 1]):
                k = height - i
                if k < len(rests):
                    yield (x,) + rests[k]


def product_non_rec(*iterables):
    #非再起版直積関数
    iterables = [iter(it) for it in iterables]
    
    values = [[] for it in iterables]
    stack = []
    push = stack.append
    def pop():
        try:
            return stack.pop(-1)
        except IndexError:
            raise StopIteration
    
    i = 0
    while 1:
        if len(stack) == len(values):
            yield tuple(values[j][k] for j, k in enumerate(stack))
            i = pop() + 1
        else:
            if len(values[len(stack)]) == i:
                try:
                    v = next(iterables[len(stack)])
                except StopIteration:
                    i = pop() + 1
                    continue
                else:
                    values[len(stack)].append(v)
            push(i)
            i = 0
    
def main():
    import itertools
    
    numbers = [1, 2, 3]
    for prod in [product_rec, product_non_rec, itertools.product]:
        p0 = sorted(prod())
        assert p0 == [()], (prod, p0)
        
        p1 = sorted(prod(numbers))
        q1 = sorted([(1,), (2,), (3,)])
        assert p1 == q1, (prod, p1)
        
        p2 = sorted(prod(numbers, numbers))
        q2 = sorted([(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)])
        assert p2 == q2, (prod, p2)
        
        p3 = sorted(prod(numbers, numbers, numbers))
        q3 = sorted([(1, 1, 1), (1, 1, 2), (1, 1, 3),
                  (1, 2, 1), (1, 2, 2), (1, 2, 3),
                  (1, 3, 1), (1, 3, 2), (1, 3, 3),
                  (2, 1, 1), (2, 1, 2), (2, 1, 3),
                  (2, 2, 1), (2, 2, 2), (2, 2, 3),
                  (2, 3, 1), (2, 3, 2), (2, 3, 3),
                  (3, 1, 1), (3, 1, 2), (3, 1, 3),
                  (3, 2, 1), (3, 2, 2), (3, 2, 3),
                  (3, 3, 1), (3, 3, 2), (3, 3, 3),
            ])
        assert p3 == q3, (prod, p3)
        
if "__main__" == __name__:
    main()


なお、再起版のproduct_recはリストを渡しすぎると、再帰呼び出しが深すぎてエラーになります。

$>python -c "from product import product_rec;print next(product_rec(*[[1, 2, 3]]*1000))"
  File "product.py", line 26, in product_rec
    rests.append(next(it_rests))
  File "product.py", line 26, in product_rec
    rests.append(next(it_rests))
      (中略)
  File "product.py", line 15, in product_rec
    it_rests = product_rec(*iterables[1:])
RuntimeError: maximum recursion depth exceeded while calling a Python object