feat: implement parent tag search (#673)

* feat: implement parent tag search

* feat: add tests for parent tag search

* fix: typo

* feat: log the time it takes to build the SQL Expression

* feat: instead of hardcoding child tag ids into main query, include subquery

* Revert "feat: instead of hardcoding child tag ids into main query, include subquery"

This reverts commit 2615e7dab4.
This commit is contained in:
Jann Stute
2025-01-01 11:36:11 +01:00
committed by GitHub
parent 5a0ba54454
commit 7b672e03a1
3 changed files with 53 additions and 4 deletions

View File

@@ -543,10 +543,18 @@ class Library:
statement = select(Entry)
if search.ast:
start_time = time.time()
statement = statement.outerjoin(Entry.tag_box_fields).where(
SQLBoolExpressionBuilder(self).visit(search.ast)
)
end_time = time.time()
logger.info(
f"SQL Expression Builder finished ({format_timespan(end_time - start_time)})"
)
extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST)

View File

@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from sqlalchemy import and_, distinct, func, or_, select
import structlog
from sqlalchemy import and_, distinct, func, or_, select, text
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
@@ -16,6 +17,20 @@ if TYPE_CHECKING:
else:
Library = None # don't import .library because of circular imports
logger = structlog.get_logger(__name__)
CHILDREN_QUERY = text("""
-- Note for this entire query that tag_subtags.child_id is the parent id and tag_subtags.parent_id is the child id due to bad naming
WITH RECURSIVE Subtags AS (
SELECT :tag_id AS child_id
UNION ALL
SELECT ts.parent_id AS child_id
FROM tag_subtags ts
INNER JOIN Subtags s ON ts.child_id = s.child_id
)
SELECT * FROM Subtags;
""") # noqa: E501
def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
for s in FILETYPE_EQUIVALENTS:
@@ -98,16 +113,28 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
def visit_not(self, node: Not) -> ColumnExpressionArgument:
return ~self.__entry_satisfies_ast(node.child)
def __get_tag_ids(self, tag_name: str) -> list[int]:
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
"""Given a tag name find the ids of all tags that this name could refer to."""
with Session(self.lib.engine, expire_on_commit=False) as session:
return list(
with Session(self.lib.engine) as session:
tag_ids = list(
session.scalars(
select(Tag.id)
.where(or_(Tag.name.ilike(tag_name), Tag.shorthand.ilike(tag_name)))
.union(select(TagAlias.tag_id).where(TagAlias.name.ilike(tag_name)))
)
)
if len(tag_ids) > 1:
logger.debug(
f'Tag Constraint "{tag_name}" is ambiguous, {len(tag_ids)} matching tags found',
tag_ids=tag_ids,
include_children=include_children,
)
if not include_children:
return tag_ids
outp = []
for tag_id in tag_ids:
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
return outp
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""

View File

@@ -116,6 +116,20 @@ def test_parentheses(search_library: Library, query: str, count: int):
verify_count(search_library, query, count)
@pytest.mark.parametrize(
["query", "count"],
[
("ellipse", 17),
("yellow", 15),
("color", 24),
("shape", 24),
("yellow not green", 10),
],
)
def test_parent_tags(search_library: Library, query: str, count: int):
verify_count(search_library, query, count)
@pytest.mark.parametrize(
"invalid_query", ["asd AND", "asd AND AND", "tag:(", "(asd", "asd[]", "asd]", ":", "tag: :"]
)