itertools.combinationsを無限長イテレータに対応させる

pythonのitertools.combinationsやitertools.permutationには無限長のイテレータを渡す事はできません。内部でイテレータをタプルに変換しようとするのでフリーズしてしまいます。

無限長にも対応する、アルゴリズム自体は難しくはありません。

そこで、python-ml-jpで、なぜitertools.combinationsは無限長に対応していないか質問したところ、

タプルの代わりにリストを利用することで、最初に各イテレータから全要素を取り出すことを
やめられるでしょうが、それだとreallocの回数分パフォーマンスが悪くなります。

実際に、無限長のイテレータに対応したバージョンを作ってみました。

>>> from itertools2 import combinations #無限長に対応したバージョン
>>> from itertools import count, islice
>>> it = combinations(count(), 2)       #無限長のを渡してもフリーズしない
>>> list(islice(it, 30))
[(0, 1), (0, 2), (1, 2), (0, 3), (1, 3), (2, 3), (0, 4), (1, 4), (2, 4), (3, 4),
 (0, 5), (1, 5), (2, 5), (3, 5), (4, 5), (0, 6), (1, 6), (2, 6), (3, 6), (4, 6),
 (5, 6), (0, 7), (1, 7), (2, 7), (3, 7), (4, 7), (5, 7), (6, 7), (0, 8), (1, 8)]

そして、時間測定

from __future__ import division, print_function, unicode_literals

import itertools2
import itertools

import timeit
for mod in ["itertools", "itertools2"]:
    print(mod)
    for n, r in [(10, 2), (100, 2), (1000, 2)]:
        t = timeit.Timer(
            "for x in combinations(xrange({}), {}):pass".format(n, r),
            "from {} import combinations".format(mod)
        )
        print(t.timeit(number=100))
    print()

測定結果

itertools
0.000719923900943
0.0634376715476
6.90153106051

itertools2
0.000939225516092
0.083500556635
9.52643394621

やっぱり無限長に対応した物の方が遅いです。1.5倍ぐらいずつかかっています。itertoolsのような基礎的なモジュールで50%の速度差は致命的です。しかし、単に私のコードが下手だからという可能性はぬぐえません。

無限シーケンスにcombinationsを使う機会は、実際どれくらいあるものなんでしょう?

以下、ソース。初めて作ったC拡張を作りました。参照関係のデバッグは地獄です。

/* itertools2.c */

#include "Python.h"
#include "structmember.h"

/* combinations object ************************************************************/

typedef struct
{
    PyObject_HEAD

    PyObject *it;
    PyObject  *pool;            /* input converted to a tuple */
    PyObject *elem;
    Py_ssize_t *indices;        /* one index per result element */
    Py_ssize_t r;               /* size of result tuple */
    Py_ssize_t i;               /* current axis of indices */
    Py_ssize_t first;
    int stopped;                /* set to 1 when the combinations iterator is exhausted */
}
combinationsobject;

static PyTypeObject combinations_type;

int next_indices(Py_ssize_t r, Py_ssize_t n, Py_ssize_t *indices, Py_ssize_t *i)
{
    if (r == 0 || n == 0 || *i < 0 || n <= *i || r < n)
    {
        return 0;
    }
    while(1)
    {
        indices[*i] += 1;
        if(*i == n - 1 && r <= indices[*i])
        {
            return 0;
        }
        else if (*i < n - 1 && indices[*i + 1] <= indices[*i])
        {
            *i += 1;
        }
        else if (0 < *i)
        {
            *i -= 1;
            indices[*i] = -1;
        }
        else
        {
            return 1;
        }
    }
}

static PyObject *
combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
    combinationsobject *co;
    Py_ssize_t r;
    PyObject *iterable = NULL;
    Py_ssize_t *indices = NULL;
    PyObject *pool = NULL;
    static char *kwargs[] = {"iterable", "r", NULL};

    if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations", kwargs,
                                     &iterable, &r))
    {
        return NULL;
    }

    if (r < 0)
    {
        PyErr_SetString(PyExc_ValueError, "r must be non-negative");
        goto error;
    }

    if(r > 0)
    {
        indices = PyMem_Malloc((r - 1) * sizeof(Py_ssize_t));
        if (indices == NULL)
        {
            PyErr_NoMemory();
            goto error;
        }
    }

    pool = PyList_New(0);
    if(pool == NULL)
    {
        goto error;
    }

    /* create combinationsobject structure */
    co = (combinationsobject *) type->tp_alloc(type, 0);
    if (co == NULL)
    {
        goto error;
    }

    co->pool = pool;
    co->indices = indices;
    co->it = PyObject_GetIter(iterable);
    co->first = 1;
    co->r = r;
    co->stopped = 0;

    return (PyObject *)co;
error:
    if (indices != NULL)
    {
        PyMem_Free(indices);
    }
    Py_XDECREF(pool);
    return NULL;
}

static void
combinations_dealloc(combinationsobject *co)
{
    PyObject_GC_UnTrack(co);
    Py_XDECREF(co->pool);
    Py_XDECREF(co->it);
    Py_XDECREF(co->elem);

    if (co->indices != NULL)
    {
        PyMem_Free(co->indices);
    }
    Py_TYPE(co)->tp_free(co);
}

static int
combinations_traverse(combinationsobject *co, visitproc visit, void *arg)
{
    Py_VISIT(co->it);
    Py_VISIT(co->pool);
    Py_VISIT(co->elem);
    return 0;
}

