SQLAlchemy, array_agg, and matching an input list

匿名 (未验证) 提交于 2019-12-03 01:01:02

问题:

I am attempting to use SQLAlchemy more fully, rather than just falling back to pure SQL at the first sign of distress. In this case, I have a table in a Postgres database (9.5) which stores a set of integers as a group by associating individual items atom_id with a group identifier group_id.

Given a list of atom_ids, I'd like to be able to figure out which group_id, if any, that set of atom_ids belong to. Solving this with just the group_id and atom_id columns was straightforward.

Now I'm trying to generalize such that a 'group' is made up of not just a list of atom_ids, but other context as well. In the example below, the list is ordered by including a sequence column, but conceptually other columns could be used instead, such as a weight column which gives each atom_id a [0,1] floating point value representing that atom's 'share' of the group.

Below is most of a unit test demonstrating my issue.

First, some setup:

def test_multi_column_grouping(self):     class MultiColumnGroups(base.Base):         __tablename__ = 'multi_groups'          group_id = Column(Integer)         atom_id = Column(Integer)         sequence = Column(Integer)  # arbitrary 'other' column.  In this case, an integer, but it could be a float (e.g. weighting factor)      base.Base.metadata.create_all(self.engine)      # Insert 6 rows representing 2 different 'groups' of values     vals = [         # Group 1         {'group_id': 1, 'atom_id': 1, 'sequence': 1},         {'group_id': 1, 'atom_id': 2, 'sequence': 2},         {'group_id': 1, 'atom_id': 3, 'sequence': 3},         # Group 2         {'group_id': 2, 'atom_id': 1, 'sequence': 3},         {'group_id': 2, 'atom_id': 2, 'sequence': 2},         {'group_id': 2, 'atom_id': 3, 'sequence': 1},     ]      self.session.bulk_save_objects(         [MultiColumnGroups(**x) for x in vals])     self.session.flush()      self.assertEqual(6, len(self.session.query(MultiColumnGroups).all())) 

Now, I want to query the above table to find which group a specific set of inputs belongs to. I'm using a list of (named) tuples to represent the query parameters.

    from collections import namedtuple     Entity = namedtuple('Entity', ['atom_id', 'sequence'])     values_to_match = [         # (atom_id, sequence)         Entity(1, 3),         Entity(2, 2),         Entity(3, 1),         ]     # The above list _should_ match with `group_id == 2` 

Raw SQL solution. I'd prefer not to fall back on this, as a part of this exercise is to learn more SQLAlchemy.

    r = self.session.execute('''         select group_id         from multi_groups         group by group_id         having array_agg((atom_id, sequence)) = :query_tuples         ''', {'query_tuples': values_to_match}).fetchone()     print(r)  # > (2,)     self.assertEqual(2, r[0]) 

Here is the above raw-SQL solution converted fairly directly into a broken SQLAlchemy query. Running this produces a psycopg2 error: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]. I believe that I need to cast the array_agg into an int[]? That would work so long as the grouping columns are all integers (which, if need be, is an acceptable limitation), but ideally this would work with mixed-type input tuples / table columns.

    from sqlalchemy import tuple_     from sqlalchemy.dialects.postgresql import array_agg      existing_group = self.session.query(MultiColumnGroups).\         with_entities(MultiColumnGroups.group_id).\         group_by(MultiColumnGroups.group_id).\         having(array_agg(tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.sequence)) == values_to_match).\         one_or_none()      self.assertIsNotNone(existing_group)     print('|{}|'.format(existing_group)) 

Is the above session.query() close? Have I blinded myself here, and am missing something super obvious that would solve this problem in some other way?

回答1:

I think your solution would produce indeterminate results, because the rows within a group are in unspecified order, and so the comparison between the array aggregate and given array may produce true or false based on that:

