CmpToKeyをCythonで



名前順でソート(XPスタイル)

で引用したCmpToKeyを、Cythonで書きなおしてみた。


参考:

Cythonを使ってみた - loooo




まず本体

#encoding:utf-8
#cmp_to_key.pyx
from functools import partial

cdef class _K(object):
    cdef readonly cmp, obj
    
    def __cinit__(self, cmp, obj):
        self.obj = obj
        self.cmp = cmp
    
    def __richcmp__(x, y, int op):
        if op == 0:
            return x.cmp(x.obj, y.obj) == -1
        elif op == 2:
            return x.cmp(x.obj, y.obj) == 0
        elif op == 4:
            return x.cmp(x.obj, y.obj) == 1
        elif op == 1:
            return x.cmp(x.obj, y.obj) != 1  
        elif op == 3:
            return x.cmp(x.obj, y.obj) != 0
        elif op == 5:
            return x.cmp(x.obj, y.obj) != -1

def CmpToKey(mycmp):
    return partial(_K, mycmp)

distutil用スクリプト

#encoding:utf-8
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

setup(
    cmdclass = {'build_ext': build_ext },
    ext_modules = [Extension('cmp_to_key', ['cmp_to_key.pyx'])],
    )



コンパイルはdistutilのいつもの方法で、

python setup.py build



purepythonの場合と、パフォーマンスを比べてみます。速くなっていればいいな。

#encoding:utf-8
from __future__ import (
    with_statement, 
    division,
    print_function,
)

from myutil import *

import ctypes  
SHLWAPI = ctypes.windll.LoadLibrary("SHLWAPI.dll")  
def cmp_filename_logical(f1, f2):  
    return SHLWAPI.StrCmpLogicalW(unicode(f1), unicode(f2))  

from cmp_to_key import CmpToKey

def sorted_cy(alist):
    return sorted(alist, key=CmpToKey(cmp_filename_logical))

def PyCmpToKey(mycmp):
    'Convert a cmp= function into a key= function'
    class K(object):
        def __init__(self, obj, *args):
            self.obj = obj
        def __lt__(self, other):
            return mycmp(self.obj, other.obj) == -1
        def __gt__(self, other):
            return mycmp(self.obj, other.obj) == 1
        def __eq__(self, other):
            return mycmp(self.obj, other.obj) == 0
        def __le__(self, other):
            return mycmp(self.obj, other.obj) != 1  
        def __ge__(self, other):
            return mycmp(self.obj, other.obj) != -1
        def __ne__(self, other):
            return mycmp(self.obj, other.obj) != 0
    return K

def sorted_py(alist):
    return sorted(alist, key=PyCmpToKey(cmp_filename_logical))

def sorted_py_cmp(alist):
    return sorted(alist, cmp=cmp_filename_logical)

def main():
    N = 1000
    TestList = [random.randrange(N) for i in xrange(N)]
    
    from timeit import Timer
    for f in [sorted_cy, sorted_py, sorted_py_cmp]:
        print(f.__name__)
        t = Timer(lambda :f(TestList))
        print(t.timeit(100))
    
#結果
#sorted_cy
#2.97824779406
#sorted_py
#3.38345068996
#sorted_py_cmp
#2.53016179431


ウ〜ン、微妙。この程度ならPython版でもいいかも?

上手に書けば、また違うのかもしれませんが。

広告を非表示にする