可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
I am attempting to use SQLAlchemy more fully, rather than just falling back to pure SQL at the first sign of distress. In this case, I have a table in a Postgres database (9.5) which stores a set of integers as a group by associating individual items atom_id
with a group identifier group_id
.
Given a list of atom_ids
, I'd like to be able to figure out which group_id
, if any, that set of atom_ids
belong to. Solving this with just the group_id
and atom_id
columns was straightforward.
Now I'm trying to generalize such that a 'group' is made up of not just a list of atom_ids
, but other context as well. In the example below, the list is ordered by including a sequence
column, but conceptually other columns could be used instead, such as a weight
column which gives each atom_id
a [0,1] floating point value representing that atom's 'share' of the group.
Below is most of a unit test demonstrating my issue.
First, some setup:
def test_multi_column_grouping(self): class MultiColumnGroups(base.Base): __tablename__ = 'multi_groups' group_id = Column(Integer) atom_id = Column(Integer) sequence = Column(Integer) # arbitrary 'other' column. In this case, an integer, but it could be a float (e.g. weighting factor) base.Base.metadata.create_all(self.engine) # Insert 6 rows representing 2 different 'groups' of values vals = [ # Group 1 {'group_id': 1, 'atom_id': 1, 'sequence': 1}, {'group_id': 1, 'atom_id': 2, 'sequence': 2}, {'group_id': 1, 'atom_id': 3, 'sequence': 3}, # Group 2 {'group_id': 2, 'atom_id': 1, 'sequence': 3}, {'group_id': 2, 'atom_id': 2, 'sequence': 2}, {'group_id': 2, 'atom_id': 3, 'sequence': 1}, ] self.session.bulk_save_objects( [MultiColumnGroups(**x) for x in vals]) self.session.flush() self.assertEqual(6, len(self.session.query(MultiColumnGroups).all()))
Now, I want to query the above table to find which group a specific set of inputs belongs to. I'm using a list of (named) tuples to represent the query parameters.
from collections import namedtuple Entity = namedtuple('Entity', ['atom_id', 'sequence']) values_to_match = [ # (atom_id, sequence) Entity(1, 3), Entity(2, 2), Entity(3, 1), ] # The above list _should_ match with `group_id == 2`
Raw SQL solution. I'd prefer not to fall back on this, as a part of this exercise is to learn more SQLAlchemy.
r = self.session.execute(''' select group_id from multi_groups group by group_id having array_agg((atom_id, sequence)) = :query_tuples ''', {'query_tuples': values_to_match}).fetchone() print(r) # > (2,) self.assertEqual(2, r[0])
Here is the above raw-SQL solution converted fairly directly into a broken SQLAlchemy query. Running this produces a psycopg2 error: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]
. I believe that I need to cast the array_agg
into an int[]
? That would work so long as the grouping columns are all integers (which, if need be, is an acceptable limitation), but ideally this would work with mixed-type input tuples / table columns.
from sqlalchemy import tuple_ from sqlalchemy.dialects.postgresql import array_agg existing_group = self.session.query(MultiColumnGroups).\ with_entities(MultiColumnGroups.group_id).\ group_by(MultiColumnGroups.group_id).\ having(array_agg(tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.sequence)) == values_to_match).\ one_or_none() self.assertIsNotNone(existing_group) print('|{}|'.format(existing_group))
Is the above session.query()
close? Have I blinded myself here, and am missing something super obvious that would solve this problem in some other way?
回答1:
I think your solution would produce indeterminate results, because the rows within a group are in unspecified order, and so the comparison between the array aggregate and given array may produce true or false based on that:
[local]:5432 u@sopython*=> select group_id [local] u@sopython- > from multi_groups [local] u@sopython- > group by group_id [local] u@sopython- > having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)]; group_id ---------- 2 (1 row) [local]:5432 u@sopython*=> update multi_groups set atom_id = atom_id where atom_id = 2; UPDATE 2 [local]:5432 u@sopython*=> select group_id from multi_groups group by group_id having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)]; group_id ---------- (0 rows)
You could apply an ordering to both, or try something entirely different: instead of array comparison you could use relational division.
In order to divide you have to form a temporary relation from your list of Entity
records. Again, there are many ways to approach that. Here's one using unnested arrays:
In [112]: vtm = select([ ...: func.unnest(postgresql.array([ ...: getattr(e, f) for e in values_to_match ...: ])).label(f) ...: for f in Entity._fields ...: ]).alias()
And another using a union:
In [114]: vtm = union_all(*[ ...: select([literal(e.atom_id).label('atom_id'), ...: literal(e.sequence).label('sequence')]) ...: for e in values_to_match ...: ]).alias()
A temporary table would do as well.
With the new relation at hand you want to find the answer to "find those multi_groups
for which no entity exists that is not in the group". It's a horrible sentence, but makes sense:
In [117]: mg = aliased(MultiColumnGroups) In [119]: session.query(MultiColumnGroups.group_id).\ ...: filter(~exists(). ...: select_from(vtm). ...: where(~exists(). ...: where(MultiColumnGroups.group_id == mg.group_id). ...: where(tuple_(vtm.c.atom_id, vtm.c.sequence) == ...: tuple_(mg.atom_id, mg.sequence)). ...: correlate_except(mg))).\ ...: distinct().\ ...: all() ...: Out[119]: [(2)]
On the other hand you could also just select the intersection of groups with the given entities:
In [19]: gs = intersect(*[ ...: session.query(MultiColumnGroups.group_id). ...: filter(MultiColumnGroups.atom_id == vtm.atom_id, ...: MultiColumnGroups.sequence == vtm.sequence) ...: for vtm in values_to_match ...: ]) In [20]: session.execute(gs).fetchall() Out[20]: [(2,)]
The error
ProgrammingError: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[] LINE 3: ...gg((multi_groups.atom_id, multi_groups.sequence)) = ARRAY[AR... ^ HINT: No operator matches the given name and argument type(s). You might need to add explicit type casts. [SQL: 'SELECT multi_groups.group_id AS multi_groups_group_id \nFROM multi_groups GROUP BY multi_groups.group_id \nHAVING array_agg((multi_groups.atom_id, multi_groups.sequence)) = %(array_agg_1)s'] [parameters: {'array_agg_1': [[1, 3], [2, 2], [3, 1]]}] (Background on this error at: http://sqlalche.me/e/f405)
is a result of how your values_to_match
is first converted to a list of lists (for reasons unknown) and then converted to an array by your DB-API driver. It results in an array of array of integer, not an array of record (int, int). Using a raw DB-API connection and cursor, passing a list of tuples works as you'd expect.
In SQLAlchemy if you wrap the list values_to_match
with sqlalchemy.dialects.postgresql.array()
, it works as you meant it to work, though remember that the results are indeterminate.
回答2:
I found your answer very helpful as well. As I do not have enough reputation to comment on your solution, I'll post the changes that I made based off of your help.
I found the double negative sql to generate some less than ideal sql so I worked backwards from the sql to find something a bit cleaner.
Here is some simple data. The example has been slightly modified to use a text field role instead of a sequence field. This should be generalizable to other types as well:
drop table if exists multi_groups; create table multi_groups (group_id, atom_id, role) as values (1, 1, 'referrer'), (1, 2, 'rendering'), (1, 3, 'attending'), (2, 1, 'attending'), (2, 2, 'rendering'), (2, 3, 'referrer');
The original solution generated sql similar to:
select distinct dim_staging.multi_groups.group_id as dim_staging_multi_groups_group_id from dim_staging.multi_groups where not ( exists ( select * from ( select unnest( array[1, 2, 3] ) as atom_id, unnest( array['referrer', 'rendering', 'attending'] ) as role ) as anon_1 where not ( exists ( select * from dim_staging.multi_groups as multi_groups_1 where dim_staging.multi_groups.group_id = multi_groups_1.group_id and (anon_1.atom_id, anon_1.role) = (multi_groups_1.atom_id, multi_groups_1.role) ) ) ) );
I used that and worked on the sql a bit to get:
with vtm as ( select unnest(array[1, 2, 3]) as atom_id, unnest(array['attending', 'rendering', 'referrer']) as role ), matched as ( select dim_staging.multi_groups.group_id as group_id, vtm.atom_id as atom_id, 3 as cnt from dim_staging.multi_groups full outer join vtm on (vtm.atom_id, vtm.role) = (dim_staging.multi_groups.atom_id, dim_staging.multi_groups.role) ) select matched.group_id from matched where not ( exists ( select * from matched where matched.group_id is null ) ) group by matched.group_id having count(1) filter (where matched.atom_id is null) = 0 and count(1) = matched.cnt;
Here is a full test script to demonstrate creating the above sql
from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base import os from sqlalchemy import ( Column, Integer, Text ) from sqlalchemy.sql.expression import func, select, tuple_, exists, join, literal, label from sqlalchemy.dialects import postgresql from collections import namedtuple db_url = os.getenv('DB_URL', 'postgresql://localhost:5432/dw') engine = create_engine(db_url, echo=False) Session = sessionmaker(bind=engine) session = Session() Base = declarative_base() class MultiColumnGroups(Base): __tablename__ = 'multi_groups' id = Column(Integer, primary_key=True) group_id = Column(Integer) atom_id = Column(Integer) role = Column(Text) Base.metadata.drop_all(engine, [MultiColumnGroups.__table__]) Base.metadata.create_all(engine, [MultiColumnGroups.__table__]) vals = [ # Group 1 {'group_id': 1, 'atom_id': 1, 'role': 'referrer'}, {'group_id': 1, 'atom_id': 2, 'role': 'rendering'}, {'group_id': 1, 'atom_id': 3, 'role': 'attending'}, # Group 2 {'group_id': 2, 'atom_id': 1, 'role': 'attending'}, {'group_id': 2, 'atom_id': 2, 'role': 'rendering'}, {'group_id': 2, 'atom_id': 3, 'role': 'referrer'}, ] session.bulk_save_objects( [MultiColumnGroups(**x) for x in vals] ) session.commit() Entity = namedtuple('Entity', ['atom_id', 'role']) values_to_match = [ # (atom_id, role) # Entity(1, 'referrer'), # Entity(2, 'rendering'), # Entity(3, 'attending'), Entity(1, 'attending'), Entity(2, 'rendering'), Entity(3, 'referrer'), ] vtm = select( [ func.unnest( postgresql.array([ getattr(e, f) for e in values_to_match ] ) ).label(f) for f in Entity._fields ] ).cte(name='vtm') j = join( MultiColumnGroups, vtm, tuple_(vtm.c.atom_id, vtm.c.role) == tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.role), full=True ) matched = select([ MultiColumnGroups.group_id, vtm.c.atom_id, label( 'cnt', literal(len(values_to_match),type_=Integer ) )]).select_from(j).cte(name='matched') group_id = session.query(matched.c.group_id).\ filter( ~exists(). select_from(matched). where(matched.c.group_id == None) ).\ group_by(matched.c.group_id).\ having(func.count(1).filter(matched.c.atom_id == None) == 0).\ having(func.count(1) == matched.c.cnt).one().group_id print(group_id)
EDIT: The case of whether a subgroup exists causing multiple matches is solved by including the number of values being compared as a count in the query and checking that the matched grouping count is equal to the number of the values. Sorry for the oversight.