Can I override a C++ virtual function within Python with Cython?

前端 未结 2 1358
感情败类
感情败类 2020-12-30 06:40

I have a C++ class with a virtual method:

//C++
class A
{

    public:
        A() {};
        virtual int override_me(int a) {return 2*a;};
        int calc         


        
2条回答
  •  暗喜
    暗喜 (楼主)
    2020-12-30 07:06

    Excellent !

    Not complete but sufficient. I've been able to do the trick for my own purpose. Combining this post with the sources linked above. It's not been easy, since I'm a beginner at Cython, but I confirm that it is the only way I could find over the www.

    Thanks a lot to you guys.

    I am sorry that I don't have so much time go into textual details, but here are my files (might help to get an additional point of view on how to put all of this together)

    setup.py :

    from distutils.core import setup
    from distutils.extension import Extension
    from Cython.Distutils import build_ext
    
    setup(
        cmdclass = {'build_ext': build_ext},
        ext_modules = [
        Extension("elps", 
                  sources=["elps.pyx", "src/ITestClass.cpp"],
                  libraries=["elp"],
                  language="c++",
                  )
        ]
    )
    

    TestClass :

    #ifndef TESTCLASS_H_
    #define TESTCLASS_H_
    
    
    namespace elps {
    
    class TestClass {
    
    public:
        TestClass(){};
        virtual ~TestClass(){};
    
        int getA() { return this->a; };
        virtual int override_me() { return 2; };
        int calculate(int a) { return a * this->override_me(); }
    
    private:
        int a;
    
    };
    
    } /* namespace elps */
    #endif /* TESTCLASS_H_ */
    

    ITestClass.h :

    #ifndef ITESTCLASS_H_
    #define ITESTCLASS_H_
    
    // Created by Cython when providing 'public api' keywords
    #include "../elps_api.h"
    
    #include "../../inc/TestClass.h"
    
    namespace elps {
    
    class ITestClass : public TestClass {
    public:
        PyObject *m_obj;
    
        ITestClass(PyObject *obj);
        virtual ~ITestClass();
        virtual int override_me();
    };
    
    } /* namespace elps */
    #endif /* ITESTCLASS_H_ */
    

    ITestClass.cpp :

    #include "ITestClass.h"
    
    namespace elps {
    
    ITestClass::ITestClass(PyObject *obj): m_obj(obj) {
        // Provided by "elps_api.h"
        if (import_elps()) {
        } else {
            Py_XINCREF(this->m_obj);
        }
    }
    
    ITestClass::~ITestClass() {
        Py_XDECREF(this->m_obj);
    }
    
    int ITestClass::override_me()
    {
        if (this->m_obj) {
            int error;
            // Call a virtual overload, if it exists
            int result = cy_call_func(this->m_obj, (char*)"override_me", &error);
            if (error)
                // Call parent method
                result = TestClass::override_me();
            return result;
        }
        // Throw error ?
        return 0;
    }
    
    } /* namespace elps */
    

    EDIT2 : A note about PURE virtual methods (it appears to be a quite recurrent concern). As shown in the above code, in that particular fashion, "TestClass::override_me()" CANNOT be pure since it has to be callable in case the method is not overridden in the Python's extended class (aka : one doesn't fall in the "error"/"override not found" part of the "ITestClass::override_me()" body).

    Extension : elps.pyx :

    cimport cpython.ref as cpy_ref
    
    cdef extern from "src/ITestClass.h" namespace "elps" :
        cdef cppclass ITestClass:
            ITestClass(cpy_ref.PyObject *obj)
            int getA()
            int override_me()
            int calculate(int a)
    
    cdef class PyTestClass:
        cdef ITestClass* thisptr
    
        def __cinit__(self):
           ##print "in TestClass: allocating thisptr"
           self.thisptr = new ITestClass(self)
        def __dealloc__(self):
           if self.thisptr:
               ##print "in TestClass: deallocating thisptr"
               del self.thisptr
    
        def getA(self):
           return self.thisptr.getA()
    
    #    def override_me(self):
    #        return self.thisptr.override_me()
    
        cpdef int calculate(self, int a):
            return self.thisptr.calculate(a) ;
    
    
    cdef public api int cy_call_func(object self, char* method, int *error):
        try:
            func = getattr(self, method);
        except AttributeError:
            error[0] = 1
        else:
            error[0] = 0
            return func()
    

    Finally, the python calls :

    from elps import PyTestClass as TC;
    
    a = TC(); 
    print a.calculate(1);
    
    class B(TC):
    #   pass
        def override_me(self):
            return 5
    
    b = B()
    print b.calculate(1)
    

    This should make the previous linked work hopefully more straight to the point we're discussing here...

    EDIT : On the other hand the above code could be optimized by using 'hasattr' instead of try/catch block :

    cdef public api int cy_call_func_int_fast(object self, char* method, bint *error):
        if (hasattr(self, method)):
            error[0] = 0
            return getattr(self, method)();
        else:
            error[0] = 1
    

    The above code, of course, makes a difference only in the case where we don't override the 'override_me' method.

提交回复
热议问题