3つ以上の集合の直積を求めるプログラム その2
3つ以上の集合の直積を求めるプログラムをClojureで書いてみました。
Python版では無限リストを扱う関係の部分でゴチャゴチャしていたのが、Clojureは無限リストを自然に扱えます。
一方Clojureは状態変化が基本的に無いので、Python版は直訳は出来ません。
clojure.contrib.combinatorics/cartesian-productのソースをカンニングしながら、書き直してみました。結局再帰もスタックも無い形に書き直せたので、Pythonにも翻訳してみました。
(ns product (:use clojure.contrib.seq-utils clojure.contrib.test-is clojure.contrib.combinatorics)) (defn has-nth [seq n] (boolean (nthnext seq n))) (def not-nil? #(not (nil? %))) (defn zip [seq1 seq2] (map #(vector %1 %2) seq1 seq2)) (defn incr-multi-index [multi-index seqs] (first (filter not-nil? (for [[n [index s]] (reverse (indexed (zip multi-index seqs)))] (if (has-nth s (+ 1 index)) (concat (take n multi-index) [(+ 1 index)] (repeat (- (count multi-index) n 1) 0)) nil))))) (defn product [& seqs] (let [incr #(incr-multi-index % seqs) multi-index0 (repeat (count seqs) 0) multi-indexes (take-while not-nil? (iterate incr multi-index0))] (for [multi-index multi-indexes] (for [[k index] (indexed multi-index)] (nth (nth seqs k) index))))) (deftest test-product (let [numbers [1 2 3]] (doseq [i (range 1)] (is (= (sort (apply cartesian-product (repeat i numbers))) (sort (apply product (repeat i numbers))))))))
#encoding:shift-jis #product.py """3つ以上の集合の直積を求めるプログラム""" from __future__ import division, print_function, unicode_literals __metaclass__ = type from itertools import izip, count class CachedIter: def __init__(self, iterable): self._iterable = iter(iterable) self._values = [] def has_nth(self, n): if not self._iterable: return n < len(self._values) elif n < len(self._values): return True else: for i in xrange(len(self._values), n + 1): try: v = next(self._iterable) except StopIteration: self._value = None return False else: self._values.append(v) else: return True def __getitem__(self, n): if self.has_nth(n): return self._values[n] else: raise IndexError def incr_multi_index(multiIndex, iterables): enum = list(enumerate(izip(iterables, multiIndex))) for i, (it, index) in reversed(enum): if it.has_nth(index + 1): newMultiIndex = list(multiIndex) newMultiIndex[i] = index + 1 for k in xrange(i + 1, len(newMultiIndex)): newMultiIndex[k] = 0 return tuple(newMultiIndex) else: return None def product(*iterables): #添え字を使った非再帰版直積関数 iterables = [CachedIter(it) for it in iterables] multiIndex = (0,) * len(iterables) while multiIndex is not None: yield tuple(it[index] for it, index in izip(iterables, multiIndex)) multiIndex = incr_multi_index(multiIndex, iterables) def main(): import itertools numbers = [1, 2, 3] for i in xrange(3): p = sorted(product(*[numbers]*i)) q = sorted(itertools.product(*[numbers]*i)) assert p == q, p if "__main__" == __name__: main()