[local]:5432 u@sopython*=> select group_id [local] u@sopython- > from multi_groups  [local] u@sopython- > group by group_id [local] u@sopython- > having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];  group_id  ----------         2 (1 row)  [local]:5432 u@sopython*=> update multi_groups set atom_id = atom_id where atom_id = 2; UPDATE 2 [local]:5432 u@sopython*=> select group_id                                              from multi_groups  group by group_id having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];  group_id  ---------- (0 rows) 

You could apply an ordering to both, or try something entirely different: instead of array comparison you could use relational division.

In order to divide you have to form a temporary relation from your list of Entity records. Again, there are many ways to approach that. Here's one using unnested arrays:

In [112]: vtm = select([      ...:     func.unnest(postgresql.array([      ...:         getattr(e, f) for e in values_to_match      ...:     ])).label(f)      ...:     for f in Entity._fields      ...: ]).alias() 

And another using a union:

In [114]: vtm = union_all(*[      ...:     select([literal(e.atom_id).label('atom_id'),      ...:             literal(e.sequence).label('sequence')])      ...:     for e in values_to_match      ...: ]).alias() 

A temporary table would do as well.

With the new relation at hand you want to find the answer to "find those multi_groups for which no entity exists that is not in the group". It's a horrible sentence, but makes sense:

In [117]: mg = aliased(MultiColumnGroups)  In [119]: session.query(MultiColumnGroups.group_id).\      ...:     filter(~exists().      ...:         select_from(vtm).      ...:         where(~exists().      ...:             where(MultiColumnGroups.group_id == mg.group_id).      ...:             where(tuple_(vtm.c.atom_id, vtm.c.sequence) ==      ...:                   tuple_(mg.atom_id, mg.sequence)).      ...:             correlate_except(mg))).\      ...:     distinct().\      ...:     all()      ...:  Out[119]: [(2)] 

On the other hand you could also just select the intersection of groups with the given entities:

In [19]: gs = intersect(*[     ...:     session.query(MultiColumnGroups.group_id).     ...:         filter(MultiColumnGroups.atom_id == vtm.atom_id,     ...:                MultiColumnGroups.sequence == vtm.sequence)     ...:     for vtm in values_to_match     ...: ])  In [20]: session.execute(gs).fetchall() Out[20]: [(2,)] 

The error

