Best way to do enum in Sqlalchemy?

后端 未结 4 731
挽巷
挽巷 2020-12-04 11:57

I\'m reading about sqlalchemy and I saw following code:

employees_table = Table(\'employees\', metadata,
    Column(\'employee_id\', Integer, primary_key=Tru         


        
4条回答
  •  暗喜
    暗喜 (楼主)
    2020-12-04 12:37

    Python's enumerated types are directly acceptable by the SQLAlchemy Enum type as of SQLAlchemy 1.1:

    import enum
    from sqlalchemy import Integer, Enum
    
    class MyEnum(enum.Enum):
        one = 1
        two = 2
        three = 3
    
    class MyClass(Base):
        __tablename__ = 'some_table'
        id = Column(Integer, primary_key=True)
        value = Column(Enum(MyEnum))
    

    Note that above, the string values "one", "two", "three" are persisted, not the integer values.

    For older versions of SQLAlchemy, I wrote a post which creates its own Enumerated type (http://techspot.zzzeek.org/2011/01/14/the-enum-recipe/)

    from sqlalchemy.types import SchemaType, TypeDecorator, Enum
    from sqlalchemy import __version__
    import re
    
    if __version__ < '0.6.5':
        raise NotImplementedError("Version 0.6.5 or higher of SQLAlchemy is required.")
    
    class EnumSymbol(object):
        """Define a fixed symbol tied to a parent class."""
    
        def __init__(self, cls_, name, value, description):
            self.cls_ = cls_
            self.name = name
            self.value = value
            self.description = description
    
        def __reduce__(self):
            """Allow unpickling to return the symbol 
            linked to the DeclEnum class."""
            return getattr, (self.cls_, self.name)
    
        def __iter__(self):
            return iter([self.value, self.description])
    
        def __repr__(self):
            return "<%s>" % self.name
    
    class EnumMeta(type):
        """Generate new DeclEnum classes."""
    
        def __init__(cls, classname, bases, dict_):
            cls._reg = reg = cls._reg.copy()
            for k, v in dict_.items():
                if isinstance(v, tuple):
                    sym = reg[v[0]] = EnumSymbol(cls, k, *v)
                    setattr(cls, k, sym)
            return type.__init__(cls, classname, bases, dict_)
    
        def __iter__(cls):
            return iter(cls._reg.values())
    
    class DeclEnum(object):
        """Declarative enumeration."""
    
        __metaclass__ = EnumMeta
        _reg = {}
    
        @classmethod
        def from_string(cls, value):
            try:
                return cls._reg[value]
            except KeyError:
                raise ValueError(
                        "Invalid value for %r: %r" % 
                        (cls.__name__, value)
                    )
    
        @classmethod
        def values(cls):
            return cls._reg.keys()
    
        @classmethod
        def db_type(cls):
            return DeclEnumType(cls)
    
    class DeclEnumType(SchemaType, TypeDecorator):
        def __init__(self, enum):
            self.enum = enum
            self.impl = Enum(
                            *enum.values(), 
                            name="ck%s" % re.sub(
                                        '([A-Z])', 
                                        lambda m:"_" + m.group(1).lower(), 
                                        enum.__name__)
                        )
    
        def _set_table(self, table, column):
            self.impl._set_table(table, column)
    
        def copy(self):
            return DeclEnumType(self.enum)
    
        def process_bind_param(self, value, dialect):
            if value is None:
                return None
            return value.value
    
        def process_result_value(self, value, dialect):
            if value is None:
                return None
            return self.enum.from_string(value.strip())
    
    if __name__ == '__main__':
        from sqlalchemy.ext.declarative import declarative_base
        from sqlalchemy import Column, Integer, String, create_engine
        from sqlalchemy.orm import Session
    
        Base = declarative_base()
    
        class EmployeeType(DeclEnum):
            part_time = "P", "Part Time"
            full_time = "F", "Full Time"
            contractor = "C", "Contractor"
    
        class Employee(Base):
            __tablename__ = 'employee'
    
            id = Column(Integer, primary_key=True)
            name = Column(String(60), nullable=False)
            type = Column(EmployeeType.db_type())
    
            def __repr__(self):
                 return "Employee(%r, %r)" % (self.name, self.type)
    
        e = create_engine('sqlite://', echo=True)
        Base.metadata.create_all(e)
    
        sess = Session(e)
    
        sess.add_all([
            Employee(name='e1', type=EmployeeType.full_time),
            Employee(name='e2', type=EmployeeType.full_time),
            Employee(name='e3', type=EmployeeType.part_time),
            Employee(name='e4', type=EmployeeType.contractor),
            Employee(name='e5', type=EmployeeType.contractor),
        ])
        sess.commit()
    
        print sess.query(Employee).filter_by(type=EmployeeType.contractor).all()
    

提交回复
热议问题