3つ以上の集合の直積を求めるプログラム その2

3つ以上の集合の直積を求めるプログラムClojureで書いてみました。

Python版では無限リストを扱う関係の部分でゴチャゴチャしていたのが、Clojureは無限リストを自然に扱えます。

一方Clojureは状態変化が基本的に無いので、Python版は直訳は出来ません。

clojure.contrib.combinatorics/cartesian-productのソースをカンニングしながら、書き直してみました。結局再帰もスタックも無い形に書き直せたので、Pythonにも翻訳してみました。

(ns product
  (:use clojure.contrib.seq-utils
        clojure.contrib.test-is
        clojure.contrib.combinatorics))

(defn has-nth [seq n]
  (boolean (nthnext seq n)))

(def not-nil? #(not (nil? %)))

(defn zip [seq1 seq2]
  (map #(vector %1 %2) seq1 seq2))

(defn incr-multi-index [multi-index seqs]
  (first
    (filter not-nil?
      (for [[n [index s]] (reverse (indexed (zip multi-index seqs)))]
        (if (has-nth s (+ 1 index))
          (concat
            (take n multi-index)
            [(+ 1 index)]
            (repeat (- (count multi-index) n 1) 0))
          nil)))))

(defn product [& seqs]
  (let [incr #(incr-multi-index % seqs)
        multi-index0 (repeat (count seqs) 0)
        multi-indexes (take-while not-nil? (iterate incr multi-index0))]
    (for [multi-index multi-indexes]
      (for [[k index] (indexed multi-index)]
        (nth (nth seqs k) index)))))

(deftest test-product
  (let [numbers [1 2 3]]
    (doseq [i (range 1)]
      (is (= (sort (apply cartesian-product (repeat i numbers))) 
             (sort (apply product           (repeat i numbers))))))))
#encoding:shift-jis
#product.py
"""3つ以上の集合の直積を求めるプログラム"""
from __future__ import division, print_function, unicode_literals
__metaclass__ = type 
from itertools import izip, count

class CachedIter:
    def __init__(self, iterable):
        self._iterable = iter(iterable)
        self._values = []
        
    def has_nth(self, n):
        if not self._iterable:
            return n < len(self._values)
        elif n < len(self._values):
            return True
        else:
            for i in xrange(len(self._values), n + 1):
                try:
                    v = next(self._iterable)
                except StopIteration:
                    self._value = None
                    return False
                else:
                    self._values.append(v)
            else:
                return True
    
    def __getitem__(self, n):
        if self.has_nth(n):
            return self._values[n]
        else:
            raise IndexError
    

def incr_multi_index(multiIndex, iterables):
    enum = list(enumerate(izip(iterables, multiIndex)))
    for i, (it, index) in reversed(enum):
        if it.has_nth(index + 1):
            newMultiIndex = list(multiIndex)
            newMultiIndex[i] = index + 1
            for k in xrange(i + 1, len(newMultiIndex)):
                newMultiIndex[k] = 0
            return tuple(newMultiIndex)
    else:
        return None

def product(*iterables):
    #添え字を使った非再帰版直積関数
    iterables = [CachedIter(it) for it in iterables]
    
    multiIndex = (0,) * len(iterables)
    while multiIndex is not None:
        yield tuple(it[index] for it, index in izip(iterables, multiIndex))
        multiIndex = incr_multi_index(multiIndex, iterables)

def main():
    import itertools
    
    numbers = [1, 2, 3]
    for i in xrange(3):
        p = sorted(product(*[numbers]*i))
        q = sorted(itertools.product(*[numbers]*i))
        assert p == q, p
        
if "__main__" == __name__:
    main()