ProgrammingError: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[] LINE 3: ...gg((multi_groups.atom_id, multi_groups.sequence)) = ARRAY[AR...                                                              ^ HINT:  No operator matches the given name and argument type(s). You might need to add explicit type casts.  [SQL: 'SELECT multi_groups.group_id AS multi_groups_group_id \nFROM multi_groups GROUP BY multi_groups.group_id \nHAVING array_agg((multi_groups.atom_id, multi_groups.sequence)) = %(array_agg_1)s'] [parameters: {'array_agg_1': [[1, 3], [2, 2], [3, 1]]}] (Background on this error at: http://sqlalche.me/e/f405) 

is a result of how your values_to_match is first converted to a list of lists (for reasons unknown) and then converted to an array by your DB-API driver. It results in an array of array of integer, not an array of record (int, int). Using a raw DB-API connection and cursor, passing a list of tuples works as you'd expect.

In SQLAlchemy if you wrap the list values_to_match with sqlalchemy.dialects.postgresql.array(), it works as you meant it to work, though remember that the results are indeterminate.



回答2:

I found your answer very helpful as well. As I do not have enough reputation to comment on your solution, I'll post the changes that I made based off of your help.

I found the double negative sql to generate some less than ideal sql so I worked backwards from the sql to find something a bit cleaner.

Here is some simple data. The example has been slightly modified to use a text field role instead of a sequence field. This should be generalizable to other types as well:

drop table if exists multi_groups; create table multi_groups (group_id, atom_id, role) as values   (1, 1, 'referrer'),   (1, 2, 'rendering'),   (1, 3, 'attending'),   (2, 1, 'attending'),   (2, 2, 'rendering'),   (2, 3, 'referrer'); 

The original solution generated sql similar to:

select distinct   dim_staging.multi_groups.group_id as dim_staging_multi_groups_group_id from dim_staging.multi_groups where not (   exists (     select *     from (            select              unnest(                array[1, 2, 3]              ) as atom_id,              unnest(                array['referrer', 'rendering', 'attending']              ) as role     ) as anon_1     where not (       exists (         select *         from dim_staging.multi_groups as multi_groups_1         where dim_staging.multi_groups.group_id = multi_groups_1.group_id           and (anon_1.atom_id, anon_1.role) = (multi_groups_1.atom_id, multi_groups_1.role)       )     )   ) ); 

I used that and worked on the sql a bit to get:

with vtm as (   select     unnest(array[1, 2, 3]) as atom_id,     unnest(array['attending', 'rendering', 'referrer']) as role ), matched as (   select     dim_staging.multi_groups.group_id as group_id,     vtm.atom_id as atom_id,     3 as cnt   from dim_staging.multi_groups   full outer join vtm     on (vtm.atom_id, vtm.role) = (dim_staging.multi_groups.atom_id, dim_staging.multi_groups.role) ) select matched.group_id from matched where not (   exists (     select *     from matched     where matched.group_id is null   ) ) group by matched.group_id having count(1) filter (where matched.atom_id is null) = 0   and count(1) = matched.cnt; 

Here is a full test script to demonstrate creating the above sql

from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base import os from sqlalchemy import (     Column,     Integer,     Text ) from sqlalchemy.sql.expression import func, select, tuple_, exists, join, literal, label from sqlalchemy.dialects import postgresql from collections import namedtuple   db_url = os.getenv('DB_URL', 'postgresql://localhost:5432/dw') engine = create_engine(db_url, echo=False) Session = sessionmaker(bind=engine) session = Session()   Base = declarative_base()   class MultiColumnGroups(Base):     __tablename__ = 'multi_groups'     id = Column(Integer, primary_key=True)     group_id = Column(Integer)     atom_id = Column(Integer)     role = Column(Text)   Base.metadata.drop_all(engine, [MultiColumnGroups.__table__]) Base.metadata.create_all(engine, [MultiColumnGroups.__table__])  vals = [     # Group 1     {'group_id': 1, 'atom_id': 1, 'role': 'referrer'},     {'group_id': 1, 'atom_id': 2, 'role': 'rendering'},     {'group_id': 1, 'atom_id': 3, 'role': 'attending'},     # Group 2     {'group_id': 2, 'atom_id': 1, 'role': 'attending'},     {'group_id': 2, 'atom_id': 2, 'role': 'rendering'},     {'group_id': 2, 'atom_id': 3, 'role': 'referrer'}, ]  session.bulk_save_objects(     [MultiColumnGroups(**x) for x in vals] ) session.commit()  Entity = namedtuple('Entity', ['atom_id', 'role']) values_to_match = [     # (atom_id, role)     # Entity(1, 'referrer'),     # Entity(2, 'rendering'),     # Entity(3, 'attending'),     Entity(1, 'attending'),     Entity(2, 'rendering'),     Entity(3, 'referrer'), ]  vtm = select(     [         func.unnest(             postgresql.array([                 getattr(e, f) for e in values_to_match                 ]             )         ).label(f)         for f in Entity._fields     ] ).cte(name='vtm')  j = join(     MultiColumnGroups, vtm,     tuple_(vtm.c.atom_id, vtm.c.role) == tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.role),     full=True ) matched = select([   MultiColumnGroups.group_id,   vtm.c.atom_id,   label(     'cnt',     literal(len(values_to_match),type_=Integer    ) )]).select_from(j).cte(name='matched')  group_id = session.query(matched.c.group_id).\     filter(         ~exists().         select_from(matched).         where(matched.c.group_id == None)     ).\     group_by(matched.c.group_id).\     having(func.count(1).filter(matched.c.atom_id == None) == 0).\     having(func.count(1) == matched.c.cnt).one().group_id  print(group_id) 

EDIT: The case of whether a subgroup exists causing multiple matches is solved by including the number of values being compared as a count in the query and checking that the matched grouping count is equal to the number of the values. Sorry for the oversight.



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!