mirror of
https://github.com/TagStudioDev/TagStudio.git
synced 2026-05-20 15:55:19 +00:00
perf: bulk insert/delete tag_entries (#1296)
* perf: Bulk insert/delete tag_entries * add type annotations
This commit is contained in:
@@ -42,6 +42,7 @@ from sqlalchemy import (
|
||||
text,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import (
|
||||
InstanceState,
|
||||
@@ -1478,7 +1479,7 @@ class Library:
|
||||
return None
|
||||
|
||||
def add_tags_to_entries(
|
||||
self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int]
|
||||
self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int]
|
||||
) -> int:
|
||||
"""Add one or more tags to one or more entries.
|
||||
|
||||
@@ -1494,45 +1495,57 @@ class Library:
|
||||
|
||||
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
|
||||
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
|
||||
values: list[tuple[int, int]] = []
|
||||
for tag_id in tag_ids_:
|
||||
values.extend((tag_id, entry_id) for entry_id in entry_ids_)
|
||||
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
for tag_id in tag_ids_:
|
||||
for entry_id in entry_ids_:
|
||||
try:
|
||||
session.add(TagEntry(tag_id=tag_id, entry_id=entry_id))
|
||||
total_added += 1
|
||||
session.commit()
|
||||
except IntegrityError:
|
||||
session.rollback()
|
||||
for sub_list in [
|
||||
values[i : i + MAX_SQL_VARIABLES // 2]
|
||||
for i in range(0, len(values), MAX_SQL_VARIABLES // 2)
|
||||
]:
|
||||
stmt = (
|
||||
sqlite.insert(TagEntry)
|
||||
.values(sub_list)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(TagEntry)
|
||||
)
|
||||
added = session.scalars(stmt).all()
|
||||
total_added += len(added)
|
||||
session.commit()
|
||||
|
||||
return total_added
|
||||
|
||||
def remove_tags_from_entries(
|
||||
self, entry_ids: int | list[int] | set[int], tag_ids: int | list[int] | set[int]
|
||||
) -> bool:
|
||||
self, entry_ids: int | Iterable[int], tag_ids: int | Iterable[int]
|
||||
):
|
||||
"""Remove one or more tags from one or more entries."""
|
||||
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else entry_ids
|
||||
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else tag_ids
|
||||
logger.info(
|
||||
"[Library][remove_tags_from_entries]",
|
||||
entry_ids=entry_ids,
|
||||
tag_ids=tag_ids,
|
||||
)
|
||||
|
||||
entry_ids_ = [entry_ids] if isinstance(entry_ids, int) else list(entry_ids)
|
||||
tag_ids_ = [tag_ids] if isinstance(tag_ids, int) else list(tag_ids)
|
||||
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
for tag_id in tag_ids_:
|
||||
for entry_id in entry_ids_:
|
||||
tag_entry = session.scalars(
|
||||
select(TagEntry).where(
|
||||
and_(
|
||||
TagEntry.tag_id == tag_id,
|
||||
TagEntry.entry_id == entry_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if tag_entry:
|
||||
session.delete(tag_entry)
|
||||
session.flush()
|
||||
session.commit()
|
||||
return True
|
||||
except IntegrityError as e:
|
||||
logger.error(e)
|
||||
session.rollback()
|
||||
return False
|
||||
for tags_sub_list in [
|
||||
tag_ids_[i : i + MAX_SQL_VARIABLES // 2]
|
||||
for i in range(0, len(tag_ids_), MAX_SQL_VARIABLES // 2)
|
||||
]:
|
||||
for entries_sub_list in [
|
||||
entry_ids_[i : i + MAX_SQL_VARIABLES // 2]
|
||||
for i in range(0, len(entry_ids_), MAX_SQL_VARIABLES // 2)
|
||||
]:
|
||||
stmt = delete(TagEntry).where(
|
||||
and_(
|
||||
TagEntry.tag_id.in_(tags_sub_list),
|
||||
TagEntry.entry_id.in_(entries_sub_list),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
def add_color(self, color_group: TagColorGroup) -> TagColorGroup | None:
|
||||
with Session(self.engine, expire_on_commit=False) as session:
|
||||
|
||||
Reference in New Issue
Block a user