static PyObject *
combinations_next(combinationsobject *co)
{
    Py_ssize_t j;
    PyObject *result;
    Py_ssize_t *indices = co->indices;
    Py_ssize_t *i = &co->i;
    PyObject *pool = co->pool;
    PyObject *it = co->it;
    Py_ssize_t r = co->r;

    if (co->stopped)
    {
        return NULL;
    }
    else if (r == 0)
    {
        co->first = 0;
        co->stopped = 1;
        result = PyTuple_New(0);
        if (result == NULL)
        {
            goto empty;
        }
        return result;
    }
    else if(r == 1)
    {
        PyObject *x = NULL;
        co->first = 0;
        x = PyIter_Next(it);
        if(x == NULL)
        {
            goto empty;
        }
        result = PyTuple_New(1);
        if (result == NULL)
        {
            goto empty;
        }
        PyTuple_SetItem(result, 0, x);
        return result;
    }
    else
    {
        if (co->first)
        {
            PyObject *x = NULL;
            co->first = 0;
            for(j=0; j < r - 1; ++j)
            {
                x = PyIter_Next(it);
                if(x == NULL)
                {
                    goto empty;
                }
                if (PyList_Append(pool, x) == -1)
                {
                    goto empty;
                }
                Py_DECREF(x);
            }
            x = PyIter_Next(it);
            if(x == NULL)
            {
                goto empty;
            }
            co->elem = x;
            for (j=0 ; j < r - 1  ; j++)
            {
                indices[j] = -1;
            }
            *i = r - 2;
        }
        while(1)
        {
            if(next_indices(PyList_Size(pool), r - 1, indices, i))
            {
                PyObject *x;
                result = PyTuple_New(r);
                if (result == NULL)
                {
                    goto empty;
                }
                for(j=0; j < r - 1; ++j)
                {
                    x = PyList_GetItem(pool, indices[j]);
                    if(x == NULL)
                    {
                        goto empty;
                    }
                    Py_INCREF(x);
                    PyTuple_SetItem(result, j, x);
                }
                Py_INCREF(co->elem);
                PyTuple_SetItem(result, r - 1, co->elem);
                return result;
            }
            else
            {
                if (PyList_Append(pool, co->elem) == -1)
                {
                    goto empty;
                }
                Py_DECREF(co->elem);
                co->elem = PyIter_Next(it);
                if(co->elem == NULL)
                {
                    goto empty;
                }
                for (j=0 ; j < r - 1  ; j++)
                {
                    indices[j] = -1;
                }
                *i = r - 2;
            }
        }
    }
empty:
    //Py_XDECREF(result);
    co->stopped = 1;
    return NULL;
}

PyDoc_STRVAR(combinations_doc,
             "combinations(iterable, r) --> combinations object\n\
             \n\
             Return successive r-length combinations of elements in the iterable.\n\n\
             combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)");

static PyTypeObject combinations_type = {
    PyVarObject_HEAD_INIT(NULL, 0)
    "itertools2.combinations",                   /* tp_name */
    sizeof(combinationsobject),         /* tp_basicsize */
    0,                                  /* tp_itemsize */
    /* methods */
    (destructor)combinations_dealloc,           /* tp_dealloc */
    0,                                  /* tp_print */
    0,                                  /* tp_getattr */
    0,                                  /* tp_setattr */
    0,                                  /* tp_compare */
    0,                                  /* tp_repr */
    0,                                  /* tp_as_number */
    0,                                  /* tp_as_sequence */
    0,                                  /* tp_as_mapping */
    0,                                  /* tp_hash */
    0,                                  /* tp_call */
    0,                                  /* tp_str */
    PyObject_GenericGetAttr,            /* tp_getattro */
    0,                                  /* tp_setattro */
    0,                                  /* tp_as_buffer */
    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
    Py_TPFLAGS_BASETYPE,            /* tp_flags */
    combinations_doc,                           /* tp_doc */
    (traverseproc)combinations_traverse,        /* tp_traverse */
    0,                                  /* tp_clear */
    0,                                  /* tp_richcompare */
    0,                                  /* tp_weaklistoffset */
    PyObject_SelfIter,                  /* tp_iter */
    (iternextfunc)combinations_next,            /* tp_iternext */
    0,                                  /* tp_methods */
    0,                                  /* tp_members */
    0,                                  /* tp_getset */
    0,                                  /* tp_base */
    0,                                  /* tp_dict */
    0,                                  /* tp_descr_get */
    0,                                  /* tp_descr_set */
    0,                                  /* tp_dictoffset */
    0,                                  /* tp_init */
    0,                                  /* tp_alloc */
    combinations_new,                           /* tp_new */
    PyObject_GC_Del,                    /* tp_free */
};



/* module level code ********************************************************/

PyDoc_STRVAR(module_doc,
             "");


static PyMethodDef module_methods[] = {
  {NULL, NULL}           /* sentinel */
};

PyMODINIT_FUNC
inititertools2(void)
{
    int i;
    PyObject *m;
    char *name;
    PyTypeObject *typelist[] = {
       &combinations_type,
       NULL
   };

    m = Py_InitModule3("itertools2", module_methods, module_doc);
    if (m == NULL)
        return;

    for (i=0 ; typelist[i] != NULL ; i++)
    {
        if (PyType_Ready(typelist[i]) < 0)
            return;
        name = strchr(typelist[i]->tp_name, '.');
        assert (name != NULL);
        Py_INCREF(typelist[i]);
        PyModule_AddObject(m, name+1, (PyObject *)typelist[i]);
    }
}
広告を非表示にする