perf: bulk insert/delete tag_entries (#1296)

* perf: Bulk insert/delete tag_entries

* add type annotations
This commit is contained in:
TheBobBobs
2026-04-29 08:30:57 +00:00
committed by GitHub
parent 695b3923c9
commit 6c8d800598

View File

@@ -42,6 +42,7 @@ from sqlalchemy import (
text,
update,
)
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import (
InstanceState,
@@ -1478,7 +1479,7 @@ class Library:
return None
def add_tags_to_entries(
self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int]
self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int]
) -> int:
"""Add one or more tags to one or more entries.
@@ -1494,45 +1495,57 @@ class Library:
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
values: list[tuple[int, int]] = []
for tag_id in tag_ids_:
values.extend((tag_id, entry_id) for entry_id in entry_ids_)
with Session(self.engine, expire_on_commit=False) as session:
for tag_id in tag_ids_:
for entry_id in entry_ids_:
try:
session.add(TagEntry(tag_id=tag_id, entry_id=entry_id))
total_added += 1
session.commit()
except IntegrityError:
session.rollback()
for sub_list in [
values[i : i + MAX_SQL_VARIABLES // 2]
for i in range(0, len(values), MAX_SQL_VARIABLES // 2)
]:
stmt = (
sqlite.insert(TagEntry)
.values(sub_list)
.on_conflict_do_nothing()
.returning(TagEntry)
)
added = session.scalars(stmt).all()
total_added += len(added)
session.commit()
return total_added
def remove_tags_from_entries(
self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int]
) -> bool:
self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int]
):
"""Remove one or more tags from one or more entries."""
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
logger.info(
"[Library][remove_tags_from_entries]",
entry_ids=entry_ids,
tag_ids=tag_ids,
)
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else list(entry_ids)
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else list(tag_ids)
with Session(self.engine, expire_on_commit=False) as session:
try:
for tag_id in tag_ids_:
for entry_id in entry_ids_:
tag_entry = session.scalars(
select(TagEntry).where(
and_(
TagEntry.tag_id == tag_id,
TagEntry.entry_id == entry_id,
)
)
).first()
if tag_entry:
session.delete(tag_entry)
session.flush()
session.commit()
return True
except IntegrityError as e:
logger.error(e)
session.rollback()
return False
for tags_sub_list in [
tag_ids_[i : i + MAX_SQL_VARIABLES // 2]
for i in range(0, len(tag_ids_), MAX_SQL_VARIABLES // 2)
]:
for entries_sub_list in [
entry_ids_[i : i + MAX_SQL_VARIABLES // 2]
for i in range(0, len(entry_ids_), MAX_SQL_VARIABLES // 2)
]:
stmt = delete(TagEntry).where(
and_(
TagEntry.tag_id.in_(tags_sub_list),
TagEntry.entry_id.in_(entries_sub_list),
)
)
session.execute(stmt)
session.commit()
def add_color(self, color_group: TagColorGroup) -> TagColorGroup | None:
with Session(self.engine, expire_on_commit=False) as session: