mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-02-12 04:39:31 +00:00
feat: implement parent tag search (#673)
* feat: implement parent tag search
* feat: add tests for parent tag search
* fix: typo
* feat: log the time it takes to build the SQL Expression
* feat: instead of hardcoding child tag ids into main query, include subquery
* Revert "feat: instead of hardcoding child tag ids into main query, include subquery"
This reverts commit 2615e7dab4.
This commit is contained in:
@@ -543,10 +543,18 @@ class Library:
|
||||
statement = select(Entry)
|
||||
|
||||
if search.ast:
|
||||
start_time = time.time()
|
||||
|
||||
statement = statement.outerjoin(Entry.tag_box_fields).where(
|
||||
SQLBoolExpressionBuilder(self).visit(search.ast)
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"SQL Expression Builder finished ({format_timespan(end_time - start_time)})"
|
||||
)
|
||||
|
||||
extensions = self.prefs(LibraryPrefs.EXTENSION_LIST)
|
||||
is_exclude_list = self.prefs(LibraryPrefs.IS_EXCLUDE_LIST)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import and_, distinct, func, or_, select
|
||||
import structlog
|
||||
from sqlalchemy import and_, distinct, func, or_, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import BinaryExpression, ColumnExpressionArgument
|
||||
from src.core.media_types import FILETYPE_EQUIVALENTS, MediaCategories
|
||||
@@ -16,6 +17,20 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
Library = None # don't import .library because of circular imports
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
CHILDREN_QUERY = text("""
|
||||
-- Note for this entire query that tag_subtags.child_id is the parent id and tag_subtags.parent_id is the child id due to bad naming
|
||||
WITH RECURSIVE Subtags AS (
|
||||
SELECT :tag_id AS child_id
|
||||
UNION ALL
|
||||
SELECT ts.parent_id AS child_id
|
||||
FROM tag_subtags ts
|
||||
INNER JOIN Subtags s ON ts.child_id = s.child_id
|
||||
)
|
||||
SELECT * FROM Subtags;
|
||||
""") # noqa: E501
|
||||
|
||||
|
||||
def get_filetype_equivalency_list(item: str) -> list[str] | set[str]:
|
||||
for s in FILETYPE_EQUIVALENTS:
|
||||
@@ -98,16 +113,28 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnExpressionArgument]):
|
||||
def visit_not(self, node: Not) -> ColumnExpressionArgument:
|
||||
return ~self.__entry_satisfies_ast(node.child)
|
||||
|
||||
def __get_tag_ids(self, tag_name: str) -> list[int]:
|
||||
def __get_tag_ids(self, tag_name: str, include_children: bool = True) -> list[int]:
|
||||
"""Given a tag name find the ids of all tags that this name could refer to."""
|
||||
with Session(self.lib.engine, expire_on_commit=False) as session:
|
||||
return list(
|
||||
with Session(self.lib.engine) as session:
|
||||
tag_ids = list(
|
||||
session.scalars(
|
||||
select(Tag.id)
|
||||
.where(or_(Tag.name.ilike(tag_name), Tag.shorthand.ilike(tag_name)))
|
||||
.union(select(TagAlias.tag_id).where(TagAlias.name.ilike(tag_name)))
|
||||
)
|
||||
)
|
||||
if len(tag_ids) > 1:
|
||||
logger.debug(
|
||||
f'Tag Constraint "{tag_name}" is ambiguous, {len(tag_ids)} matching tags found',
|
||||
tag_ids=tag_ids,
|
||||
include_children=include_children,
|
||||
)
|
||||
if not include_children:
|
||||
return tag_ids
|
||||
outp = []
|
||||
for tag_id in tag_ids:
|
||||
outp.extend(list(session.scalars(CHILDREN_QUERY, {"tag_id": tag_id})))
|
||||
return outp
|
||||
|
||||
def __entry_has_all_tags(self, tag_ids: list[int]) -> BinaryExpression[bool]:
|
||||
"""Returns Binary Expression that is true if the Entry has all provided tag ids."""
|
||||
|
||||
@@ -116,6 +116,20 @@ def test_parentheses(search_library: Library, query: str, count: int):
|
||||
verify_count(search_library, query, count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["query", "count"],
|
||||
[
|
||||
("ellipse", 17),
|
||||
("yellow", 15),
|
||||
("color", 24),
|
||||
("shape", 24),
|
||||
("yellow not green", 10),
|
||||
],
|
||||
)
|
||||
def test_parent_tags(search_library: Library, query: str, count: int):
|
||||
verify_count(search_library, query, count)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_query", ["asd AND", "asd AND AND", "tag:(", "(asd", "asd[]", "asd]", ":", "tag: :"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user