Update the Python2-compatible version of datanommer.models

Signed-off-by: Aurélien Bompard <aurelien@bompard.org>
This commit is contained in:
Aurélien Bompard 2021-10-21 12:09:06 +02:00
parent 46ec765e00
commit 208429e3b6
No known key found for this signature in database
GPG key ID: 31584CFEB9BF64AD

View file

@ -13,10 +13,18 @@
# #
# You should have received a copy of the GNU General Public License along # You should have received a copy of the GNU General Public License along
# with this program. If not, see <http://www.gnu.org/licenses/>. # with this program. If not, see <http://www.gnu.org/licenses/>.
# This is a Python2-compatible file for the badges app, that is still running on Python2.
# Compatibility fixes done by pasteurize: http://python-future.org/pasteurize.html
from __future__ import absolute_import, division, print_function, unicode_literals
import datetime import datetime
import json
import logging import logging
import math import math
import traceback import traceback
import uuid
from warnings import warn from warnings import warn
import pkg_resources import pkg_resources
@ -32,20 +40,28 @@ from sqlalchemy import (
Integer, Integer,
not_, not_,
or_, or_,
String,
Table, Table,
TypeDecorator,
Unicode, Unicode,
UnicodeText, UnicodeText,
UniqueConstraint, UniqueConstraint,
) )
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import (
from sqlalchemy.orm import relationship, scoped_session, sessionmaker, validates declarative_base,
relationship,
scoped_session,
sessionmaker,
validates,
)
from sqlalchemy.sql import operators
try: try:
from psycopg2.errors import UniqueViolation from psycopg2.errors import UniqueViolation
except ImportError: except ImportError: # pragma: no cover
from psycopg2.errorcodes import lookup as lookup_error from psycopg2.errorcodes import lookup as lookup_error
UniqueViolation = lookup_error("23505") UniqueViolation = lookup_error("23505")
@ -83,17 +99,17 @@ def init(uri=None, alembic_ini=None, engine=None, create=False):
session.configure(bind=engine) session.configure(bind=engine)
DeclarativeBase.query = session.query_property() DeclarativeBase.query = session.query_property()
# Loads the alembic configuration and generates the version table, with
# the most recent revision stamped as head
if alembic_ini is not None:
from alembic import command
from alembic.config import Config
alembic_cfg = Config(alembic_ini)
command.stamp(alembic_cfg, "head")
if create: if create:
session.execute("CREATE EXTENSION IF NOT EXISTS timescaledb")
DeclarativeBase.metadata.create_all(engine) DeclarativeBase.metadata.create_all(engine)
# Loads the alembic configuration and generates the version table, with
# the most recent revision stamped as head
if alembic_ini is not None: # pragma: no cover
from alembic import command
from alembic.config import Config
alembic_cfg = Config(alembic_ini)
command.stamp(alembic_cfg, "head")
def add(message): def add(message):
@ -133,6 +149,35 @@ def source_version_default(context):
return dist.version return dist.version
# https://docs.sqlalchemy.org/en/14/core/custom_types.html#marshal-json-strings
class JSONEncodedDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string."""
impl = UnicodeText
cache_ok = True
def process_bind_param(self, value, dialect):
if value is not None:
value = json.dumps(value)
return value
def process_result_value(self, value, dialect):
if value is not None:
value = json.loads(value)
return value
def coerce_compared_value(self, op, value):
# https://docs.sqlalchemy.org/en/14/core/custom_types.html#dealing-with-comparison-operations
if op in (operators.like_op, operators.not_like_op):
return String()
else:
return self
users_assoc_table = Table( users_assoc_table = Table(
"users_messages", "users_messages",
DeclarativeBase.metadata, DeclarativeBase.metadata,
@ -166,7 +211,7 @@ class Message(DeclarativeBase):
crypto = Column(UnicodeText) crypto = Column(UnicodeText)
source_name = Column(Unicode, default="datanommer") source_name = Column(Unicode, default="datanommer")
source_version = Column(Unicode, default=source_version_default) source_version = Column(Unicode, default=source_version_default)
msg = Column(postgresql.JSONB, nullable=False) msg = Column(JSONEncodedDict, nullable=False)
headers = Column(postgresql.JSONB(none_as_null=True)) headers = Column(postgresql.JSONB(none_as_null=True))
users = relationship( users = relationship(
"User", "User",
@ -206,6 +251,9 @@ class Message(DeclarativeBase):
def create(cls, **kwargs): def create(cls, **kwargs):
users = kwargs.pop("users") users = kwargs.pop("users")
packages = kwargs.pop("packages") packages = kwargs.pop("packages")
if not kwargs.get("msg_id"):
log.info("Message on %s was received without a msg_id", kwargs["topic"])
kwargs["msg_id"] = str(uuid.uuid4())
obj = cls(**kwargs) obj = cls(**kwargs)
try: try:
@ -273,9 +321,14 @@ class Message(DeclarativeBase):
) )
def as_fedora_message_dict(self): def as_fedora_message_dict(self):
headers = self.headers
if "sent-at" not in headers:
headers["sent-at"] = self.timestamp.astimezone(
datetime.timezone.utc
).isoformat()
return dict( return dict(
body=self.msg, body=self.msg,
headers=self.headers, headers=headers,
id=self.msg_id, id=self.msg_id,
queue=None, queue=None,
topic=self.topic, topic=self.topic,
@ -389,16 +442,18 @@ class Message(DeclarativeBase):
if contains: if contains:
query = query.filter( query = query.filter(
or_(*(Message._msg.like("%%%s%%" % contain) for contain in contains)) or_(*(Message.msg.like("%{}%".format(contain)) for contain in contains))
) )
# And then the four negative filters as necessary # And then the four negative filters as necessary
if not_users: if not_users:
query = query.filter(not_(or_(*(Message.users.any(u) for u in not_users)))) query = query.filter(
not_(or_(*(Message.users.any(User.name == u) for u in not_users)))
)
if not_packs: if not_packs:
query = query.filter( query = query.filter(
not_(or_(*(Message.packages.any(p) for p in not_packs))) not_(or_(*(Message.packages.any(Package.name == p) for p in not_packs)))
) )
if not_cats: if not_cats:
@ -429,10 +484,10 @@ class Message(DeclarativeBase):
return total, pages, messages return total, pages, messages
class NamedSingleton: class NamedSingleton(object):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(UnicodeText, index=True) name = Column(UnicodeText, index=True, unique=True)
@classmethod @classmethod
def get_or_create(cls, name): def get_or_create(cls, name):